v4.0 update. (#2371)
This commit is contained in:
@ -42,7 +42,6 @@
|
||||
#include "cute/algorithm/functional.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cute/algorithm/gemm.hpp"
|
||||
#include "cute/tensor_predicate.hpp"
|
||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||
#include "cutlass/arch/grid_dependency_control.h"
|
||||
|
||||
@ -288,7 +287,7 @@ struct CollectiveMma<
|
||||
constexpr int tma_alignment_bits = 128;
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
|
||||
|
||||
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
||||
bool implementable = cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
|
||||
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
|
||||
@ -445,7 +444,7 @@ struct CollectiveMma<
|
||||
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b, cute::TMA::CacheHintSm90::EVICT_LAST), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
|
||||
++k_tile_iter;
|
||||
|
||||
if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) {
|
||||
if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) {
|
||||
launch_dep_grids = true;
|
||||
cutlass::arch::launch_dependent_grids();
|
||||
}
|
||||
@ -453,7 +452,7 @@ struct CollectiveMma<
|
||||
// Advance smem_pipe_write
|
||||
++smem_pipe_write;
|
||||
}
|
||||
if (!disable_gdc && !launch_dep_grids) {
|
||||
if (!disable_gdc && !launch_dep_grids) {
|
||||
cutlass::arch::launch_dependent_grids();
|
||||
}
|
||||
}
|
||||
@ -533,7 +532,7 @@ struct CollectiveMma<
|
||||
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
|
||||
++k_tile_iter;
|
||||
|
||||
if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) {
|
||||
if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) {
|
||||
launch_dep_grids = true;
|
||||
cutlass::arch::launch_dependent_grids();
|
||||
}
|
||||
@ -541,7 +540,7 @@ struct CollectiveMma<
|
||||
// Advance smem_pipe_write
|
||||
++smem_pipe_write;
|
||||
}
|
||||
if (!disable_gdc && !launch_dep_grids) {
|
||||
if (!disable_gdc && !launch_dep_grids) {
|
||||
cutlass::arch::launch_dependent_grids();
|
||||
}
|
||||
}
|
||||
@ -634,9 +633,9 @@ struct CollectiveMma<
|
||||
// Issue the epilogue waits
|
||||
if (lane_predicate) {
|
||||
/* This helps avoid early exit of blocks in Cluster
|
||||
* Waits for all stages to either be released (all
|
||||
* Waits for all stages to either be released (all
|
||||
* Consumer UNLOCKs), or if the stage was never used
|
||||
* then would just be acquired since the phase was
|
||||
* then would just be acquired since the phase was
|
||||
* still inverted from make_producer_start_state
|
||||
*/
|
||||
pipeline.producer_tail(smem_pipe_write);
|
||||
@ -854,7 +853,7 @@ struct CollectiveMma<
|
||||
k_tile_count -= prologue_mma_count;
|
||||
|
||||
smem_pipe_release.advance(k_tile_count);
|
||||
|
||||
|
||||
// Wait on all GMMAs to complete
|
||||
warpgroup_wait<0>();
|
||||
|
||||
|
||||
@ -133,7 +133,7 @@ using TP = _8;
|
||||
static constexpr int TP_ = TP{};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && \
|
||||
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4))
|
||||
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
|
||||
|
||||
// Distributed GEMM tiling/sharding schedule
|
||||
// Choices:
|
||||
@ -252,7 +252,8 @@ HostTensorB tensor_B_arr[TP_];
|
||||
HostTensorD tensor_C_arr[TP_];
|
||||
HostTensorD tensor_D_arr[TP_];
|
||||
|
||||
#endif // (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
|
||||
#endif // (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) &&
|
||||
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
@ -345,7 +346,7 @@ struct Result {
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && \
|
||||
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4))
|
||||
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
@ -803,17 +804,18 @@ int run(Options &options) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
|
||||
#endif // (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) &&
|
||||
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA Toolkit 12.4 or newer to run this example
|
||||
// CUTLASS must be compiled with CUDA Toolkit 12.6 or newer to run this example
|
||||
// and must have compute capability at least 90.
|
||||
// Some necessary cuda graph APIs were only introduced in CUDA 12.4.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4)) {
|
||||
std::cerr << "This example requires CUDA 12.4 or newer." << std::endl;
|
||||
// Some necessary cuda graph APIs were only introduced in CUDA 12.6.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 6)) {
|
||||
std::cerr << "This example requires CUDA 12.6 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
@ -857,11 +859,11 @@ int main(int argc, char const **args) {
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
#if (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
|
||||
#if (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6)))
|
||||
run(options);
|
||||
#else
|
||||
std::cerr
|
||||
<< "This example must be compiled with `sm90a` and CUDA Toolkit 12.4 or later." << std::endl;
|
||||
<< "This example must be compiled with `sm90a` and CUDA Toolkit 12.6 or later." << std::endl;
|
||||
return 0;
|
||||
#endif
|
||||
|
||||
|
||||
@ -250,8 +250,6 @@ cutlass::DeviceAllocation<ElementAccumulator> block_beta;
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90GroupParams<Shape<int,int,int>>::RasterOrderOptions;
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
@ -518,7 +516,7 @@ GemmArguments args_from_options(const OptionType &options, bool host_problem_sha
|
||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
|
||||
}
|
||||
|
||||
arguments.scheduler.raster_order = options.raster;
|
||||
arguments.scheduler.raster_order = options.raster_order;
|
||||
// The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8)
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
|
||||
@ -690,10 +688,10 @@ int run(OptionType &options, bool host_problem_shapes_available = true)
|
||||
|
||||
std::string raster = "Heuristic";
|
||||
|
||||
if (options.raster == RasterOrderOptions::AlongN) {
|
||||
if (options.raster_order == RasterOrderOptions::AlongN) {
|
||||
raster = "Along N";
|
||||
}
|
||||
else if (options.raster == RasterOrderOptions::AlongM) {
|
||||
else if (options.raster_order == RasterOrderOptions::AlongM) {
|
||||
raster = "Along M";
|
||||
}
|
||||
|
||||
@ -747,7 +745,7 @@ int main(int argc, char const **args) {
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options<RasterOrderOptions, ProblemShape> options;
|
||||
Options<ProblemShape> options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
|
||||
@ -253,8 +253,6 @@ cutlass::DeviceAllocation<ElementAccumulator> block_beta;
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90GroupParams<Shape<int,int,int>>::RasterOrderOptions;
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
@ -523,7 +521,7 @@ GemmArguments args_from_options(const OptionType &options, bool host_problem_sha
|
||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
|
||||
}
|
||||
|
||||
arguments.scheduler.raster_order = options.raster;
|
||||
arguments.scheduler.raster_order = options.raster_order;
|
||||
// The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8)
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
|
||||
@ -699,10 +697,10 @@ int run(OptionType &options, bool host_problem_shapes_available = true)
|
||||
|
||||
std::string raster = "Heuristic";
|
||||
|
||||
if (options.raster == RasterOrderOptions::AlongN) {
|
||||
if (options.raster_order == RasterOrderOptions::AlongN) {
|
||||
raster = "Along N";
|
||||
}
|
||||
else if (options.raster == RasterOrderOptions::AlongM) {
|
||||
else if (options.raster_order == RasterOrderOptions::AlongM) {
|
||||
raster = "Along M";
|
||||
}
|
||||
|
||||
@ -755,7 +753,7 @@ int main(int argc, char const **args) {
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options<RasterOrderOptions, ProblemShape> options;
|
||||
Options<ProblemShape> options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
|
||||
@ -30,10 +30,9 @@
|
||||
**************************************************************************************************/
|
||||
|
||||
// Command line options parsing
|
||||
template<typename _RasterOrderOptions, typename _ProblemShape>
|
||||
using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
|
||||
template<typename _ProblemShape>
|
||||
struct Options {
|
||||
|
||||
using RasterOrderOptions = _RasterOrderOptions;
|
||||
using ProblemShape = _ProblemShape;
|
||||
|
||||
bool help = false;
|
||||
@ -50,7 +49,7 @@ struct Options {
|
||||
int const m_alignment = 128;
|
||||
int const n_alignment = 128;
|
||||
|
||||
RasterOrderOptions raster;
|
||||
RasterOrderOptions raster_order;
|
||||
int swizzle;
|
||||
|
||||
// Parses the command line
|
||||
@ -74,13 +73,13 @@ struct Options {
|
||||
cmd.get_cmd_line_argument("raster", raster_char);
|
||||
|
||||
if (raster_char == 'N' || raster_char == 'n') {
|
||||
raster = RasterOrderOptions::AlongN;
|
||||
raster_order = RasterOrderOptions::AlongN;
|
||||
}
|
||||
else if (raster_char == 'M' || raster_char == 'm') {
|
||||
raster = RasterOrderOptions::AlongM;
|
||||
raster_order = RasterOrderOptions::AlongM;
|
||||
}
|
||||
else if (raster_char == 'H' || raster_char == 'h') {
|
||||
raster = RasterOrderOptions::Heuristic;
|
||||
raster_order = RasterOrderOptions::Heuristic;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("swizzle", swizzle, 1);
|
||||
|
||||
@ -543,7 +543,7 @@ int run(Options &options) {
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example
|
||||
// CUTLASS must be compiled with CUDA 12.8 Toolkit or newer to run this example
|
||||
// and must have compute capability at least 100.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
@ -560,7 +560,6 @@ int main(int argc, char const **args) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -237,7 +237,7 @@ cutlass::DeviceAllocation<ElementAccumulator> block_beta;
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100GroupParams<typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
|
||||
using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
|
||||
@ -300,7 +300,7 @@ auto make_iterator(T* ptr) {
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100GroupParams<typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
|
||||
using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
|
||||
@ -490,7 +490,7 @@ int run(Options &options)
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
// CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example
|
||||
// and must have compute capability at least 90.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
@ -503,11 +503,11 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
|
||||
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
}
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -490,7 +490,7 @@ int run(Options &options)
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
// CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example
|
||||
// and must have compute capability at least 90.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
@ -503,11 +503,11 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
|
||||
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
}
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -499,11 +499,11 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
|
||||
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -117,15 +117,17 @@ struct Options {
|
||||
int q = 256;
|
||||
int k = 256;
|
||||
int d = 128;
|
||||
int warmup_iterations = 1;
|
||||
int iterations = 3;
|
||||
int tensor_ring_buffers = 1;
|
||||
bool verify = false;
|
||||
bool verbose = false;
|
||||
|
||||
bool causal = false;
|
||||
bool residual = false;
|
||||
bool varlen = false;
|
||||
bool persistent = false;
|
||||
int sm_count = 0;
|
||||
|
||||
std::string kernel_filter;
|
||||
|
||||
InitStyle init_style_q = InitStyle::kRandom;
|
||||
@ -189,10 +191,15 @@ struct Options {
|
||||
if (b == -1) b = 16384 / k;
|
||||
if (b == 0) b = 1;
|
||||
|
||||
cmd.get_cmd_line_argument("warmup_iterations", warmup_iterations, defaults.warmup_iterations);
|
||||
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
|
||||
cmd.get_cmd_line_argument("tensor_ring_buffers", tensor_ring_buffers, defaults.tensor_ring_buffers);
|
||||
|
||||
verify = cmd.check_cmd_line_flag("verify");
|
||||
verbose = cmd.check_cmd_line_flag("verbose");
|
||||
varlen = cmd.check_cmd_line_flag("varlen");
|
||||
persistent = cmd.check_cmd_line_flag("persistent");
|
||||
|
||||
std::string mask;
|
||||
cmd.get_cmd_line_argument<std::string>("mask", mask, "");
|
||||
if (mask == "no" || mask == "") {
|
||||
@ -210,7 +217,6 @@ struct Options {
|
||||
causal = false;
|
||||
}
|
||||
cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count);
|
||||
|
||||
get_init_style_argument(cmd, "init-style", init_style_q, defaults.init_style_q);
|
||||
get_init_style_argument(cmd, "init-style", init_style_k, defaults.init_style_q);
|
||||
get_init_style_argument(cmd, "init-style", init_style_v, defaults.init_style_q);
|
||||
@ -235,10 +241,13 @@ struct Options {
|
||||
<< " --q=<int> Sets the Q extent\n"
|
||||
<< " --k=<int> Sets the K extent\n"
|
||||
<< " --d=<int> Sets the D extentn"
|
||||
<< " --tensor_ring_buffers=<int> Sets the number of tensor ring buffers\n"
|
||||
<< " --warmup_iterations=<int> Sets the warmup iterations\n"
|
||||
<< " --iterations=<int> Benchmarking iterations\n"
|
||||
<< " --verify Verify results\n"
|
||||
<< " --verbose Print smem and execution time per kernel\n"
|
||||
<< " --mask=<no|residual|causal> Enables masking\n"
|
||||
<< " --persistent Enables persistent scheduler\n"
|
||||
<< " --varlen Enables variable sequence length\n"
|
||||
<< " B*Q and B*K become the total sequence length\n"
|
||||
<< " and are split B-ways, alternatingly +10% and -10%\n"
|
||||
@ -379,40 +388,55 @@ struct FwdRunner {
|
||||
StrideLSE stride_LSE;
|
||||
uint64_t seed = 0;
|
||||
|
||||
DeviceAllocation<Element> block_Q;
|
||||
DeviceAllocation<Element> block_K;
|
||||
DeviceAllocation<Element> block_V;
|
||||
DeviceAllocation<ElementOut> block_O;
|
||||
DeviceAllocation<ElementAccumulatorPV> block_LSE;
|
||||
DeviceAllocation<ElementOut> block_ref_O;
|
||||
DeviceAllocation<ElementAccumulatorPV> block_ref_LSE;
|
||||
struct DeviceBuffer {
|
||||
DeviceAllocation<Element> block_Q;
|
||||
DeviceAllocation<Element> block_K;
|
||||
DeviceAllocation<Element> block_V;
|
||||
DeviceAllocation<ElementOut> block_O;
|
||||
DeviceAllocation<ElementAccumulatorPV> block_LSE;
|
||||
DeviceAllocation<ElementOut> block_ref_O;
|
||||
DeviceAllocation<ElementAccumulatorPV> block_ref_LSE;
|
||||
DeviceAllocation<int> device_cumulative_seqlen_q;
|
||||
DeviceAllocation<int> device_cumulative_seqlen_kv;
|
||||
|
||||
DeviceBuffer() = default;
|
||||
DeviceBuffer(const DeviceBuffer&) = delete;
|
||||
DeviceBuffer& operator=(const DeviceBuffer&) = delete;
|
||||
|
||||
size_t get_storage_size() const {
|
||||
return block_Q.get_storage_size() + block_K.get_storage_size() + block_V.get_storage_size()
|
||||
+ block_O.get_storage_size() + block_LSE.get_storage_size() + block_ref_O.get_storage_size()
|
||||
+ block_ref_LSE.get_storage_size() + device_cumulative_seqlen_q.get_storage_size()
|
||||
+ device_cumulative_seqlen_kv.get_storage_size();
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<std::unique_ptr<DeviceBuffer>> buffers;
|
||||
|
||||
std::vector<int> cumulative_seqlen_q;
|
||||
std::vector<int> cumulative_seqlen_kv;
|
||||
DeviceAllocation<int> device_cumulative_seqlen_q;
|
||||
DeviceAllocation<int> device_cumulative_seqlen_kv;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
bool verify(const ProblemShapeType& problem_shape) {
|
||||
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()),
|
||||
bool verify(const ProblemShapeType& problem_shape, DeviceBuffer& buffer) {
|
||||
Tensor mQ = make_tensor(make_gmem_ptr(buffer.block_Q.get()),
|
||||
select<0,2,3>(problem_shape),
|
||||
stride_Q);
|
||||
|
||||
Tensor mK = make_tensor(make_gmem_ptr(block_K.get()),
|
||||
Tensor mK = make_tensor(make_gmem_ptr(buffer.block_K.get()),
|
||||
select<1,2,3>(problem_shape),
|
||||
stride_K);
|
||||
|
||||
Tensor mV = make_tensor(make_gmem_ptr(block_V.get()),
|
||||
Tensor mV = make_tensor(make_gmem_ptr(buffer.block_V.get()),
|
||||
select<1,2,3>(problem_shape),
|
||||
stride_V);
|
||||
|
||||
Tensor mO = make_tensor(make_gmem_ptr(block_ref_O.get()),
|
||||
Tensor mO = make_tensor(make_gmem_ptr(buffer.block_ref_O.get()),
|
||||
select<0,2,3>(problem_shape),
|
||||
stride_O);
|
||||
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(block_ref_LSE.get()),
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(buffer.block_ref_LSE.get()),
|
||||
select<0,3>(problem_shape),
|
||||
stride_LSE);
|
||||
|
||||
@ -431,7 +455,7 @@ struct FwdRunner {
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
double max_diff = 0;
|
||||
double mean_diff = 0;
|
||||
reference_abs_diff(block_O, block_ref_O, max_diff, mean_diff);
|
||||
reference_abs_diff(buffer.block_O, buffer.block_ref_O, max_diff, mean_diff);
|
||||
|
||||
bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
if (! passed_O) {
|
||||
@ -439,14 +463,13 @@ struct FwdRunner {
|
||||
<< " mean " << mean_diff << std::endl;
|
||||
}
|
||||
|
||||
// reference_abs_diff(block_LSE, block_ref_LSE, max_diff, mean_diff);
|
||||
reference_abs_diff(buffer.block_LSE, buffer.block_ref_LSE, max_diff, mean_diff);
|
||||
|
||||
bool passed_LSE = true; // future work
|
||||
// bool passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
// if ( ! passed_LSE) {
|
||||
// std::cerr << "failed LSE: max diff " << max_diff
|
||||
// << " mean " << mean_diff << std::endl;
|
||||
// }
|
||||
bool passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
if ( ! passed_LSE) {
|
||||
std::cerr << "failed LSE: max diff " << max_diff
|
||||
<< " mean " << mean_diff << std::endl;
|
||||
}
|
||||
|
||||
return passed_O && passed_LSE;
|
||||
}
|
||||
@ -559,50 +582,70 @@ struct FwdRunner {
|
||||
get<1,1>(stride_LSE) = 0;
|
||||
}
|
||||
|
||||
block_Q.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
|
||||
block_K.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
|
||||
block_V.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
|
||||
block_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
|
||||
block_LSE.reset(size(shape_LSE));
|
||||
block_ref_O.reset(size(shape_QO));
|
||||
block_ref_LSE.reset(size(shape_LSE));
|
||||
auto buffer_init_fn = [&](auto& buffer) {
|
||||
buffer.block_Q.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
|
||||
buffer.block_K.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
|
||||
buffer.block_V.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
|
||||
buffer.block_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
|
||||
buffer.block_LSE.reset(size(shape_LSE));
|
||||
|
||||
initialize_block(block_Q, seed + 2023, options.init_style_q);
|
||||
initialize_block(block_K, seed + 2022, options.init_style_k);
|
||||
initialize_block(block_V, seed + 2021, options.init_style_v);
|
||||
initialize_block(buffer.block_Q, seed + 2023, options.init_style_q);
|
||||
initialize_block(buffer.block_K, seed + 2022, options.init_style_k);
|
||||
initialize_block(buffer.block_V, seed + 2021, options.init_style_v);
|
||||
|
||||
if ( ! cumulative_seqlen_q.empty()) {
|
||||
device_cumulative_seqlen_q.reset(cumulative_seqlen_q.size());
|
||||
device_cumulative_seqlen_q.copy_from_host(
|
||||
cumulative_seqlen_q.data(), cumulative_seqlen_q.size());
|
||||
}
|
||||
if ( ! cumulative_seqlen_kv.empty()) {
|
||||
device_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size());
|
||||
device_cumulative_seqlen_kv.copy_from_host(
|
||||
cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size());
|
||||
if ( ! cumulative_seqlen_q.empty()) {
|
||||
buffer.device_cumulative_seqlen_q.reset(cumulative_seqlen_q.size());
|
||||
buffer.device_cumulative_seqlen_q.copy_from_host(
|
||||
cumulative_seqlen_q.data(), cumulative_seqlen_q.size());
|
||||
}
|
||||
if ( ! cumulative_seqlen_kv.empty()) {
|
||||
buffer.device_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size());
|
||||
buffer.device_cumulative_seqlen_kv.copy_from_host(
|
||||
cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size());
|
||||
}
|
||||
};
|
||||
|
||||
buffers.push_back(std::make_unique<DeviceBuffer>());
|
||||
buffer_init_fn(*buffers.back());
|
||||
|
||||
int tensor_ring_buffers = options.tensor_ring_buffers;
|
||||
for (int i = 1; i < tensor_ring_buffers; i++) {
|
||||
buffers.push_back(std::make_unique<DeviceBuffer>());
|
||||
buffer_init_fn(*buffers.back());
|
||||
}
|
||||
|
||||
if constexpr (kIsVarlen) {
|
||||
get<0>(problem_shape).cumulative_length = device_cumulative_seqlen_q.get();
|
||||
get<1>(problem_shape).cumulative_length = device_cumulative_seqlen_kv.get();
|
||||
get<0>(problem_shape).cumulative_length = buffers[0]->device_cumulative_seqlen_q.get();
|
||||
get<1>(problem_shape).cumulative_length = buffers[0]->device_cumulative_seqlen_kv.get();
|
||||
}
|
||||
|
||||
return problem_shape;
|
||||
}
|
||||
|
||||
auto get_arguments(const ProblemShapeType& problem_shape, const cutlass::KernelHardwareInfo& hw_info, int buffer_index) {
|
||||
auto problem_shape_ = problem_shape;
|
||||
if constexpr (kIsVarlen) {
|
||||
get<0>(problem_shape_).cumulative_length = buffers[buffer_index]->device_cumulative_seqlen_q.get();
|
||||
get<1>(problem_shape_).cumulative_length = buffers[buffer_index]->device_cumulative_seqlen_kv.get();
|
||||
}
|
||||
typename Operation::Arguments arguments{
|
||||
problem_shape_,
|
||||
{ buffers[buffer_index]->block_Q.get(), stride_Q,
|
||||
buffers[buffer_index]->block_K.get(), stride_K,
|
||||
buffers[buffer_index]->block_V.get(), stride_V },
|
||||
{ buffers[buffer_index]->block_O.get(), stride_O,
|
||||
buffers[buffer_index]->block_LSE.get(), stride_LSE },
|
||||
hw_info
|
||||
};
|
||||
return arguments;
|
||||
}
|
||||
|
||||
ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
|
||||
|
||||
ProblemShapeType problem_shape = initialize(options);
|
||||
|
||||
typename Operation::Arguments arguments{
|
||||
problem_shape,
|
||||
{ block_Q.get(), stride_Q,
|
||||
block_K.get(), stride_K,
|
||||
block_V.get(), stride_V },
|
||||
{ block_O.get(), stride_O,
|
||||
block_LSE.get(), stride_LSE },
|
||||
hw_info
|
||||
};
|
||||
int buffer_index = 0;
|
||||
typename Operation::Arguments arguments = get_arguments(problem_shape, hw_info, buffer_index);
|
||||
|
||||
Operation op;
|
||||
|
||||
@ -630,11 +673,21 @@ struct FwdRunner {
|
||||
}
|
||||
|
||||
// Run
|
||||
status = op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
for (int i = 0; i < options.warmup_iterations; i++) {
|
||||
status = op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
buffer_index = (buffer_index + 1) % buffers.size();
|
||||
arguments = get_arguments(problem_shape, hw_info, buffer_index);
|
||||
status = op.update(arguments, workspace.get());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to update the CUTLASS kernel's parameters. Last CUDA error is: "
|
||||
<< std::endl;
|
||||
return example_result;
|
||||
}
|
||||
}
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
@ -672,6 +725,14 @@ struct FwdRunner {
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
buffer_index = (buffer_index + 1) % buffers.size();
|
||||
arguments = get_arguments(problem_shape, hw_info, buffer_index);
|
||||
status = op.update(arguments, workspace.get());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to update the CUTLASS kernel's parameters. Last CUDA error is: "
|
||||
<< std::endl;
|
||||
return example_result;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
@ -734,10 +795,10 @@ struct FwdRunner {
|
||||
// Verify that the result is correct
|
||||
bool passed = true;
|
||||
if (options.verify) {
|
||||
passed = verify(problem_shape);
|
||||
passed = verify(problem_shape, *buffers[0]);
|
||||
if (passed) example_result.verified = true;
|
||||
}
|
||||
|
||||
|
||||
if (!passed) {
|
||||
std::cerr << "Reference check failed" << std::endl;
|
||||
return example_result;
|
||||
@ -789,10 +850,14 @@ void run_fwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareIn
|
||||
|
||||
using HeadDim = _128;
|
||||
|
||||
// Persistent Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
|
||||
// Individual Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
|
||||
if (options.persistent) {
|
||||
// Persistent Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
|
||||
}
|
||||
else {
|
||||
// Individual Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -818,10 +883,14 @@ void run_fwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInf
|
||||
|
||||
using HeadDim = _64;
|
||||
|
||||
// Persistent Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
|
||||
// Individual Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
|
||||
if (options.persistent) {
|
||||
// Persistent Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
|
||||
}
|
||||
else {
|
||||
// Individual Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -845,10 +914,14 @@ void run_fwd_32(Mask fusion, Options const & options, cutlass::KernelHardwareInf
|
||||
using HeadDim = _32;
|
||||
|
||||
#ifdef FP8
|
||||
// Persistent Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
|
||||
// Individual Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
|
||||
if (options.persistent) {
|
||||
// Persistent Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
|
||||
}
|
||||
else {
|
||||
// Individual Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@ -59,7 +59,7 @@ using namespace cutlass::fmha::kernel;
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
enum class InitStyle {
|
||||
kOne, kLinearStride128, kLinearStride1, kRandom, kNone
|
||||
kOne, kLinearStride128, kLinearStride1, kRandom, kRandomLarge, kNone
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -98,6 +98,9 @@ struct Options {
|
||||
if (s == "r") {
|
||||
dst = InitStyle::kRandom;
|
||||
}
|
||||
else if (s == "l") {
|
||||
dst = InitStyle::kRandomLarge;
|
||||
}
|
||||
else if (s == "1") {
|
||||
dst = InitStyle::kOne;
|
||||
}
|
||||
@ -203,6 +206,11 @@ void initialize_block(
|
||||
block.get(), block.size(), seed, (Element) -1, (Element) 1);
|
||||
break;
|
||||
}
|
||||
case InitStyle::kRandomLarge: {
|
||||
cutlass::reference::device::BlockFillRandomGaussian(
|
||||
block.get(), block.size(), seed, (Element) -1, (Element) 100);
|
||||
break;
|
||||
}
|
||||
case InitStyle::kLinearStride1: {
|
||||
std::vector<Element> data(block.size());
|
||||
for (size_t i = 0; i < block.size() / 128; i ++) {
|
||||
|
||||
@ -144,4 +144,23 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
|
||||
target_compile_definitions(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE ${PREC_MACRO} SKIP_ATOMIC)
|
||||
target_compile_options(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE -Xptxas -v)
|
||||
endforeach()
|
||||
|
||||
# Add a target that builds all examples
|
||||
add_custom_target(77_blackwell_fmha_all
|
||||
DEPENDS
|
||||
77_blackwell_fmha_fp8
|
||||
77_blackwell_fmha_fp16
|
||||
77_blackwell_fmha_gen_fp8
|
||||
77_blackwell_fmha_gen_fp16
|
||||
77_blackwell_mla_2sm_fp8
|
||||
77_blackwell_mla_2sm_fp16
|
||||
77_blackwell_mla_2sm_cpasync_fp8
|
||||
77_blackwell_mla_2sm_cpasync_fp16
|
||||
77_blackwell_mla_b2b_2sm_fp8
|
||||
77_blackwell_mla_b2b_2sm_fp16
|
||||
77_blackwell_fmha_bwd_fp8
|
||||
77_blackwell_fmha_bwd_fp16
|
||||
77_blackwell_fmha_bwd_sat_fp8
|
||||
77_blackwell_fmha_bwd_sat_fp16
|
||||
)
|
||||
endif()
|
||||
|
||||
@ -157,8 +157,8 @@ struct CausalMask : NoMask {
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
int trip_count = get_trip_count(blk_coord, tile_shape, problem_size);
|
||||
return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));
|
||||
int trip_count = get_trip_count(blk_coord, tile_shape, problem_size);
|
||||
return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));
|
||||
}
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
|
||||
@ -42,7 +42,7 @@ template<
|
||||
class ElementAcc,
|
||||
class TileShape, // Q, D, _
|
||||
class StrideO, // Q, D, B
|
||||
class StrideLSE // Q, B
|
||||
class StrideLSE_ // Q, B
|
||||
>
|
||||
struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
||||
|
||||
@ -54,6 +54,7 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
||||
// using SmemLayoutAtomO = decltype(make_ordered_layout(select<0,1>(TileShape{}), Step<_1, _0>{}));
|
||||
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, replace<2>(TileShape{}, _2{}), Step<_2, _1, _3>{}));
|
||||
using SmemLayoutO_ = SmemLayoutO;
|
||||
using StrideLSE = StrideLSE_;
|
||||
|
||||
struct TensorStorage {
|
||||
|
||||
@ -79,6 +80,9 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
||||
|
||||
struct Params {
|
||||
TMA_O tma_store_o;
|
||||
|
||||
ElementAcc* ptr_LSE;
|
||||
StrideLSE dLSE;
|
||||
};
|
||||
|
||||
template<class ProblemShape>
|
||||
@ -110,7 +114,9 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
||||
);
|
||||
|
||||
return {
|
||||
tma_store_o
|
||||
tma_store_o,
|
||||
args.ptr_LSE,
|
||||
args.dLSE
|
||||
};
|
||||
}
|
||||
|
||||
@ -119,6 +125,10 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
||||
cute::prefetch_tma_descriptor(params.tma_store_o.get_tma_descriptor());
|
||||
}
|
||||
|
||||
const Params& params;
|
||||
|
||||
CUTLASS_DEVICE Sm100FmhaFwdEpilogueTmaWarpspecialized(const Params& params) : params(params) {}
|
||||
|
||||
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
|
||||
CUTLASS_DEVICE auto
|
||||
store(
|
||||
|
||||
@ -531,7 +531,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
|
||||
|
||||
// Each thread owns a single row
|
||||
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem
|
||||
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem
|
||||
using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem
|
||||
using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem
|
||||
|
||||
@ -613,7 +613,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
NumericArrayConverter<Element, ElementQK, kConversionsPerStep> convert;
|
||||
|
||||
const int kReleasePipeCount = 10; // must be multiple of 2
|
||||
|
||||
|
||||
order_s.wait();
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
@ -637,7 +637,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
}
|
||||
tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv);
|
||||
|
||||
|
||||
|
||||
if (i == size(tTMEM_LOADrS) - kReleasePipeCount) {
|
||||
order_s.arrive();
|
||||
}
|
||||
@ -691,7 +691,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3);
|
||||
cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2);
|
||||
float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y;
|
||||
|
||||
|
||||
row_sum = local_row_sum;
|
||||
|
||||
if (final_call) {
|
||||
@ -787,14 +787,14 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
// good values would be either 32 or 64
|
||||
const int kCorrectionTileSize = 32 / sizeof(ElementOut);
|
||||
|
||||
using TMEM_LOAD = std::conditional_t<kCorrectionTileSize == 32, SM100_TMEM_LOAD_32dp32b32x, SM100_TMEM_LOAD_32dp32b16x>; // 4x32 threads with 64 cols of 32b elem
|
||||
using TMEM_LOAD = std::conditional_t<kCorrectionTileSize == 32, SM100_TMEM_LOAD_32dp32b32x, SM100_TMEM_LOAD_32dp32b16x>; // 4x32 threads with 64 cols of 32b elem
|
||||
|
||||
typename CollectiveMmaPV::TiledMma mma;
|
||||
Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));
|
||||
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));
|
||||
Tensor tOcO = mma.get_slice(0).partition_C(cO);
|
||||
Tensor tOsO = mma.get_slice(0).partition_C(sO);
|
||||
|
||||
|
||||
Tensor tOtO_i = logical_divide(tOtO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
||||
Tensor tOcO_i = logical_divide(tOcO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
||||
Tensor tOsO_i = logical_divide(tOsO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
||||
@ -809,7 +809,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
|
||||
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i(make_coord(_, _), _0{}));
|
||||
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
|
||||
|
||||
|
||||
Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i(make_coord(_, _), _));
|
||||
Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i(make_coord(_, _), _));
|
||||
Tensor tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i(make_coord(_, _), _));
|
||||
@ -824,9 +824,9 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i);
|
||||
|
||||
Tensor tTMrO = make_tensor<ElementPV>(shape(tTMEM_LOADcO(_, _0{}, _0{}, i)));
|
||||
|
||||
|
||||
copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO);
|
||||
|
||||
|
||||
#ifndef ONLY_SOFTMAX
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size(tTMrO); j += 2) {
|
||||
@ -872,24 +872,24 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
// good values would be either 32 or 64
|
||||
const int kCorrectionTileSize = 16;
|
||||
|
||||
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem
|
||||
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem
|
||||
using TMEM_STORE = SM100_TMEM_STORE_32dp32b16x; // 4x32 threads with 64 cols of 32b elem
|
||||
|
||||
typename CollectiveMmaPV::TiledMma mma;
|
||||
Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));
|
||||
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));
|
||||
Tensor tOcO = mma.get_slice(0).partition_C(cO);
|
||||
|
||||
|
||||
Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
||||
Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
||||
|
||||
tOtO_i.data() = tOtO_i.data().get() + tmem_O;
|
||||
|
||||
|
||||
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i);
|
||||
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
|
||||
auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i);
|
||||
auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx);
|
||||
|
||||
|
||||
Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i);
|
||||
Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i);
|
||||
Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i);
|
||||
@ -899,7 +899,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
float2 scale_f32x2 = make_float2(scale, scale);
|
||||
|
||||
Tensor tTMrO = make_tensor<ElementPV>(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{}));
|
||||
|
||||
|
||||
auto copy_in = [&](int i) {
|
||||
Tensor tTMEM_LOADtO_i = tTMEM_LOADtO;
|
||||
tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize);
|
||||
@ -942,7 +942,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
}
|
||||
}
|
||||
|
||||
template<class BlkCoord, class ProblemShape, class TensorStorageEpi>
|
||||
template<class BlkCoord, class ProblemShape, class TensorStorageEpi, class CollectiveEpilogue>
|
||||
CUTLASS_DEVICE auto
|
||||
correction(
|
||||
BlkCoord const& blk_coord,
|
||||
@ -951,7 +951,8 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state,
|
||||
PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state,
|
||||
PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state,
|
||||
PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state) {
|
||||
PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state,
|
||||
CollectiveEpilogue& epilogue) {
|
||||
|
||||
int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape);
|
||||
|
||||
@ -961,7 +962,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
|
||||
Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{}));
|
||||
Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS);
|
||||
|
||||
|
||||
Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{})));
|
||||
Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));
|
||||
|
||||
@ -1060,13 +1061,25 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
// F2FP
|
||||
// store to smem
|
||||
Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{});
|
||||
Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), repeat_like(typename CollectiveEpilogue::StrideLSE{}, _1{}), epilogue.params.dLSE);
|
||||
|
||||
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO);
|
||||
|
||||
if (epilogue.params.ptr_LSE != nullptr) {
|
||||
int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord);
|
||||
|
||||
ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);
|
||||
|
||||
if (row_idx < get<0>(problem_shape)) {
|
||||
gLSE(row_idx, get<2>(blk_coord)) = lse;
|
||||
}
|
||||
}
|
||||
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
|
||||
pipeline_o.consumer_release(pipeline_o_consumer_state);
|
||||
++pipeline_o_consumer_state;
|
||||
|
||||
|
||||
pipeline_epi.producer_commit(pipeline_epi_producer_state);
|
||||
++pipeline_epi_producer_state;
|
||||
|
||||
@ -1083,6 +1096,16 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
|
||||
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _1{}, sO);
|
||||
|
||||
if (epilogue.params.ptr_LSE != nullptr) {
|
||||
int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{});
|
||||
|
||||
ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);
|
||||
|
||||
if (row_idx < get<0>(problem_shape)) {
|
||||
gLSE(row_idx, get<2>(blk_coord)) = lse;
|
||||
}
|
||||
}
|
||||
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
|
||||
pipeline_o.consumer_release(pipeline_o_consumer_state);
|
||||
|
||||
@ -118,7 +118,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
|
||||
using TensorStrideContiguousK = Stride<int, _1, Stride<int, int>>;
|
||||
using TensorStrideContiguousMN = Stride<_1, int, Stride<int, int>>;
|
||||
|
||||
|
||||
// compute S
|
||||
using CollectiveMmaKQ = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
|
||||
@ -381,7 +381,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q, D, HB), args.mainloop.stride_dq_acc),
|
||||
SmemLayoutDQ{}(_, _, _0{})
|
||||
);
|
||||
|
||||
|
||||
return Params{
|
||||
args.problem_shape,
|
||||
args.mainloop,
|
||||
@ -452,7 +452,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
|
||||
ThrMMA cta_mma_kq = TiledMmaKQ{}.get_slice(_0{});
|
||||
ThrMMA cta_mma_vdo = TiledMmaVDO{}.get_slice(_0{});
|
||||
|
||||
|
||||
auto tSTgK = cta_mma_kq.partition_A(gK);
|
||||
auto tSTgQ = cta_mma_kq.partition_B(gQ);
|
||||
auto tDPTgV = cta_mma_vdo.partition_A(gV);
|
||||
@ -477,7 +477,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
group_modes<0,3>(sDO), group_modes<0,3>(tDPTgDO));
|
||||
|
||||
// set up lse and sum_odo
|
||||
|
||||
|
||||
auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord;
|
||||
|
||||
pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state);
|
||||
@ -495,7 +495,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
}
|
||||
|
||||
// load Q
|
||||
if (cute::elect_one_sync()) {
|
||||
if (cute::elect_one_sync()) {
|
||||
cute::copy(
|
||||
mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask),
|
||||
tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch),
|
||||
@ -520,7 +520,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
&mLSE(gmem_idx, blk_coord_batch),
|
||||
gmem_idx < Q
|
||||
);
|
||||
|
||||
|
||||
pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive);
|
||||
++pipeline_load_compute_lse_producer_state;
|
||||
|
||||
@ -529,7 +529,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state);
|
||||
|
||||
pipeline_load_mma_do.producer_expect_transaction(pipeline_load_mma_do_producer_state, kTransactionsBytesLoadV);
|
||||
|
||||
|
||||
// load V
|
||||
if (cute::elect_one_sync()) {
|
||||
cute::copy(
|
||||
@ -540,7 +540,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
}
|
||||
|
||||
// load dO
|
||||
if (cute::elect_one_sync()) {
|
||||
if (cute::elect_one_sync()) {
|
||||
cute::copy(
|
||||
mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask),
|
||||
tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch),
|
||||
@ -573,7 +573,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state);
|
||||
|
||||
// load Q
|
||||
if (cute::elect_one_sync()) {
|
||||
if (cute::elect_one_sync()) {
|
||||
cute::copy(
|
||||
mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask),
|
||||
tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch),
|
||||
@ -584,7 +584,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
++pipeline_load_mma_q_producer_state;
|
||||
|
||||
pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state);
|
||||
|
||||
|
||||
// load LSE
|
||||
smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4;
|
||||
gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
|
||||
@ -593,15 +593,15 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
&mLSE(gmem_idx, blk_coord_batch),
|
||||
gmem_idx < Q
|
||||
);
|
||||
|
||||
|
||||
pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive);
|
||||
++pipeline_load_compute_lse_producer_state;
|
||||
|
||||
pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state);
|
||||
tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state);
|
||||
|
||||
// load dO
|
||||
if (cute::elect_one_sync()) {
|
||||
// load dO
|
||||
if (cute::elect_one_sync()) {
|
||||
cute::copy(
|
||||
mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask),
|
||||
tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch),
|
||||
@ -612,7 +612,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
++pipeline_load_mma_do_producer_state;
|
||||
|
||||
pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state);
|
||||
|
||||
|
||||
// load sum_OdO
|
||||
smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4;
|
||||
gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
|
||||
@ -621,7 +621,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
&mSumOdO(gmem_idx, blk_coord_batch),
|
||||
gmem_idx < Q
|
||||
);
|
||||
|
||||
|
||||
pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive);
|
||||
++pipeline_load_compute_sum_odo_producer_state;
|
||||
|
||||
@ -639,23 +639,23 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
int iter_count,
|
||||
MainloopArguments const& mainloop_args,
|
||||
TensorStorage& shared_tensors,
|
||||
PipelineLoadMmaQ& pipeline_load_mma_q,
|
||||
typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state,
|
||||
PipelineLoadMmaDO& pipeline_load_mma_do,
|
||||
PipelineLoadMmaQ& pipeline_load_mma_q,
|
||||
typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state,
|
||||
PipelineLoadMmaDO& pipeline_load_mma_do,
|
||||
typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_consumer_state,
|
||||
PipelineMmaComputeS& pipeline_mma_compute_s,
|
||||
PipelineMmaComputeS& pipeline_mma_compute_s,
|
||||
typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_producer_state,
|
||||
PipelineMmaComputeDP& pipeline_mma_compute_dp,
|
||||
PipelineMmaComputeDP& pipeline_mma_compute_dp,
|
||||
typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_producer_state,
|
||||
PipelineMmaReduceDQ& pipeline_mma_reduce_dq,
|
||||
PipelineMmaReduceDQ& pipeline_mma_reduce_dq,
|
||||
typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_producer_state,
|
||||
PipelineComputeMmaP& pipeline_compute_mma_p,
|
||||
PipelineComputeMmaP& pipeline_compute_mma_p,
|
||||
typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_consumer_state,
|
||||
PipelineComputeMmaDS& pipeline_compute_mma_ds,
|
||||
PipelineComputeMmaDS& pipeline_compute_mma_ds,
|
||||
typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_consumer_state,
|
||||
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
|
||||
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) {
|
||||
|
||||
|
||||
auto [Q, K, D, HB] = problem_shape;
|
||||
|
||||
auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{});
|
||||
@ -685,7 +685,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
Tensor tDVrP = TiledMmaPDO::make_fragment_A(sP)(_, _, _, _0{});
|
||||
tDVrP.data() = TmemAllocation::kP;
|
||||
Tensor tDVrDOT = TiledMmaPDO::make_fragment_B(sDOT);
|
||||
|
||||
|
||||
TiledMmaKQ tiled_mma_kq;
|
||||
TiledMmaVDO tiled_mma_vdo;
|
||||
TiledMmaDSK tiled_mma_dsk;
|
||||
@ -923,6 +923,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
TensorC const& coord,
|
||||
TensorShape const& tensor_shape) {
|
||||
|
||||
Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); });
|
||||
|
||||
auto copy_op = make_cotiled_copy(
|
||||
Copy_Atom<UniversalCopy<uint128_t>, Element>{},
|
||||
make_layout(make_shape(_1{}, Int<sizeof(uint128_t) / sizeof(Element)>{})),
|
||||
@ -930,21 +932,11 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
);
|
||||
auto thr_copy = copy_op.get_slice(_0{});
|
||||
|
||||
auto tCg = thr_copy.partition_D(gmem);
|
||||
auto tCr = thr_copy.partition_S(quantize(regs));
|
||||
auto tCc = thr_copy.partition_D(coord);
|
||||
Tensor tCg = thr_copy.partition_D(gmem);
|
||||
Tensor tCr = thr_copy.partition_S(quantize(regs));
|
||||
Tensor tPc = thr_copy.partition_D(preds);
|
||||
|
||||
constexpr int R = decltype(tCr.layout())::rank;
|
||||
auto tCg_v = group_modes<1, R>(tCg);
|
||||
auto tCr_v = group_modes<1, R>(tCr);
|
||||
auto tCc_v = group_modes<1, R>(tCc);
|
||||
auto tCp_v = make_tensor<bool>(shape<1>(tCc_v));
|
||||
|
||||
for (int i = 0; i < size(tCp_v); ++i) {
|
||||
tCp_v(i) = elem_less(tCc_v(_0{},i), tensor_shape);
|
||||
}
|
||||
|
||||
copy_if(copy_op, tCp_v, tCr_v, tCg_v);
|
||||
copy_if(copy_op, tPc, tCr, tCg);
|
||||
}
|
||||
|
||||
|
||||
@ -1073,7 +1065,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
|
||||
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) {
|
||||
|
||||
|
||||
|
||||
auto [Q, K, D, HB] = problem_shape;
|
||||
|
||||
// in tmem, S & P overlap
|
||||
@ -1114,7 +1106,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
Tensor tTR_cST = split_wg(thread_t2r.partition_D(cST));
|
||||
Tensor tTR_rST = make_tensor<ElementAcc>(shape(tTR_cST));
|
||||
Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST));
|
||||
|
||||
|
||||
Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT);
|
||||
Tensor tTR_cDPT = split_wg(tTR_cDPT_p);
|
||||
Tensor tTR_rDPT = make_tensor<ElementAcc>(shape(tTR_cDPT));
|
||||
@ -1152,20 +1144,20 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
fn(cute::false_type{});
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
dispatch_bool(std::is_base_of_v<cutlass::fmha::collective::CausalMask, Mask> &&
|
||||
warp_uniform(iter_index == get<1>(blk_coord)), [&](auto is_causal_masked_tile) {
|
||||
|
||||
// compute P = softmax(S, LSE)
|
||||
cute::copy(tiled_t2r, tTR_tST, tTR_rST);
|
||||
|
||||
|
||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask, Mask> && decltype(is_causal_masked_tile)::value) {
|
||||
Mask{}.apply_mask(tTR_rST, [&](int i) {
|
||||
auto c_transpose = tTR_cST(i);
|
||||
return make_coord(get<1>(c_transpose) + iter_index * TileShapeQ{}, get<0>(c_transpose) + get<1>(blk_coord) * TileShapeK{});
|
||||
}, problem_shape);
|
||||
}
|
||||
|
||||
|
||||
ElementAcc log2_e = static_cast<ElementAcc>(M_LOG2E);
|
||||
float2 softmax_scale_log2_e;
|
||||
softmax_scale_log2_e.x = mainloop_args.softmax_scale * log2_e;
|
||||
@ -1184,16 +1176,16 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
tTR_rST(i) = ::exp2f(out.x);
|
||||
tTR_rST(i+1) = ::exp2f(out.y);
|
||||
}
|
||||
|
||||
|
||||
auto tRT_rST = quantize(tTR_rST);
|
||||
auto tRT_rST_reshaped = make_tensor(tRT_rST.data(), shape(tRT_cST));
|
||||
|
||||
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
cutlass::arch::NamedBarrier(
|
||||
kNumComputeWarps * NumThreadsPerWarp,
|
||||
cutlass::arch::ReservedNamedBarriers::TransformBarrier
|
||||
).arrive_and_wait();
|
||||
|
||||
|
||||
cute::copy(tiled_r2t, tRT_rST_reshaped, tRT_tP);
|
||||
});
|
||||
|
||||
@ -1293,9 +1285,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state,
|
||||
PipelineReduceTmaStore& pipeline_reduce_tma_store,
|
||||
typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state) {
|
||||
|
||||
|
||||
using X = Underscore;
|
||||
|
||||
|
||||
auto [Q, K, D, HB] = problem_shape;
|
||||
|
||||
auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord;
|
||||
@ -1307,7 +1299,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
tDQtDQ.data() = TmemAllocation::kDQ;
|
||||
|
||||
Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB));
|
||||
auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step<_1, _1, X>{})
|
||||
auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step<X, _1, _1>{})
|
||||
(_, _, _, _0{}, blk_coord_batch);
|
||||
|
||||
Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{}));
|
||||
@ -1376,7 +1368,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
iter_index += 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
CUTLASS_DEVICE void operator()(Params const& params, char* smem) {
|
||||
int warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
@ -1561,7 +1553,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
typename decltype(pipeline_compute_mma_p)::PipelineState pipeline_compute_mma_p_consumer_state;
|
||||
typename decltype(pipeline_compute_mma_ds)::PipelineState pipeline_compute_mma_ds_consumer_state;
|
||||
typename decltype(pipeline_mma_compute_dkdv)::PipelineState pipeline_mma_compute_dkdv_consumer_state;
|
||||
|
||||
|
||||
auto pipeline_load_mma_q_producer_state = make_producer_start_state<decltype(pipeline_load_mma_q)>();
|
||||
auto pipeline_load_mma_do_producer_state = make_producer_start_state<decltype(pipeline_load_mma_do)>();
|
||||
auto pipeline_load_compute_lse_producer_state = make_producer_start_state<decltype(pipeline_load_compute_lse)>();
|
||||
@ -1587,7 +1579,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
|
||||
if (role == WarpRole::Load) {
|
||||
warpgroup_reg_set<RegisterAllocation::kLoad>();
|
||||
|
||||
|
||||
load(
|
||||
blk_coord,
|
||||
problem_shape,
|
||||
@ -1596,7 +1588,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
params.mainloop,
|
||||
params.mainloop_params,
|
||||
shared_storage.tensors,
|
||||
pipeline_load_mma_q, pipeline_load_mma_q_producer_state,
|
||||
pipeline_load_mma_q, pipeline_load_mma_q_producer_state,
|
||||
pipeline_load_mma_do, pipeline_load_mma_do_producer_state,
|
||||
pipeline_load_compute_lse, pipeline_load_compute_lse_producer_state,
|
||||
pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_producer_state
|
||||
@ -1608,7 +1600,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
|
||||
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr);
|
||||
__syncwarp();
|
||||
|
||||
|
||||
mma(
|
||||
blk_coord,
|
||||
problem_shape,
|
||||
@ -1616,7 +1608,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
iter_count,
|
||||
params.mainloop,
|
||||
shared_storage.tensors,
|
||||
pipeline_load_mma_q, pipeline_load_mma_q_consumer_state,
|
||||
pipeline_load_mma_q, pipeline_load_mma_q_consumer_state,
|
||||
pipeline_load_mma_do, pipeline_load_mma_do_consumer_state,
|
||||
pipeline_mma_compute_s, pipeline_mma_compute_s_producer_state,
|
||||
pipeline_mma_compute_dp, pipeline_mma_compute_dp_producer_state,
|
||||
@ -1629,7 +1621,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
}
|
||||
else if (role == WarpRole::Compute) {
|
||||
warpgroup_reg_set<RegisterAllocation::kCompute>();
|
||||
|
||||
|
||||
compute(
|
||||
blk_coord,
|
||||
problem_shape,
|
||||
@ -1660,7 +1652,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
}
|
||||
else if (role == WarpRole::Reduce) {
|
||||
warpgroup_reg_set<RegisterAllocation::kReduce>();
|
||||
|
||||
|
||||
reduce(
|
||||
blk_coord,
|
||||
problem_shape,
|
||||
@ -1677,9 +1669,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
}
|
||||
else {
|
||||
warpgroup_reg_set<RegisterAllocation::kEmpty>();
|
||||
|
||||
|
||||
/* no-op */
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -356,7 +356,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineO>();
|
||||
|
||||
CollectiveMainloop mainloop;
|
||||
CollectiveEpilogue epilogue;
|
||||
CollectiveEpilogue epilogue{params.epilogue};
|
||||
|
||||
if (role == WarpRole::Softmax0 || role == WarpRole::Softmax1) {
|
||||
warpgroup_reg_set<NumRegsSoftmax>();
|
||||
@ -407,7 +407,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
pipeline_s0_corr, pipeline_s0_corr_consumer_state,
|
||||
pipeline_s1_corr, pipeline_s1_corr_consumer_state,
|
||||
pipeline_mma_corr, pipeline_mma_corr_consumer_state,
|
||||
pipeline_corr_epi, pipeline_corr_epi_producer_state
|
||||
pipeline_corr_epi, pipeline_corr_epi_producer_state,
|
||||
epilogue
|
||||
);
|
||||
|
||||
|
||||
|
||||
@ -146,7 +146,7 @@ struct Sm100FmhaMlaReductionKernel {
|
||||
ElementAcc sum_lse = 0;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
sum_lse = sum_lse + expf(local_lse[i] - params.scale * lse_max);
|
||||
sum_lse = sum_lse + expf(local_lse[i] - lse_max);
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
@ -156,7 +156,7 @@ struct Sm100FmhaMlaReductionKernel {
|
||||
|
||||
sum_lse = __shfl_sync(0xffffffff, sum_lse, 0);
|
||||
|
||||
ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementAcc>::infinity() : logf(sum_lse) + params.scale * lse_max;
|
||||
ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementAcc>::infinity() : logf(sum_lse) + lse_max;
|
||||
if (threadIdx.x == 0 and params.ptr_lse != nullptr) {
|
||||
gLSE(0) = global_lse;
|
||||
}
|
||||
|
||||
@ -127,7 +127,7 @@ void __global__ fmha_reference_kernel(
|
||||
mO(idx_Q + offset_Q, idx_D, idx_L) = static_cast<typename TensorO::value_type>(acc * scale);
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
if (threadIdx.x == 0 && mLSE.data() != nullptr) {
|
||||
mLSE(idx_Q + offset_Q, idx_L) = log(sum) + softmax_scale * maxS;
|
||||
}
|
||||
|
||||
|
||||
@ -75,6 +75,8 @@ struct DeviceAllocation {
|
||||
|
||||
size_t size() const { return size_; }
|
||||
|
||||
size_t get_storage_size() const { return (size_ + offset_) * sizeof(T); }
|
||||
|
||||
void copy_from_host(const T* ptr, size_t sz) {
|
||||
auto ret = cudaMemcpy(ptr_, ptr, sz * sizeof(T), cudaMemcpyDefault);
|
||||
assert(ret == cudaSuccess);
|
||||
|
||||
@ -280,7 +280,7 @@ auto make_iterator(T* ptr) {
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100GroupParams<typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
|
||||
using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
|
||||
@ -133,7 +133,7 @@ using TP = _8;
|
||||
static constexpr int TP_ = TP{};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && \
|
||||
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4))
|
||||
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
|
||||
|
||||
// Distributed GEMM tiling/sharding schedule
|
||||
// Choices:
|
||||
@ -254,7 +254,8 @@ HostTensorB tensor_B_arr[TP_];
|
||||
HostTensorD tensor_C_arr[TP_];
|
||||
HostTensorD tensor_D_arr[TP_];
|
||||
|
||||
#endif // (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
|
||||
#endif // (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) &&
|
||||
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
@ -347,7 +348,7 @@ struct Result {
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && \
|
||||
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4))
|
||||
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
@ -805,17 +806,16 @@ int run(Options &options) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
|
||||
#endif // (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) &&
|
||||
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA Toolkit 12.4 or newer to run this example
|
||||
// and must have compute capability at least 90.
|
||||
// Some necessary cuda graph APIs were only introduced in CUDA 12.4.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4)) {
|
||||
std::cerr << "This example requires CUDA 12.4 or newer." << std::endl;
|
||||
// CUTLASS must be compiled with CUDA Toolkit 12.8 or newer to run Blackwell kernels.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
@ -861,11 +861,11 @@ int main(int argc, char const **args) {
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
|
||||
#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)))
|
||||
run(options);
|
||||
#else
|
||||
std::cerr
|
||||
<< "This example must be compiled with `sm100a` and CUDA Toolkit 12.4 or later." << std::endl;
|
||||
<< "This example must be compiled with `sm100a` and CUDA Toolkit 12.8 or later." << std::endl;
|
||||
return 0;
|
||||
#endif
|
||||
|
||||
|
||||
@ -14,8 +14,8 @@ cmake $PATH -DCUTLASS_NVCC_ARCHS="100a" -DCUTLASS_ENABLE_GDC_FOR_SM100=1
|
||||
### Minimum software
|
||||
|
||||
Like all other CUTLASS examples, the NVIDIA driver, runtime, and CUDA Toolkit are required.
|
||||
This example specifically requires CUDA Toolkit 12.6 or newer, due to some of the necessary
|
||||
CUDA graph APIs.
|
||||
This example specifically requires CUDA Toolkit 12.8 or newer, since that is the first version
|
||||
supporting the Blackwell architecture.
|
||||
|
||||
### Hardware / driver settings
|
||||
|
||||
|
||||
1192
examples/88_hopper_fmha/88_hopper_fmha.cu
Normal file
1192
examples/88_hopper_fmha/88_hopper_fmha.cu
Normal file
File diff suppressed because it is too large
Load Diff
50
examples/88_hopper_fmha/CMakeLists.txt
Normal file
50
examples/88_hopper_fmha/CMakeLists.txt
Normal file
@ -0,0 +1,50 @@
|
||||
# Copyright (c) 2014 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cutlass_example_add_executable(
|
||||
88_hopper_fmha
|
||||
88_hopper_fmha.cu
|
||||
)
|
||||
|
||||
if(NOT WIN32 AND NOT CUTLASS_CLANG_HOST_COMPILE)
|
||||
|
||||
set_property(
|
||||
SOURCE 88_hopper_fmha.cu
|
||||
PROPERTY COMPILE_FLAGS "--use_fast_math"
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
88_hopper_fmha_fp8
|
||||
88_hopper_fmha.cu
|
||||
)
|
||||
|
||||
target_compile_definitions(
|
||||
88_hopper_fmha_fp8
|
||||
PRIVATE FP8)
|
||||
|
||||
endif()
|
||||
77
examples/88_hopper_fmha/README.md
Normal file
77
examples/88_hopper_fmha/README.md
Normal file
@ -0,0 +1,77 @@
|
||||
# CUTLASS Hopper FMHA Example
|
||||
|
||||
This sample showcases how to implement fused multi-head attention (FMHA) using
|
||||
CUTLASS for the NVIDIA Hopper architecture. At its heart, the forward pass of
|
||||
FMHA is a GEMM-online softmax-GEMM fusion, whereas the backward pass is a slightly
|
||||
more complex structure (basically, a GEMM-softmax-2xGEMM-2xGEMM fusion).
|
||||
For more information please refer to the [Flash Attention 3 paper](https://arxiv.org/abs/2407.08608).
|
||||
|
||||
The forward pass kernel supports head dims 32, 64, 128, and 256 for fp16 and bf16 input data types,
|
||||
and head dims 128, and 256 for fp8.
|
||||
All kernels use the Tensor Memory Accelerator for loads.
|
||||
Kernels with head dims 128 and 256 have warp-specialized cooperative schedules.
|
||||
|
||||
Backward pass kernels (fp16 only) support head dims 32, 64, and 128, and all support
|
||||
warp-specialized cooperative schedules.
|
||||
|
||||
## Customization
|
||||
|
||||
### Mask Fusion
|
||||
|
||||
Similar to the [Blackwell FMHA example](../77_blackwell_fmha/README.md), attention masks such as
|
||||
causal masking can be fused into the kernel. To modify the code for such fusions,
|
||||
`collective/fmha_fusion.hpp` provides the easiest customization point.
|
||||
The `before_softmax` function is called with the accumulator of the first GEMM and the logical
|
||||
positions of those elements. It is well-suited for applying masks or activations.
|
||||
|
||||
### MHA Variants
|
||||
|
||||
Using CuTe, it is easy to represent the various attention variants.
|
||||
Where regular multi-head attention's layout for the head dimension is (numHeads:headStride),
|
||||
for single-head attention it is simply (1:0) everywhere,
|
||||
for GQA it is normal in Q and (numHeads/numGroups,numGroups:headStride,0) in KV,
|
||||
and for MQA it is normal for Q and (numHeads:0) in KV.
|
||||
As such, beyond general stride handling, no additional work is needed to support these,
|
||||
and the example will just demonstrate regular multi-head attention.
|
||||
|
||||
### FP8
|
||||
|
||||
The warp-specialized forward kernel supports FP8 computation with both FP32 and FP16
|
||||
accumulation for the Q*K product. They can be enabled in the runner by defining FP8.
|
||||
|
||||
## Performance
|
||||
Forward pass kernels can generally come close to that of FA3, but backward pass
|
||||
kernels are more limited in performance and are not expected to reach the same level of performance
|
||||
as FA3.
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
@ -0,0 +1,863 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "../collective/fmha_common.hpp"
|
||||
#include "../collective/fmha_collective_load.hpp"
|
||||
#include "../collective/fmha_collective_softmax.hpp"
|
||||
#include "../kernel/fmha_options.hpp"
|
||||
|
||||
namespace cutlass::fmha::collective {
|
||||
|
||||
template<
|
||||
typename Element_,
|
||||
typename ElementAccumulator_,
|
||||
typename TileShape_, // BlockQO, BlockKV, BlockHead
|
||||
class Fusion,
|
||||
class... Options
|
||||
>
|
||||
struct FmhaBwdMainloopTmaWarpSpecialized {
|
||||
|
||||
using Element = Element_;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using TileShape = TileShape_;
|
||||
|
||||
static constexpr bool kIsPersistent = false;
|
||||
|
||||
static const int NumLoadWarpGroups = 1;
|
||||
static constexpr int NumMmaWarpGroups = 2;
|
||||
static constexpr int StageCountQ = 2 /*K, V*/ * NumMmaWarpGroups;
|
||||
static constexpr int StageCount = 2 /*Q, dO*/ * 2 /* actual stages */;
|
||||
|
||||
static const int kOuterLoads = 2;
|
||||
using StagesQ = cutlass::gemm::collective::StageCount<StageCountQ>;
|
||||
using Stages = cutlass::gemm::collective::StageCount<StageCount>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
static_assert(StagesQ::value >= 2);
|
||||
static_assert(Stages::value >= 2 * NumMmaWarpGroups);
|
||||
|
||||
// 16B alignment lets us use TMA
|
||||
static constexpr int Alignment = 16 / sizeof(Element);
|
||||
|
||||
using TileShapeNM = Shape< // (N,M,D)
|
||||
decltype(tuple_element_t<1, TileShape>{} / Int<NumMmaWarpGroups>{}),
|
||||
tuple_element_t<0, TileShape>,
|
||||
tuple_element_t<2, TileShape>>;
|
||||
|
||||
using TileShapeND = decltype(select<0,2,1>(TileShapeNM{})); // (N,D,M)
|
||||
|
||||
using TileShapeMD = decltype(select<2,1,0>(TileShapeND{})); // (M,D,N)
|
||||
|
||||
using CollectiveMmaNM = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Element, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment,
|
||||
Element, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment,
|
||||
ElementAccumulator,
|
||||
TileShapeNM, ClusterShape, Stages,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
|
||||
|
||||
using CollectiveMmaND = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Element, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment, // from register, doesn't matter
|
||||
Element, cute::tuple<_1, int, cute::tuple<int, int>>, Alignment,
|
||||
ElementAccumulator,
|
||||
TileShapeND, ClusterShape, Stages,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
|
||||
|
||||
using CollectiveMmaND_SS = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Element, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment, // from register, doesn't matter
|
||||
Element, cute::tuple<_1, int, cute::tuple<int, int>>, Alignment,
|
||||
ElementAccumulator,
|
||||
TileShapeND, ClusterShape, Stages,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
|
||||
|
||||
|
||||
using CollectiveMmaMD = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Element, cute::tuple<_1, int, cute::tuple<int, int>>, Alignment, // from smem, might matter (?)
|
||||
Element, cute::tuple<_1, int, cute::tuple<int, int>>, Alignment,
|
||||
ElementAccumulator,
|
||||
TileShapeMD, ClusterShape, Stages,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
|
||||
|
||||
using TiledMmaNM = typename CollectiveMmaNM::TiledMma;
|
||||
using TiledMmaND_SS = typename CollectiveMmaND_SS::TiledMma;
|
||||
using TiledMmaND_RS = decltype(convert_to_gmma_rs(typename CollectiveMmaND::TiledMma{}));
|
||||
using TiledMmaND = TiledMmaND_RS;
|
||||
using TiledMmaMD = typename CollectiveMmaMD::TiledMma;
|
||||
|
||||
using SmemLayoutQ = typename CollectiveMmaNM::SmemLayoutB;
|
||||
using SmemLayoutK = typename CollectiveMmaNM::SmemLayoutA;
|
||||
using SmemLayoutV = typename CollectiveMmaNM::SmemLayoutA;
|
||||
using SmemLayoutDO = typename CollectiveMmaNM::SmemLayoutB;
|
||||
|
||||
//using SmemLayoutDQ = Layout<
|
||||
// Shape<
|
||||
// tuple_element_t<0, TileShapeMD>,
|
||||
// Shape<_2, _4, decltype(tuple_element_t<1, TileShapeMD>{} / _8{})>,
|
||||
// _2
|
||||
// >,
|
||||
// Stride<
|
||||
// _4,
|
||||
// Stride<decltype(tuple_element_t<0, TileShapeMD>{} * _4{}), _1, decltype(tuple_element_t<0, TileShapeMD>{} * _8{})>,
|
||||
// decltype(tuple_element_t<0, TileShapeMD>{} * tuple_element_t<1, TileShapeMD>{})
|
||||
// >>;
|
||||
|
||||
using SmemLayoutDQ_0 = Layout<
|
||||
Shape<
|
||||
tuple_element_t<0, TileShapeMD>,
|
||||
tuple_element_t<1, TileShapeMD>,
|
||||
_2
|
||||
>,
|
||||
Stride<
|
||||
tuple_element_t<1, TileShapeMD>,
|
||||
_1,
|
||||
decltype(tuple_element_t<0, TileShapeMD>{} * tuple_element_t<1, TileShapeMD>{})
|
||||
>>;
|
||||
|
||||
using SmemAtomDQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
cute::GMMA::Major::K, ElementAccumulator, tuple_element_t<0, TileShapeMD>, tuple_element_t<1, TileShapeMD>>());
|
||||
using SmemLayoutDQ_1 = decltype(tile_to_shape(SmemAtomDQ{}, make_shape(get<0>(TileShapeMD{}), get<1>(TileShapeMD{}), _2{}), Step<_2, _1, _3>{}));
|
||||
using SmemLayoutDQ = SmemLayoutDQ_1;
|
||||
|
||||
|
||||
using PipelineDQ = cutlass::PipelineAsync<2>;
|
||||
|
||||
|
||||
using SmemLayoutDS_0 = decltype(unstageSmemLayout(typename CollectiveMmaMD::SmemLayoutA{}, Int<NumMmaWarpGroups>{}));
|
||||
|
||||
using SmemLayoutDS = decltype(tile_to_shape(GMMA::Layout_MN_INTER_Atom<Element>{}, make_shape(size<0>(SmemLayoutDS_0{}), size<1>(SmemLayoutDS_0{}), size<2>(SmemLayoutDS_0{})), Step<_1, _2, _3>{}));
|
||||
using SmemLayoutKp = typename CollectiveMmaMD::SmemLayoutB;
|
||||
|
||||
using SmemLayoutQp = typename CollectiveMmaND::SmemLayoutB;
|
||||
using SmemLayoutDOp = typename CollectiveMmaND::SmemLayoutB;
|
||||
|
||||
using SmemLayoutLSE = Layout<Shape<tuple_element_t<1, TileShapeNM>, Int<StageCount>>>;
|
||||
|
||||
using MainloopPipeline = cutlass::PipelineTmaAsync<Stages::value>;
|
||||
using MainloopPipelineQ = cutlass::PipelineTmaAsync<StagesQ::value>;
|
||||
|
||||
using PipelineState = typename cutlass::PipelineState<MainloopPipeline::Stages>;
|
||||
using PipelineStateQ = typename cutlass::PipelineState<MainloopPipelineQ::Stages>;
|
||||
|
||||
using TileShapePV = TileShapeND; // To work with the kernel level
|
||||
using TiledMmaPV = TiledMmaND;
|
||||
|
||||
static constexpr int kInnerLoadBytes = size(SmemLayoutQ{}(_,_,_0{})) * sizeof(Element) + size(SmemLayoutLSE{}(_,_0{})) * sizeof(ElementAccumulator);
|
||||
static constexpr int kOuterLoadBytes = size(SmemLayoutK{}(_,_,_0{})) * sizeof(Element);
|
||||
|
||||
struct SharedStorage {
|
||||
// One for each consumer WG
|
||||
union {
|
||||
cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
|
||||
cute::array_aligned<Element, cute::cosize_v<SmemLayoutKp>> smem_kp;
|
||||
cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
|
||||
};
|
||||
|
||||
cute::array_aligned<Element, cute::cosize_v<SmemLayoutDS>> smem_ds;
|
||||
|
||||
// Loaded by producer, consumed by both WGs
|
||||
union {
|
||||
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
|
||||
cute::array_aligned<Element, cute::cosize_v<SmemLayoutDO>> smem_do;
|
||||
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQp>> smem_qp;
|
||||
cute::array_aligned<Element, cute::cosize_v<SmemLayoutDOp>> smem_dop;
|
||||
};
|
||||
|
||||
// Accumulated into by both consumers, potentially loaded, potentially written
|
||||
cute::array_aligned<ElementAccumulator, cute::cosize_v<SmemLayoutDQ>> smem_dq;
|
||||
|
||||
union {
|
||||
cute::array_aligned<ElementAccumulator, cute::cosize_v<SmemLayoutLSE>> smem_lse;
|
||||
cute::array_aligned<ElementAccumulator, cute::cosize_v<SmemLayoutLSE>> smem_sumOdO;
|
||||
};
|
||||
};
|
||||
|
||||
struct Arguments {
|
||||
const Element* ptr_Q;
|
||||
cute::tuple<int, int, int, _1> dQ;
|
||||
const Element* ptr_K;
|
||||
cute::tuple<int, int, int, _1> dK;
|
||||
const Element* ptr_V;
|
||||
cute::tuple<int, int, int, _1> dV;
|
||||
|
||||
const Element* ptr_dO;
|
||||
cute::tuple<int, int, int, _1> dDO;
|
||||
|
||||
const ElementAccumulator* ptr_LSE;
|
||||
cute::tuple<int, int, _1> dLSE;
|
||||
const ElementAccumulator* ptr_sum_OdO;
|
||||
cute::tuple<int, int, _1> dSumOdO;
|
||||
|
||||
ElementAccumulator* ptr_dQ;
|
||||
cute::tuple<int, int, int, _1> dDQ;
|
||||
};
|
||||
|
||||
using TMA_Q = typename CollectiveMmaNM::Params::TMA_B;
|
||||
using TMA_K = typename CollectiveMmaNM::Params::TMA_A;
|
||||
using TMA_V = typename CollectiveMmaNM::Params::TMA_A;
|
||||
using TMA_DO = typename CollectiveMmaNM::Params::TMA_B;
|
||||
|
||||
using TMA_LSE = decltype(make_tma_copy(SM90_TMA_LOAD{}, make_tensor((const ElementAccumulator*)nullptr, make_shape(1, 1, 1), make_stride(_1{}, 0, 0)), SmemLayoutLSE{}(_,_0{})));
|
||||
using TMA_ODO = TMA_LSE;
|
||||
|
||||
using TMA_DQ = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{}, make_tensor((const ElementAccumulator*)nullptr, make_shape(1, 1, 1, 1), make_stride(0, _1{}, 0, 0)), SmemLayoutDQ{}(_,_,_0{})));
|
||||
|
||||
using LoadQ = CollectiveLoadTma<
|
||||
LoadKind::kBwdM,
|
||||
MainloopPipeline,
|
||||
Element,
|
||||
SmemLayoutQ,
|
||||
TMA_Q
|
||||
>;
|
||||
|
||||
using LoadK = CollectiveLoadTma<
|
||||
LoadKind::kBwdN,
|
||||
MainloopPipelineQ,
|
||||
Element,
|
||||
SmemLayoutK,
|
||||
TMA_K
|
||||
>;
|
||||
|
||||
using LoadV = CollectiveLoadTma<
|
||||
LoadKind::kBwdN,
|
||||
MainloopPipelineQ,
|
||||
Element,
|
||||
SmemLayoutV,
|
||||
TMA_V
|
||||
>;
|
||||
|
||||
using LoadDO = CollectiveLoadTma<
|
||||
LoadKind::kBwdM,
|
||||
MainloopPipeline,
|
||||
Element,
|
||||
SmemLayoutDO,
|
||||
TMA_DO
|
||||
>;
|
||||
|
||||
using LoadLSE = CollectiveLoadTma<
|
||||
LoadKind::kBwdScalar,
|
||||
MainloopPipeline,
|
||||
ElementAccumulator,
|
||||
SmemLayoutLSE,
|
||||
TMA_LSE
|
||||
>;
|
||||
|
||||
using LoadODO = CollectiveLoadTma<
|
||||
LoadKind::kBwdScalar,
|
||||
MainloopPipeline,
|
||||
ElementAccumulator,
|
||||
SmemLayoutLSE,
|
||||
TMA_ODO
|
||||
>;
|
||||
|
||||
struct Params {
|
||||
TMA_Q tma_load_q;
|
||||
TMA_K tma_load_k;
|
||||
TMA_V tma_load_v;
|
||||
TMA_DO tma_load_do;
|
||||
|
||||
TMA_LSE tma_load_lse;
|
||||
TMA_ODO tma_load_odo;
|
||||
|
||||
TMA_DQ tma_red_dq;
|
||||
|
||||
float scale_softmax;
|
||||
float scale_softmax_log2;
|
||||
};
|
||||
|
||||
static_assert(size(TiledMmaNM{}) == size(TiledMmaND{}));
|
||||
static_assert(size(TiledMmaNM{}) == size(TiledMmaMD{}));
|
||||
|
||||
template<class ProblemShape>
|
||||
static bool can_implement(ProblemShape const& problem_size, Arguments const& args) {
|
||||
return true
|
||||
&& (get<4>(problem_size) <= get<2>(TileShape{}))
|
||||
&& ((get<4>(problem_size) % Alignment) == 0)
|
||||
&& ((get<2>(problem_size) % Alignment) == 0)
|
||||
;
|
||||
}
|
||||
|
||||
template<class ProblemShape>
|
||||
static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace) {
|
||||
auto problem_shape_nm = make_shape(get<3>(problem_size), get<2>(problem_size), get<4>(problem_size), make_shape(get<0>(problem_size), get<1>(problem_size)));
|
||||
|
||||
auto dK = make_stride(get<2>(args.dK), get<3>(args.dK), make_stride(get<0>(args.dK), get<1>(args.dK)));
|
||||
auto dQ = make_stride(get<2>(args.dQ), get<3>(args.dQ), make_stride(get<0>(args.dQ), get<1>(args.dQ)));
|
||||
auto params_nm_kq = CollectiveMmaNM::to_underlying_arguments(problem_shape_nm,
|
||||
typename CollectiveMmaNM::Arguments {
|
||||
args.ptr_K, dK,
|
||||
args.ptr_Q, dQ,
|
||||
}, /*workspace=*/ nullptr);
|
||||
|
||||
auto dV = make_stride(get<2>(args.dV), get<3>(args.dV), make_stride(get<0>(args.dV), get<1>(args.dV)));
|
||||
auto dDO = make_stride(get<2>(args.dDO), get<3>(args.dDO), make_stride(get<0>(args.dDO), get<1>(args.dDO)));
|
||||
auto params_nm_vdo = CollectiveMmaNM::to_underlying_arguments(problem_shape_nm,
|
||||
typename CollectiveMmaNM::Arguments {
|
||||
args.ptr_V, dV,
|
||||
args.ptr_dO, dDO,
|
||||
}, /*workspace=*/ nullptr);
|
||||
|
||||
|
||||
TMA_LSE tma_load_lse = make_tma_copy(SM90_TMA_LOAD{}, make_tensor(args.ptr_LSE, select<2,0,1>(problem_size), select<2,0,1>(args.dLSE)), SmemLayoutLSE{}(_,_0{}));
|
||||
TMA_ODO tma_load_odo = make_tma_copy(SM90_TMA_LOAD{}, make_tensor(args.ptr_sum_OdO, select<2,0,1>(problem_size), select<2,0,1>(args.dSumOdO)), SmemLayoutLSE{}(_,_0{}));
|
||||
|
||||
TMA_DQ tma_red_dq = make_tma_copy(SM90_TMA_REDUCE_ADD{}, make_tensor(args.ptr_dQ, select<2,4,0,1>(problem_size), select<2,3,0,1>(args.dDQ)), SmemLayoutDQ{}(_,_,_0{}));
|
||||
|
||||
return Params{
|
||||
params_nm_kq.tma_load_b,
|
||||
params_nm_kq.tma_load_a,
|
||||
params_nm_vdo.tma_load_a,
|
||||
params_nm_vdo.tma_load_b,
|
||||
tma_load_lse, tma_load_odo,
|
||||
tma_red_dq,
|
||||
1.0f / (float) std::sqrt(get<4>(problem_size)),
|
||||
(float) (std::log2(std::exp(1.0)) / std::sqrt(get<4>(problem_size)))
|
||||
};
|
||||
}
|
||||
|
||||
template<class BlkCoord, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
auto
|
||||
get_inner_tile_count(BlkCoord const& blk_coord, ProblemSize const& problem_size) {
|
||||
return Fusion{}.get_trip_count(blk_coord, TileShape{}, problem_size);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& params) {
|
||||
cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(params.tma_load_do.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(params.tma_load_odo.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(params.tma_load_lse.get_tma_descriptor());
|
||||
}
|
||||
|
||||
template<bool kLoadOuter, class BlkCoord, class ProblemShape, class LoadWarpBarrier>
|
||||
CUTLASS_DEVICE void
|
||||
load_kv_maybe_q(
|
||||
int block_rank_in_cluster,
|
||||
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size,
|
||||
MainloopPipeline& pipeline_inner, PipelineState& smem_pipe_write_inner,
|
||||
MainloopPipelineQ& pipeline_outer, PipelineStateQ& smem_pipe_write_outer,
|
||||
SharedStorage& storage,
|
||||
LoadWarpBarrier& load_warp_barrier, bool do_barrier)
|
||||
{
|
||||
// Load pattern:
|
||||
// K0 V0 K1 V1
|
||||
// Q0 DO0 Q1 DO1 Q2 DO2 ...
|
||||
// K0 Q0 V0 K1 DO0 V1 ...
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
int outer_tile_count = NumMmaWarpGroups;
|
||||
int inner_tile_count = get_inner_tile_count(blk_coord, problem_size);
|
||||
|
||||
auto outer_tile_iter = cute::make_coord_iterator(outer_tile_count);
|
||||
auto inner_tile_iter = cute::make_coord_iterator(inner_tile_count);
|
||||
|
||||
uint16_t mcast_mask_b = 0;
|
||||
|
||||
LoadQ load_q{params.tma_load_q, pipeline_inner, storage.smem_q};
|
||||
auto load_state_q = load_q.init_state(block_rank_in_cluster, problem_size, TileShapeNM{}, blk_coord, inner_tile_count);
|
||||
|
||||
LoadDO load_do{params.tma_load_do, pipeline_inner, storage.smem_do};
|
||||
auto load_state_do = load_do.init_state(block_rank_in_cluster, problem_size, TileShapeNM{}, blk_coord, inner_tile_count);
|
||||
|
||||
LoadK load_k{params.tma_load_k, pipeline_outer, storage.smem_k};
|
||||
auto load_state_k = load_k.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count);
|
||||
|
||||
LoadV load_v{params.tma_load_v, pipeline_outer, storage.smem_v};
|
||||
auto load_state_v = load_v.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count);
|
||||
|
||||
LoadLSE load_lse{params.tma_load_lse, pipeline_inner, storage.smem_lse};
|
||||
auto load_state_lse = load_lse.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count);
|
||||
|
||||
LoadODO load_odo{params.tma_load_odo, pipeline_inner, storage.smem_sumOdO};
|
||||
auto load_state_odo = load_odo.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count);
|
||||
|
||||
outer_tile_count *= 2; // K & V
|
||||
inner_tile_count *= 4; // Q & dO & LSE & sumOdO
|
||||
|
||||
while (inner_tile_count > 0) {
|
||||
if (Fusion{}.is_contributing(make_coord(*inner_tile_iter, get<1>(blk_coord)), TileShape{}, problem_size)) {
|
||||
break;
|
||||
}
|
||||
inner_tile_count -= 4;
|
||||
++inner_tile_iter;
|
||||
}
|
||||
|
||||
if constexpr (kLoadOuter) {
|
||||
load_k.template step<false>(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count);
|
||||
}
|
||||
|
||||
load_q.template step<false,false,true>(inner_tile_iter, load_state_q, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
|
||||
load_lse.template step<false,true,false>(inner_tile_iter, load_state_lse, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
|
||||
|
||||
if constexpr (! kLoadOuter) {
|
||||
if (do_barrier) {
|
||||
load_warp_barrier.arrive();
|
||||
load_warp_barrier.wait(/*phase=*/ 0);
|
||||
do_barrier = false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (kLoadOuter) {
|
||||
load_v.template step<true>(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count);
|
||||
load_k.template step<false>(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count);
|
||||
}
|
||||
|
||||
load_do.template step<false,false,true>(inner_tile_iter, load_state_do, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
|
||||
load_odo.template step<true,true,false>(inner_tile_iter, load_state_odo, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
|
||||
|
||||
if constexpr (kLoadOuter) {
|
||||
load_v.template step<true>(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count);
|
||||
}
|
||||
|
||||
if constexpr (kLoadOuter) {
|
||||
while (outer_tile_count > 0) {
|
||||
load_k.template step<false>(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count);
|
||||
load_v.template step<true>(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count);
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
while (inner_tile_count > 0) {
|
||||
while (inner_tile_count > 0) {
|
||||
if (Fusion{}.is_contributing(make_coord(*inner_tile_iter, get<1>(blk_coord)), TileShape{}, problem_size)) {
|
||||
break;
|
||||
}
|
||||
inner_tile_count -= 4;
|
||||
++inner_tile_iter;
|
||||
}
|
||||
load_q.template step<false,false,true>(inner_tile_iter, load_state_q, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
|
||||
load_lse.template step<false,true,false>(inner_tile_iter, load_state_lse, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
|
||||
|
||||
load_do.template step<false,false,true>(inner_tile_iter, load_state_do, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
|
||||
load_odo.template step<true,true,false>(inner_tile_iter, load_state_odo, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
|
||||
}
|
||||
}
|
||||
|
||||
template<class BlkCoord, class ProblemShape, class LoadWarpBarrier>
|
||||
CUTLASS_DEVICE void
|
||||
load_maybe_q(
|
||||
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size,
|
||||
MainloopPipelineQ& pipeline_outer, PipelineStateQ& smem_pipe_write_outer,
|
||||
SharedStorage& storage,
|
||||
LoadWarpBarrier& load_warp_barrier, bool do_barrier)
|
||||
{
|
||||
// Load pattern:
|
||||
// K0 V0 K1 V1
|
||||
// Q0 DO0 Q1 DO1 Q2 DO2 ...
|
||||
// K0 Q0 V0 K1 DO0 V1 ...
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
int outer_tile_count = NumMmaWarpGroups;
|
||||
|
||||
auto outer_tile_iter = cute::make_coord_iterator(outer_tile_count);
|
||||
|
||||
LoadK load_k{params.tma_load_k, pipeline_outer, storage.smem_k};
|
||||
auto load_state_k = load_k.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count);
|
||||
|
||||
LoadV load_v{params.tma_load_v, pipeline_outer, storage.smem_v};
|
||||
auto load_state_v = load_v.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count);
|
||||
|
||||
outer_tile_count *= 2; // K & V
|
||||
|
||||
load_k.template step<false>(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count);
|
||||
|
||||
if (do_barrier) {
|
||||
load_warp_barrier.arrive();
|
||||
load_warp_barrier.wait(/*phase=*/ 0);
|
||||
do_barrier = false;
|
||||
}
|
||||
|
||||
load_v.template step<true>(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count);
|
||||
|
||||
while (outer_tile_count > 0) {
|
||||
load_k.template step<false>(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count);
|
||||
load_v.template step<true>(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count);
|
||||
}
|
||||
}
|
||||
|
||||
template<class BlkCoord, class ProblemShape, class MainloopPipelineReducer, class PipelineStateReducer>
|
||||
CUTLASS_DEVICE void
|
||||
reduce(
|
||||
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size,
|
||||
MainloopPipelineReducer& pipeline_reducer, PipelineStateReducer& smem_pipe_read_reducer,
|
||||
SharedStorage& storage)
|
||||
{
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
Tensor mDQ_full = params.tma_red_dq.get_tma_tensor(select<2,4,0,1>(problem_size));
|
||||
Tensor gDQ_full = local_tile(mDQ_full, TileShapeMD{}, make_coord(_, _, _), Step<_1, _1, Underscore>{});
|
||||
Tensor gDQ = gDQ_full(_, _, _, _0{}, get<2,0>(blk_coord), get<2,1>(blk_coord));
|
||||
Tensor sDQ = make_tensor(make_smem_ptr(storage.smem_dq.data()), SmemLayoutDQ{});
|
||||
|
||||
auto block_tma = params.tma_red_dq.get_slice(_0{});
|
||||
|
||||
Tensor tDQsDQ = block_tma.partition_S(sDQ);
|
||||
Tensor tDQgDQ = block_tma.partition_D(gDQ);
|
||||
|
||||
int inner_tile_count = get_inner_tile_count(blk_coord, problem_size);
|
||||
int g_index = 0;
|
||||
|
||||
auto smem_pipe_release_reducer = smem_pipe_read_reducer;
|
||||
bool first = true;
|
||||
while (inner_tile_count > 0) {
|
||||
while (inner_tile_count > 0) {
|
||||
if (Fusion{}.is_contributing(make_coord(g_index, get<1>(blk_coord)), TileShape{}, problem_size)) {
|
||||
break;
|
||||
}
|
||||
inner_tile_count -= 1;
|
||||
++g_index;
|
||||
}
|
||||
if (inner_tile_count == 0) break;
|
||||
|
||||
pipeline_reducer.consumer_wait(smem_pipe_read_reducer);
|
||||
if (lane_predicate == 1) {
|
||||
tma_store_wait<1>();
|
||||
}
|
||||
if (! first) {
|
||||
pipeline_reducer.consumer_release(smem_pipe_release_reducer);
|
||||
++smem_pipe_release_reducer;
|
||||
} else {
|
||||
first = false;
|
||||
}
|
||||
if (lane_predicate == 1) {
|
||||
copy(params.tma_red_dq, tDQsDQ(_,_,_,smem_pipe_read_reducer.index()), tDQgDQ(_,_,_,g_index));
|
||||
tma_store_arrive();
|
||||
}
|
||||
++smem_pipe_read_reducer;
|
||||
--inner_tile_count;
|
||||
++g_index;
|
||||
}
|
||||
if (lane_predicate) {
|
||||
tma_store_wait<0>();
|
||||
}
|
||||
pipeline_reducer.consumer_release(smem_pipe_release_reducer);
|
||||
++smem_pipe_release_reducer;
|
||||
}
|
||||
|
||||
template<class BlkCoord, class ProblemShape, class MainloopPipelineReducer, class PipelineStateReducer, class MathWgOrderBarrier>
|
||||
CUTLASS_DEVICE auto
|
||||
compute(
|
||||
BlkCoord const& blk_coord, BlkCoord const& wg_coord,
|
||||
Params const& params, ProblemShape const& problem_size,
|
||||
MainloopPipeline& pipeline_inner, PipelineState& smem_pipe_read_inner,
|
||||
MainloopPipelineQ& pipeline_outer, PipelineStateQ& smem_pipe_read_outer,
|
||||
MainloopPipelineReducer& pipeline_reducer, PipelineStateReducer& smem_pipe_write_reducer,
|
||||
SharedStorage& storage,
|
||||
MathWgOrderBarrier& math_wg_order_barrier)
|
||||
{
|
||||
TiledMmaND tiled_mma_nd;
|
||||
|
||||
Tensor acc_DV = partition_fragment_C(tiled_mma_nd, take<0,2>(TileShapeND{}));
|
||||
clear(acc_DV);
|
||||
|
||||
Tensor acc_DK = partition_fragment_C(tiled_mma_nd, take<0,2>(TileShapeND{}));
|
||||
clear(acc_DK);
|
||||
|
||||
int thread_idx = int(threadIdx.x) % cutlass::NumThreadsPerWarpGroup;
|
||||
|
||||
PipelineState smem_pipe_release_inner = smem_pipe_read_inner;
|
||||
|
||||
pipeline_outer.consumer_wait(smem_pipe_read_outer);
|
||||
PipelineStateQ smem_pipe_read_k = smem_pipe_read_outer;
|
||||
++smem_pipe_read_outer;
|
||||
pipeline_outer.consumer_wait(smem_pipe_read_outer);
|
||||
PipelineStateQ smem_pipe_read_v = smem_pipe_read_outer;
|
||||
|
||||
int inner_tile_count = get_inner_tile_count(wg_coord, problem_size);
|
||||
|
||||
TiledMmaNM tiled_mma_nm;
|
||||
Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{});
|
||||
Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{});
|
||||
auto thr_mma_nm = tiled_mma_nm.get_thread_slice(thread_idx);
|
||||
Tensor tSsK = thr_mma_nm.partition_A(sK);
|
||||
Tensor tSsQ = thr_mma_nm.partition_B(sQ);
|
||||
Tensor tSrK = thr_mma_nm.make_fragment_A(tSsK);
|
||||
Tensor tSrQ = thr_mma_nm.make_fragment_B(tSsQ);
|
||||
|
||||
Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{});
|
||||
Tensor sDO = make_tensor(make_smem_ptr(storage.smem_do.data()), SmemLayoutDO{});
|
||||
|
||||
Tensor tDPsV = thr_mma_nm.partition_A(sV);
|
||||
Tensor tDPsDO = thr_mma_nm.partition_B(sDO);
|
||||
Tensor tDPrV = thr_mma_nm.make_fragment_A(tDPsV);
|
||||
Tensor tDPrDO = thr_mma_nm.make_fragment_B(tDPsDO);
|
||||
|
||||
auto thr_mma_nd = tiled_mma_nd.get_thread_slice(thread_idx);
|
||||
|
||||
Tensor sDOp = make_tensor(make_smem_ptr(storage.smem_dop.data()), SmemLayoutDOp{});
|
||||
Tensor tDV_sDO = thr_mma_nd.partition_B(sDOp);
|
||||
Tensor tDVrDO = thr_mma_nd.make_fragment_B(tDV_sDO);
|
||||
|
||||
|
||||
Tensor sQp = make_tensor(make_smem_ptr(storage.smem_qp.data()), SmemLayoutQp{});
|
||||
Tensor tDK_sQ = thr_mma_nd.partition_B(sQp);
|
||||
Tensor tDKrQ = thr_mma_nd.make_fragment_B(tDK_sQ);
|
||||
|
||||
|
||||
int wg_idx = __shfl_sync(0xffffffff, get<1>(wg_coord) % NumMmaWarpGroups, 0);
|
||||
|
||||
TiledMmaMD tiled_mma_md;
|
||||
auto thr_mma_md = tiled_mma_md.get_thread_slice(thread_idx);
|
||||
Tensor sDS = make_tensor(make_smem_ptr(storage.smem_ds.data()), SmemLayoutDS{});
|
||||
Tensor tDQsDS = thr_mma_md.partition_A(sDS);
|
||||
Tensor tDQrDS_full = thr_mma_md.make_fragment_A(tDQsDS);
|
||||
Tensor tDQrDS = tDQrDS_full(_,_,_,_);
|
||||
Tensor sKp = make_tensor(make_smem_ptr(storage.smem_kp.data()), SmemLayoutKp{});
|
||||
Tensor tDQsK = thr_mma_md.partition_B(sKp);
|
||||
Tensor tDQrK = thr_mma_md.make_fragment_B(tDQsK);
|
||||
|
||||
Tensor sLSE = make_tensor(make_smem_ptr(storage.smem_lse.data()), make_shape(get<0>(TileShapeNM{}), get<1>(TileShapeNM{}), Int<StageCount>{}), make_stride(_0{}, _1{}, get<1>(TileShapeNM{})));
|
||||
Tensor tSsLSE = thr_mma_nm.partition_C(sLSE);
|
||||
|
||||
Tensor sODO = make_tensor(make_smem_ptr(storage.smem_sumOdO.data()), make_shape(get<0>(TileShapeNM{}), get<1>(TileShapeNM{}), Int<StageCount>{}), make_stride(_0{}, _1{}, get<1>(TileShapeNM{})));
|
||||
Tensor tDPsODO = thr_mma_nm.partition_C(sODO);
|
||||
|
||||
Tensor cS = make_identity_tensor(take<0,2>(TileShapeNM{}));
|
||||
Tensor tScS = thr_mma_nm.partition_C(cS);
|
||||
int n_block = get<1>(wg_coord);
|
||||
tScS.data() = tScS.data() + E<0>{} * n_block * get<0>(TileShapeNM{});
|
||||
|
||||
|
||||
// Transpose
|
||||
Tensor sDSp_full = sDS.compose(make_layout(make_shape(size<1>(sDS), size<0>(sDS), size<2>(sDS)), make_stride(size<0>(sDS), _1{}, size<1>(sDS) * size<0>(sDS))));
|
||||
Tensor sDSp = sDSp_full(_,_,_);
|
||||
Tensor tDPsDS = thr_mma_nm.partition_C(sDSp);
|
||||
|
||||
auto thr_mma_nd_ss = TiledMmaND_SS{}.get_thread_slice(thread_idx);
|
||||
Tensor tDKsDSp = thr_mma_nd_ss.partition_A(sDSp);
|
||||
|
||||
Tensor tDKrDSp = thr_mma_nd_ss.make_fragment_A(tDKsDSp);
|
||||
|
||||
Tensor sDQ = make_tensor(make_smem_ptr(storage.smem_dq.data()), SmemLayoutDQ{});
|
||||
auto tDQsDQ_full = thr_mma_md.partition_C(sDQ);
|
||||
|
||||
|
||||
auto smem_pipe_read_k_other = smem_pipe_read_k;
|
||||
smem_pipe_read_k_other.advance(2);
|
||||
|
||||
int k_index = 0;
|
||||
|
||||
while (inner_tile_count > 0) {
|
||||
while (inner_tile_count > 0) {
|
||||
if (Fusion{}.is_contributing(make_coord(k_index, get<1>(blk_coord)), TileShape{}, problem_size)) {
|
||||
break;
|
||||
}
|
||||
inner_tile_count -= 1;
|
||||
tScS.data() = tScS.data() + E<1>{} * get<1>(TileShapeNM{});
|
||||
k_index += 1;
|
||||
}
|
||||
if (inner_tile_count == 0) break;
|
||||
|
||||
pipeline_inner.consumer_wait(smem_pipe_read_inner);
|
||||
PipelineState smem_pipe_read_q = smem_pipe_read_inner;
|
||||
++smem_pipe_read_inner;
|
||||
PipelineState smem_pipe_read_do = smem_pipe_read_inner;
|
||||
++smem_pipe_read_inner;
|
||||
|
||||
// GEMM KQ -> S
|
||||
Tensor acc_S = partition_fragment_C(tiled_mma_nm, take<0,2>(TileShapeNM{}));
|
||||
|
||||
warpgroup_fence_operand(acc_S);
|
||||
warpgroup_arrive();
|
||||
gemm_zero_acc(tiled_mma_nm, tSrK(_,_,_,smem_pipe_read_k.index()), tSrQ(_,_,_,smem_pipe_read_q.index()), acc_S);
|
||||
warpgroup_commit_batch();
|
||||
|
||||
pipeline_inner.consumer_wait(smem_pipe_read_do);
|
||||
|
||||
// GEMM VdO -> dP
|
||||
Tensor acc_DP = partition_fragment_C(tiled_mma_nm, take<0,2>(TileShapeNM{}));
|
||||
|
||||
warpgroup_fence_operand(acc_DP);
|
||||
warpgroup_arrive();
|
||||
gemm_zero_acc(tiled_mma_nm, tDPrV(_,_,_,smem_pipe_read_v.index()), tDPrDO(_,_,_,smem_pipe_read_do.index()), acc_DP);
|
||||
warpgroup_commit_batch();
|
||||
|
||||
Tensor reg_LSE = make_fragment_like<ElementAccumulator>(acc_S);
|
||||
for (int i = 0; i < size(reg_LSE); i++) {
|
||||
reg_LSE(i) = ((ElementAccumulator)std::log2(std::exp(1.0))) * tSsLSE(_,_,_,smem_pipe_read_q.index())(i);
|
||||
}
|
||||
|
||||
Tensor reg_ODO = make_fragment_like<ElementAccumulator>(acc_S);
|
||||
if constexpr (decltype(get<0>(TileShape{}) != _128{})::value) {
|
||||
for (int i = 0; i < size(reg_ODO); i++) {
|
||||
reg_ODO(i) = tDPsODO(_,_,_,smem_pipe_read_do.index())(i);
|
||||
}
|
||||
}
|
||||
|
||||
warpgroup_wait<1>();
|
||||
warpgroup_fence_operand(acc_S);
|
||||
|
||||
math_wg_order_barrier.wait();
|
||||
// Compute S -> P
|
||||
Fusion{}.before_softmax(acc_S, tScS, problem_size);
|
||||
auto acc_P = make_fragment_like<ElementAccumulator>(acc_S);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(acc_P); i++) {
|
||||
acc_P(i) = ::exp2f(params.scale_softmax_log2 * acc_S(i) - reg_LSE(i));
|
||||
}
|
||||
math_wg_order_barrier.arrive();
|
||||
|
||||
if constexpr (decltype(get<0>(TileShape{}) == _128{})::value) {
|
||||
for (int i = 0; i < size(reg_ODO); i++) {
|
||||
reg_ODO(i) = tDPsODO(_,_,_,smem_pipe_read_do.index())(i);
|
||||
}
|
||||
}
|
||||
|
||||
warpgroup_wait<0>();
|
||||
warpgroup_fence_operand(acc_DP);
|
||||
|
||||
// Compute dP P -> dS
|
||||
auto acc_DS = make_fragment_like<Element>(acc_DP);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(acc_DS); i++) {
|
||||
// We could move the scale out and into the respective epilogues (or a final scaling step)
|
||||
acc_DS(i) = acc_P(i) * params.scale_softmax * (acc_DP(i) - reg_ODO(i));
|
||||
}
|
||||
|
||||
// GEMM PdO -> dV
|
||||
auto op_P = make_acc_into_op<Element>(acc_P, typename TiledMmaND::LayoutA_TV{});
|
||||
warpgroup_fence_operand(acc_DV);
|
||||
warpgroup_fence_operand(op_P);
|
||||
warpgroup_arrive();
|
||||
cute::gemm(tiled_mma_nd, op_P, tDVrDO(_,_,_,smem_pipe_read_do.index()), acc_DV);
|
||||
warpgroup_commit_batch();
|
||||
|
||||
// Store dS to smem dS'
|
||||
if (wg_idx == 0) math_wg_order_barrier.wait();
|
||||
|
||||
auto recast_bits = [](auto sz, auto t) {
|
||||
return recast<uint_bit_t<decltype(sz)::value>>(t);
|
||||
};
|
||||
auto tDPsDS_v = recast_bits(Int<sizeof_bits_v<Element> * 2>{}, tDPsDS);
|
||||
auto acc_DS_v = recast_bits(Int<sizeof_bits_v<Element> * 2>{}, acc_DS);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(acc_DS_v); i++) {
|
||||
tDPsDS_v(_,_,_,wg_idx)(i) = acc_DS_v(i);
|
||||
}
|
||||
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
if (wg_idx == 0) math_wg_order_barrier.arrive();
|
||||
|
||||
// GEMM dS Q -> dK
|
||||
if (wg_idx == 1) {
|
||||
|
||||
math_wg_order_barrier.wait();
|
||||
|
||||
// GEMM dS' K -> dQ
|
||||
Tensor acc_DQ = partition_fragment_C(tiled_mma_md, take<0,2>(TileShapeMD{}));
|
||||
|
||||
warpgroup_fence_operand(acc_DQ);
|
||||
warpgroup_arrive();
|
||||
gemm_zero_acc(tiled_mma_md, tDQrDS(_,_,_,0), tDQrK(_,_,_,smem_pipe_read_k_other.index()), acc_DQ);
|
||||
cute::gemm(tiled_mma_md, tDQrDS(_,_,_,1), tDQrK(_,_,_,smem_pipe_read_k.index()), acc_DQ);
|
||||
warpgroup_commit_batch();
|
||||
|
||||
warpgroup_fence_operand(acc_DK);
|
||||
warpgroup_arrive();
|
||||
cute::gemm(TiledMmaND_SS{}, tDKrDSp(_,_,_,wg_idx), tDKrQ(_,_,_,smem_pipe_read_q.index()), acc_DK);
|
||||
warpgroup_commit_batch();
|
||||
|
||||
warpgroup_wait<1>();
|
||||
warpgroup_fence_operand(acc_DK);
|
||||
|
||||
warpgroup_wait<1>();
|
||||
warpgroup_fence_operand(acc_DQ);
|
||||
|
||||
math_wg_order_barrier.arrive();
|
||||
|
||||
pipeline_reducer.producer_acquire(smem_pipe_write_reducer);
|
||||
auto tDQsDQ = tDQsDQ_full(_,_,_,smem_pipe_write_reducer.index());
|
||||
|
||||
// Store dQ to smem dQ'
|
||||
// Invoke TMA reduce on dQ'
|
||||
using Vec = uint_bit_t<sizeof_bits_v<ElementAccumulator> * 2>;
|
||||
auto tDQsDQ_v = recast<Vec>(tDQsDQ);
|
||||
auto acc_DQ_v = recast<Vec>(acc_DQ);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(acc_DQ_v); i++) {
|
||||
tDQsDQ_v(i) = acc_DQ_v(i);
|
||||
}
|
||||
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
|
||||
pipeline_reducer.producer_commit(smem_pipe_write_reducer);
|
||||
++smem_pipe_write_reducer;
|
||||
} else {
|
||||
|
||||
warpgroup_fence_operand(acc_DK);
|
||||
warpgroup_arrive();
|
||||
cute::gemm(TiledMmaND_SS{}, tDKrDSp(_,_,_,wg_idx), tDKrQ(_,_,_,smem_pipe_read_q.index()), acc_DK);
|
||||
warpgroup_commit_batch();
|
||||
|
||||
warpgroup_wait<1>();
|
||||
warpgroup_fence_operand(acc_DK);
|
||||
|
||||
pipeline_reducer.producer_acquire(smem_pipe_write_reducer);
|
||||
pipeline_reducer.producer_commit(smem_pipe_write_reducer);
|
||||
++smem_pipe_write_reducer;
|
||||
}
|
||||
|
||||
--inner_tile_count;
|
||||
|
||||
pipeline_inner.consumer_release(smem_pipe_release_inner);
|
||||
++smem_pipe_release_inner;
|
||||
pipeline_inner.consumer_release(smem_pipe_release_inner);
|
||||
++smem_pipe_release_inner;
|
||||
|
||||
tScS.data() = tScS.data() + E<1>{} * get<1>(TileShapeNM{});
|
||||
k_index += 1;
|
||||
}
|
||||
|
||||
pipeline_outer.consumer_release(smem_pipe_read_k);
|
||||
pipeline_outer.consumer_release(smem_pipe_read_outer);
|
||||
pipeline_reducer.producer_tail(smem_pipe_write_reducer);
|
||||
++smem_pipe_read_outer;
|
||||
|
||||
warpgroup_wait<0>();
|
||||
warpgroup_fence_operand(acc_DK);
|
||||
warpgroup_fence_operand(acc_DV);
|
||||
|
||||
return make_tuple(acc_DK, acc_DV);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
140
examples/88_hopper_fmha/collective/fmha_collective_load.hpp
Normal file
140
examples/88_hopper_fmha/collective/fmha_collective_load.hpp
Normal file
@ -0,0 +1,140 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
namespace cutlass::fmha::collective {
|
||||
|
||||
enum class LoadKind {
|
||||
kQ, kK, kV,
|
||||
kBwdN, kBwdM, kBwdScalar
|
||||
};
|
||||
|
||||
template<
|
||||
LoadKind kKind,
|
||||
class Pipeline,
|
||||
class Element,
|
||||
class SmemLayout,
|
||||
class TMA
|
||||
>
|
||||
struct CollectiveLoadTma {
|
||||
|
||||
using Params = TMA;
|
||||
using SharedStorage = cute::array_aligned<Element, cute::cosize_v<SmemLayout>>;
|
||||
using PipelineState = typename cutlass::PipelineState<Pipeline::Stages>;
|
||||
|
||||
Params const& params;
|
||||
Pipeline& pipeline;
|
||||
SharedStorage& storage;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
CollectiveLoadTma(Params const& params, Pipeline& pipeline, SharedStorage& storage)
|
||||
: params(params), pipeline(pipeline), storage(storage) {}
|
||||
|
||||
template<class ProblemSize, class TileShape, class BlockCoord>
|
||||
CUTLASS_DEVICE auto init_g(ProblemSize const& problem_size, TileShape const& tile_shape,
|
||||
BlockCoord const& blk_coord, int loop_count
|
||||
) {
|
||||
using X = Underscore;
|
||||
if constexpr (kKind == LoadKind::kK) {
|
||||
Tensor mK_full = params.get_tma_tensor(make_shape(get<3>(problem_size), get<4>(problem_size), select<0,1>(problem_size)));
|
||||
Tensor gK_full = local_tile(mK_full, tile_shape, make_coord(_, _, _), Step<X, _1, _1>{});
|
||||
Tensor gK = gK_full(_, _, _, _0{}, get<2>(blk_coord));
|
||||
return gK;
|
||||
} else if constexpr (kKind == LoadKind::kQ) {
|
||||
Tensor mQ_full = params.get_tma_tensor(make_shape(get<2>(problem_size), get<4>(problem_size), select<0,1>(problem_size)));
|
||||
Tensor gQ_full = local_tile(mQ_full, tile_shape, make_coord(_, _, _), Step<_1, X, _1>{});
|
||||
Tensor gQ = gQ_full(_, _, _, _0{}, get<2>(blk_coord));
|
||||
return make_tensor(gQ.data() + loop_count * get<0>(blk_coord) * stride<2>(gQ), gQ.layout());
|
||||
} else if constexpr (kKind == LoadKind::kV) {
|
||||
Tensor mV_full = params.get_tma_tensor(make_shape(get<4>(problem_size), get<3>(problem_size), select<0,1>(problem_size)));
|
||||
Tensor gV_full = local_tile(mV_full, tile_shape, make_coord(_, _, _), Step<X, _1, _1>{});
|
||||
Tensor gV = gV_full(_, _, _0{}, _, get<2>(blk_coord));
|
||||
return gV;
|
||||
} else if constexpr (kKind == LoadKind::kBwdN) {
|
||||
Tensor m_full = params.get_tma_tensor(make_shape(get<3>(problem_size), get<4>(problem_size), select<0,1>(problem_size)));
|
||||
Tensor g_full = local_tile(m_full, tile_shape, make_coord(_, _, _), Step<_1, X, _1>{});
|
||||
Tensor g = g_full(_, _, _, _0{}, get<2>(blk_coord));
|
||||
return make_tensor(g.data() + loop_count * get<1>(blk_coord) * stride<2>(g), g.layout());
|
||||
} else if constexpr (kKind == LoadKind::kBwdM) {
|
||||
Tensor m_full = params.get_tma_tensor(make_shape(get<2>(problem_size), get<4>(problem_size), select<0,1>(problem_size)));
|
||||
Tensor g_full = local_tile(m_full, tile_shape, make_coord(_, _, _), Step<X, _1, _1>{});
|
||||
Tensor g = g_full(_, _, _, _0{}, get<2>(blk_coord));
|
||||
return g;
|
||||
} else if constexpr (kKind == LoadKind::kBwdScalar) {
|
||||
Tensor m_full = params.get_tma_tensor(select<2,0,1>(problem_size));
|
||||
Tensor g_full = local_tile(m_full, tile_shape, make_coord(_, _, _), Step<X, _1, X>{});
|
||||
Tensor g = g_full(_, _, get<2,0>(blk_coord), get<2,1>(blk_coord));
|
||||
return g;
|
||||
}
|
||||
}
|
||||
|
||||
template<class ClusterRank, class ProblemSize, class TileShape, class BlockCoord>
|
||||
CUTLASS_DEVICE auto init_state(ClusterRank const& block_rank_in_cluster,
|
||||
ProblemSize const& problem_size, TileShape const& tile_shape,
|
||||
BlockCoord const& block_coord, int loop_count
|
||||
) {
|
||||
Tensor g = init_g(problem_size, tile_shape, block_coord, loop_count);
|
||||
Tensor s = make_tensor(make_smem_ptr(storage.data()), SmemLayout{});
|
||||
|
||||
auto block_tma = params.get_slice(block_rank_in_cluster);
|
||||
Tensor ts = block_tma.partition_D(s);
|
||||
Tensor tg = block_tma.partition_S(g);
|
||||
|
||||
return make_tuple(tg, ts);
|
||||
}
|
||||
|
||||
template<bool kAdvanceIterator=true, bool kAdvancePipe=true, bool kAcquireBarrier=true, class TileIterator, class State>
|
||||
CUTLASS_DEVICE void step(TileIterator& tile_iter, State const& state,
|
||||
PipelineState& smem_pipe_write,
|
||||
int lane_predicate, int& tile_count, uint16_t mcast_mask = 0
|
||||
) {
|
||||
if ((lane_predicate == 1) && (tile_count > 0)) {
|
||||
if constexpr (kAcquireBarrier) pipeline.producer_acquire(smem_pipe_write);
|
||||
using BarrierType = typename Pipeline::ProducerBarrierType;
|
||||
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
|
||||
|
||||
if constexpr (kKind == LoadKind::kBwdScalar) {
|
||||
copy(params.with(*tma_barrier, mcast_mask), get<0>(state)(_,_,*tile_iter), get<1>(state)(_,_,smem_pipe_write.index()));
|
||||
} else {
|
||||
copy(params.with(*tma_barrier, mcast_mask), get<0>(state)(_,_,_,*tile_iter), get<1>(state)(_,_,_,smem_pipe_write.index()));
|
||||
}
|
||||
if constexpr (kAdvancePipe) ++smem_pipe_write;
|
||||
if constexpr (kAdvanceIterator) ++tile_iter;
|
||||
}
|
||||
--tile_count;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
305
examples/88_hopper_fmha/collective/fmha_collective_softmax.hpp
Normal file
305
examples/88_hopper_fmha/collective/fmha_collective_softmax.hpp
Normal file
@ -0,0 +1,305 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "../collective/fmha_common.hpp"
|
||||
|
||||
namespace cutlass::fmha::collective {
|
||||
|
||||
template<
|
||||
class ElementAccumulator,
|
||||
class Fusion,
|
||||
class Params
|
||||
>
|
||||
struct CollectiveSoftmax {
|
||||
Params const& params;
|
||||
CUTLASS_DEVICE CollectiveSoftmax(Params const& params) : params(params) {}
|
||||
|
||||
using SumType = float;
|
||||
using MaxType = ElementAccumulator;
|
||||
|
||||
template<class AccPV, class TiledMmaPV>
|
||||
CUTLASS_DEVICE auto init(AccPV const& acc_pv, TiledMmaPV const& tiled_mma_pv) {
|
||||
Tensor s_max = make_fragment_like<MaxType>(size<0>(layout_acc_mn(tiled_mma_pv, acc_pv.layout())));
|
||||
Tensor a_sum = make_fragment_like<SumType>(s_max);
|
||||
return make_tuple(s_max, a_sum);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE float overload_exp2(float f) {
|
||||
return ::exp2f(f);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE cutlass::half_t overload_exp2(cutlass::half_t f) {
|
||||
auto a = f.raw();
|
||||
decltype(a) d;
|
||||
asm("ex2.approx.f16 %0, %1;" : "=h"(d) : "h"(a));
|
||||
return cutlass::half_t::bitcast(d);
|
||||
}
|
||||
|
||||
|
||||
CUTLASS_DEVICE float overload_max(float a, float b) {
|
||||
return ::max(a, b);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE cutlass::half_t overload_max(cutlass::half_t a, cutlass::half_t b) {
|
||||
return cutlass::half_t{__hmax_nan(a.to_half(), b.to_half())};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE half overload_to_native(cutlass::half_t f) {
|
||||
return f.to_half();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE float overload_to_native(float f) {
|
||||
return f;
|
||||
}
|
||||
|
||||
template<class AccQK, class TiledMmaQK, class CountQK, class State, class ProblemShape>
|
||||
CUTLASS_DEVICE auto step(AccQK& acc_qk, TiledMmaQK const& tiled_mma_qk, CountQK const& count_qk, State& state, ProblemShape const& problem_shape) {
|
||||
Fusion{}.before_softmax(acc_qk, count_qk, problem_shape);
|
||||
Tensor acc_qk_mn = make_tensor(acc_qk.data(), layout_acc_mn(tiled_mma_qk, acc_qk.layout()));
|
||||
auto reduction_target_qk = reduction_target_n(tiled_mma_qk);
|
||||
constexpr int red_rank = decltype(rank(reduction_target_qk))::value;
|
||||
|
||||
auto& s_max = get<0>(state);
|
||||
auto& a_sum = get<1>(state);
|
||||
|
||||
// Linear reduction is faster for the first iteration
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
|
||||
s_max(i) = acc_qk_mn(i, 0);
|
||||
}
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 1; j < size<1>(acc_qk_mn); j++) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
|
||||
s_max(i) = overload_max(s_max(i), acc_qk_mn(i, j));
|
||||
}
|
||||
}
|
||||
|
||||
for_each(make_seq<red_rank>{}, [&](auto r) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 1; j < shape<r>(reduction_target_qk); j *= 2) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
|
||||
s_max(i) = overload_max(s_max(i), MaxType{__shfl_xor_sync(uint32_t(-1), overload_to_native(s_max(i)), stride<r>(reduction_target_qk) * j)});
|
||||
}
|
||||
}
|
||||
});
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
|
||||
MaxType local_max = s_max(i) == static_cast<MaxType>(-INFINITY) ? static_cast<MaxType>(0) : s_max(i);
|
||||
MaxType scale = static_cast<MaxType>(params.scale_softmax_log2);
|
||||
MaxType scale_max = scale * local_max;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size<1>(acc_qk_mn); j++) {
|
||||
acc_qk_mn(i, j) = overload_exp2(scale * acc_qk_mn(i, j) - scale_max);
|
||||
}
|
||||
}
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
|
||||
a_sum(i) = SumType{reduce(acc_qk_mn(i, _), cute::plus{})};
|
||||
}
|
||||
}
|
||||
|
||||
template<bool kUseFusion=true, class AccQK, class TiledMmaQK, class CountQK, class State, class AccPV, class TiledMmaPV, class ProblemShape>
|
||||
CUTLASS_DEVICE auto step_interleave_begin(AccQK& acc_qk, TiledMmaQK const& tiled_mma_qk, CountQK const& count_qk, State& state, AccPV& acc_pv, TiledMmaPV const& tiled_mma_pv, ProblemShape const& problem_shape) {
|
||||
|
||||
if constexpr (kUseFusion) {
|
||||
Fusion{}.before_softmax(acc_qk, count_qk, problem_shape);
|
||||
}
|
||||
|
||||
Tensor acc_qk_mn = make_tensor(acc_qk.data(), layout_acc_mn(tiled_mma_qk, acc_qk.layout()));
|
||||
Tensor acc_pv_mn = make_tensor(acc_pv.data(), layout_acc_mn(tiled_mma_pv, acc_pv.layout()));
|
||||
|
||||
static_assert(size<0>(acc_qk_mn) == size<0>(acc_pv_mn));
|
||||
auto reduction_target_qk = reduction_target_n(tiled_mma_qk);
|
||||
constexpr int red_rank = decltype(rank(reduction_target_qk))::value;
|
||||
|
||||
auto& s_max = get<0>(state);
|
||||
auto& a_sum = get<1>(state);
|
||||
|
||||
Tensor s_max_prev = make_fragment_like(s_max);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
|
||||
s_max_prev(i) = s_max(i);
|
||||
}
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
|
||||
// Linear reduction is faster here, as well
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size<1>(acc_qk_mn); j++) {
|
||||
s_max(i) = overload_max(s_max(i), acc_qk_mn(i, j));
|
||||
}
|
||||
}
|
||||
// reduce max
|
||||
for_each(make_seq<red_rank>{}, [&](auto r) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 1; j < shape<r>(reduction_target_qk); j *= 2) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
|
||||
s_max(i) = overload_max(s_max(i), __shfl_xor_sync(uint32_t(-1), s_max(i), stride<r>(reduction_target_qk) * j));
|
||||
}
|
||||
}
|
||||
});
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<0>(acc_pv_mn); i++) {
|
||||
float s_max_cur = s_max(i) == -INFINITY ? 0.0f : s_max(i);
|
||||
float scale = ::exp2f((s_max_prev(i) - s_max_cur) * params.scale_softmax_log2);
|
||||
a_sum(i) *= scale;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size<1>(acc_pv_mn); j++) {
|
||||
acc_pv_mn(i, j) *= scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<class AccQK_MN, class State>
|
||||
CUTLASS_DEVICE auto step_interleave_step(AccQK_MN& acc_qk_mn, State& state) {
|
||||
|
||||
auto& s_max = get<0>(state);
|
||||
auto& a_sum = get<1>(state);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size<0>(acc_qk_mn); j++) {
|
||||
float local_max = s_max(j) == -INFINITY ? 0.f : s_max(j);
|
||||
float scale_max = params.scale_softmax_log2 * local_max;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k = 0; k < size<1>(acc_qk_mn); k++) {
|
||||
acc_qk_mn(j, k) = ::exp2f(params.scale_softmax_log2 * acc_qk_mn(j, k) - scale_max);
|
||||
a_sum(j) += acc_qk_mn(j, k);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<bool kUseFusion=true, class AccQK, class TiledMmaQK, class CountQK, class State, class AccPV, class TiledMmaPV, class ProblemShape>
|
||||
CUTLASS_DEVICE auto step(AccQK& acc_qk, TiledMmaQK const& tiled_mma_qk, CountQK const& count_qk, State& state, AccPV& acc_pv, TiledMmaPV const& tiled_mma_pv, ProblemShape const& problem_shape) {
|
||||
|
||||
if constexpr (kUseFusion) {
|
||||
Fusion{}.before_softmax(acc_qk, count_qk, problem_shape);
|
||||
}
|
||||
|
||||
Tensor acc_qk_mn = make_tensor(acc_qk.data(), layout_acc_mn(tiled_mma_qk, acc_qk.layout()));
|
||||
Tensor acc_pv_mn = make_tensor(acc_pv.data(), layout_acc_mn(tiled_mma_pv, acc_pv.layout()));
|
||||
|
||||
static_assert(size<0>(acc_qk_mn) == size<0>(acc_pv_mn));
|
||||
auto reduction_target_qk = reduction_target_n(tiled_mma_qk);
|
||||
constexpr int red_rank = decltype(rank(reduction_target_qk))::value;
|
||||
|
||||
auto& s_max = get<0>(state);
|
||||
auto& a_sum = get<1>(state);
|
||||
|
||||
Tensor s_max_prev = make_fragment_like(s_max);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
|
||||
s_max_prev(i) = s_max(i);
|
||||
|
||||
// Linear reduction is faster here, as well
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size<1>(acc_qk_mn); j++) {
|
||||
s_max(i) = overload_max(s_max(i), acc_qk_mn(i, j));
|
||||
}
|
||||
// reduce max
|
||||
for_each(make_seq<red_rank>{}, [&](auto r) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 1; j < shape<r>(reduction_target_qk); j *= 2) {
|
||||
s_max(i) = overload_max(s_max(i), MaxType{__shfl_xor_sync(uint32_t(-1), overload_to_native(s_max(i)), stride<r>(reduction_target_qk) * j)});
|
||||
}
|
||||
});
|
||||
|
||||
MaxType local_max = s_max(i) == static_cast<MaxType>(-INFINITY) ? static_cast<MaxType>(0) : s_max(i);
|
||||
MaxType scale = static_cast<MaxType>(params.scale_softmax_log2);
|
||||
MaxType scale_max = scale * local_max;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size<1>(acc_qk_mn); j++) {
|
||||
acc_qk_mn(i, j) = overload_exp2(scale * acc_qk_mn(i, j) - scale_max);
|
||||
}
|
||||
|
||||
MaxType s_max_cur = s_max(i) == static_cast<MaxType>(-INFINITY) ? static_cast<MaxType>(0) : s_max(i);
|
||||
SumType scale_pv = overload_exp2((s_max_prev(i) - s_max_cur) * scale);
|
||||
a_sum(i) *= scale_pv;
|
||||
|
||||
using ElementPV = typename AccPV::value_type;
|
||||
ElementPV scale_pv_ele = static_cast<ElementPV>(scale_pv);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size<1>(acc_pv_mn); j++) {
|
||||
acc_pv_mn(i, j) *= scale_pv_ele;
|
||||
}
|
||||
a_sum(i) += SumType{reduce(acc_qk_mn(i, _), cute::plus{})};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<class State, class AccPV, class TiledMmaPV>
|
||||
CUTLASS_DEVICE auto tail(State& state, AccPV& acc_pv, TiledMmaPV const& tiled_mma_pv) {
|
||||
auto& s_max = get<0>(state);
|
||||
auto& a_sum = get<1>(state);
|
||||
|
||||
Tensor acc_pv_mn = make_tensor(acc_pv.data(), layout_acc_mn(tiled_mma_pv, acc_pv.layout()));
|
||||
|
||||
auto reduction_target = reduction_target_n(tiled_mma_pv);
|
||||
constexpr int red_rank = decltype(rank(reduction_target))::value;
|
||||
for_each(make_seq<red_rank>{}, [&](auto r) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 1; j < shape<r>(reduction_target); j *= 2) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<0>(acc_pv_mn); i++) {
|
||||
a_sum(i) = a_sum(i) + __shfl_xor_sync(uint32_t(-1), a_sum(i), stride<r>(reduction_target) * j);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Tensor acc_mn = make_tensor(acc_pv.data(), layout_acc_mn(tiled_mma_pv, acc_pv.layout()));
|
||||
|
||||
Tensor lse = make_fragment_like(a_sum);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<0>(acc_mn); i++) {
|
||||
float sum = a_sum(i);
|
||||
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : __frcp_rn(sum);
|
||||
lse(i) = (sum == 0.f || sum != sum) ? INFINITY : s_max(i) * params.scale_softmax + __logf(sum);
|
||||
float scale = params.rp_dropout * inv_sum;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size<1>(acc_mn); j++) {
|
||||
acc_mn(i, j) *= scale;
|
||||
}
|
||||
}
|
||||
|
||||
return lse;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
526
examples/88_hopper_fmha/collective/fmha_collective_tma.hpp
Normal file
526
examples/88_hopper_fmha/collective/fmha_collective_tma.hpp
Normal file
@ -0,0 +1,526 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "../collective/fmha_common.hpp"
|
||||
#include "../collective/fmha_collective_load.hpp"
|
||||
#include "../collective/fmha_collective_softmax.hpp"
|
||||
#include "../kernel/fmha_options.hpp"
|
||||
|
||||
namespace cutlass::fmha::collective {
|
||||
|
||||
using namespace cute;
|
||||
using cutlass::fmha::kernel::Tag;
|
||||
using cutlass::fmha::kernel::find_option_t;
|
||||
|
||||
template<
|
||||
typename Element_,
|
||||
typename ElementAccumulator_,
|
||||
typename TileShape_, // BlockQO, BlockKV, BlockHead
|
||||
class Fusion,
|
||||
class... Options
|
||||
>
|
||||
struct FmhaMainloopTma {
|
||||
|
||||
using Element = Element_;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using TileShape = TileShape_;
|
||||
|
||||
// Options
|
||||
using kClusterM = find_option_t<Tag::kClusterM, Int<1>, Options...>;
|
||||
static constexpr int StageCount = find_option_t<Tag::kStagesKV, Int<4>, Options...>::value;
|
||||
static constexpr int StageCountQ = find_option_t<Tag::kStagesQ, Int<1>, Options...>::value;
|
||||
|
||||
using StagesQ = cutlass::gemm::collective::StageCount<StageCountQ>;
|
||||
using Stages = cutlass::gemm::collective::StageCount<StageCount>;
|
||||
using ClusterShape = Shape<kClusterM, _1, _1>;
|
||||
|
||||
// 16B alignment lets us use TMA
|
||||
static constexpr int Alignment = 16 / sizeof(Element);
|
||||
|
||||
using TileShapeQK = TileShape;
|
||||
using TileShapePV = decltype(select<0,2,1>(TileShapeQK{}));
|
||||
|
||||
using LayoutQKV = cute::tuple<int, _1, cute::tuple<int, int>>;
|
||||
using LayoutQ = LayoutQKV;
|
||||
using LayoutK = LayoutQKV;
|
||||
using LayoutV = LayoutQKV;
|
||||
|
||||
using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Element, LayoutQ, Alignment,
|
||||
Element, LayoutK, Alignment,
|
||||
ElementAccumulator,
|
||||
TileShapeQK, ClusterShape, Stages,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
|
||||
|
||||
using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
// the stride for A does not matter since we do not load from smem at all
|
||||
Element, LayoutK, Alignment,
|
||||
Element, decltype(select<1,0,2>(LayoutV{})), Alignment,
|
||||
ElementAccumulator,
|
||||
TileShapePV, ClusterShape, Stages,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
|
||||
|
||||
using TiledMmaQK = typename CollectiveMmaQK::TiledMma;
|
||||
using TiledMmaPV = decltype(convert_to_gmma_rs(typename CollectiveMmaPV::TiledMma{}));
|
||||
|
||||
using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int<StagesQ::value>{}));
|
||||
using SmemLayoutK = typename CollectiveMmaQK::SmemLayoutB;
|
||||
using SmemLayoutV = typename CollectiveMmaPV::SmemLayoutB;
|
||||
|
||||
using MainloopPipeline = cutlass::PipelineTmaAsync<Stages::value>;
|
||||
using MainloopPipelineQ = cutlass::PipelineTmaAsync<StagesQ::value>;
|
||||
|
||||
using PipelineState = typename cutlass::PipelineState<MainloopPipeline::Stages>;
|
||||
using PipelineStateQ = typename cutlass::PipelineState<MainloopPipelineQ::Stages>;
|
||||
|
||||
using TileShapeOut = TileShapePV;
|
||||
using TiledMmaOut = TiledMmaPV;
|
||||
using ElementOut = ElementAccumulator;
|
||||
|
||||
struct SharedStorage {
|
||||
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
|
||||
union {
|
||||
cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
|
||||
cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
|
||||
};
|
||||
};
|
||||
|
||||
struct Arguments {
|
||||
const Element* ptr_Q;
|
||||
LayoutQ dQ;
|
||||
const Element* ptr_K;
|
||||
LayoutK dK;
|
||||
const Element* ptr_V;
|
||||
LayoutV dV;
|
||||
};
|
||||
|
||||
using TMA_Q = typename CollectiveMmaQK::Params::TMA_A;
|
||||
using TMA_K = typename CollectiveMmaQK::Params::TMA_B;
|
||||
using TMA_V = typename CollectiveMmaPV::Params::TMA_B;
|
||||
|
||||
struct Params {
|
||||
TMA_Q tma_load_q;
|
||||
TMA_K tma_load_k;
|
||||
TMA_V tma_load_v;
|
||||
|
||||
float scale_softmax;
|
||||
float scale_softmax_log2;
|
||||
float rp_dropout;
|
||||
};
|
||||
|
||||
using LoadQ = cutlass::fmha::collective::CollectiveLoadTma<
|
||||
cutlass::fmha::collective::LoadKind::kQ,
|
||||
MainloopPipelineQ,
|
||||
Element,
|
||||
SmemLayoutQ,
|
||||
TMA_Q
|
||||
>;
|
||||
|
||||
using LoadK = cutlass::fmha::collective::CollectiveLoadTma<
|
||||
cutlass::fmha::collective::LoadKind::kK,
|
||||
MainloopPipeline,
|
||||
Element,
|
||||
SmemLayoutK,
|
||||
TMA_K
|
||||
>;
|
||||
|
||||
using LoadV = cutlass::fmha::collective::CollectiveLoadTma<
|
||||
cutlass::fmha::collective::LoadKind::kV,
|
||||
MainloopPipeline,
|
||||
Element,
|
||||
SmemLayoutV,
|
||||
TMA_V
|
||||
>;
|
||||
|
||||
static_assert(size(typename CollectiveMmaQK::TiledMma{}) == size(typename CollectiveMmaPV::TiledMma{}));
|
||||
|
||||
static const int MaxThreadsPerBlock = size(typename CollectiveMmaQK::TiledMma{});
|
||||
|
||||
template<class ProblemShape>
|
||||
static bool can_implement(ProblemShape const& problem_size, Arguments const& args) {
|
||||
return true
|
||||
&& (get<4>(problem_size) <= get<2>(TileShape{}))
|
||||
&& ((get<4>(problem_size) % Alignment) == 0)
|
||||
&& ((get<2>(problem_size) % Alignment) == 0)
|
||||
;
|
||||
}
|
||||
|
||||
template<class ProblemShape>
|
||||
static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace) {
|
||||
|
||||
auto problem_shape_qk = make_shape(get<2>(problem_size), get<3>(problem_size), get<4>(problem_size), make_shape(get<0>(problem_size), get<1>(problem_size)));
|
||||
auto params_qk = CollectiveMmaQK::to_underlying_arguments(problem_shape_qk,
|
||||
typename CollectiveMmaQK::Arguments {
|
||||
args.ptr_Q, args.dQ,
|
||||
args.ptr_K, args.dK,
|
||||
}, /*workspace=*/ nullptr);
|
||||
|
||||
auto problem_shape_pv = select<0,2,1,3>(problem_shape_qk);
|
||||
auto params_pv = CollectiveMmaPV::to_underlying_arguments(problem_shape_pv,
|
||||
typename CollectiveMmaPV::Arguments {
|
||||
args.ptr_K, args.dK, // never used, dummy
|
||||
args.ptr_V, select<1,0,2>(args.dV),
|
||||
}, /*workspace=*/ nullptr);
|
||||
|
||||
return Params{
|
||||
params_qk.tma_load_a,
|
||||
params_qk.tma_load_b,
|
||||
params_pv.tma_load_b,
|
||||
1.0f / (float) std::sqrt(get<4>(problem_size)),
|
||||
(float) (std::log2(std::exp(1.0)) / std::sqrt(get<4>(problem_size))),
|
||||
1.0f
|
||||
};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& params) {
|
||||
cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor());
|
||||
}
|
||||
|
||||
template<class BlkCoord, class ProblemShape>
|
||||
CUTLASS_DEVICE auto
|
||||
compute(
|
||||
int block_rank_in_cluster,
|
||||
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size,
|
||||
MainloopPipeline& pipeline, PipelineState& smem_pipe_read, PipelineState& smem_pipe_write,
|
||||
MainloopPipelineQ& pipeline_q, PipelineStateQ& smem_pipe_read_q, PipelineStateQ& smem_pipe_write_q,
|
||||
SharedStorage& storage)
|
||||
{
|
||||
int warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
PipelineState smem_pipe_release = smem_pipe_read;
|
||||
[[maybe_unused]] PipelineStateQ smem_pipe_release_q = smem_pipe_read_q;
|
||||
|
||||
|
||||
int fusion_tile_count = Fusion{}.get_trip_count(blk_coord, TileShape{}, problem_size);
|
||||
|
||||
LoadQ load_q{params.tma_load_q, pipeline_q, storage.smem_q};
|
||||
auto load_state_q = load_q.init_state(_0{}, problem_size, TileShapeQK{}, blk_coord, 1);
|
||||
|
||||
LoadK load_k{params.tma_load_k, pipeline, storage.smem_k};
|
||||
auto load_state_k = load_k.init_state(block_rank_in_cluster, problem_size, TileShapeQK{}, blk_coord, fusion_tile_count);
|
||||
|
||||
LoadV load_v{params.tma_load_v, pipeline, storage.smem_v};
|
||||
auto load_state_v = load_v.init_state(block_rank_in_cluster, problem_size, TileShapePV{}, blk_coord, fusion_tile_count);
|
||||
|
||||
// Set predicate for the lowest lane_id in the warp
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
// Issue TmaLoads (Prologue fetches)
|
||||
if (warp_idx == 0) {
|
||||
auto q_tile_iter = cute::make_coord_iterator(1);
|
||||
int q_tile_count = 1;
|
||||
load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, q_tile_count);
|
||||
}
|
||||
|
||||
// Loop over K elems
|
||||
auto k_tile_iter = cute::make_coord_iterator(fusion_tile_count);
|
||||
|
||||
int k_tile_count_tma = 2 * fusion_tile_count;
|
||||
|
||||
uint16_t mcast_mask_b = 0;
|
||||
|
||||
if (warp_idx == 0 && lane_predicate == 1) {
|
||||
if constexpr (cute::is_same_v<typename CollectiveMmaQK::GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int m = 0; m < size<0>(block_layout); ++m) {
|
||||
mcast_mask_b |= (uint16_t(1) << block_layout(m,_0{},Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < StageCount; i++) {
|
||||
if (i % 2 == 0) {
|
||||
load_k.template step<false>(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b);
|
||||
} else {
|
||||
load_v.template step<true>(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TiledMmaQK tiled_mma_qk;
|
||||
auto thr_mma_qk = tiled_mma_qk.get_thread_slice(thread_idx);
|
||||
|
||||
// Mainloop setup QK
|
||||
Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{});
|
||||
Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{});
|
||||
|
||||
Tensor tSsQ = thr_mma_qk.partition_A(sQ); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tSsK = thr_mma_qk.partition_B(sK); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
Tensor tSrQ = thr_mma_qk.make_fragment_A(tSsQ); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
Tensor tSrK = thr_mma_qk.make_fragment_B(tSsK); // (MMA,MMA_M,MMA_N,PIPE)
|
||||
|
||||
// Prepare: MMA PV
|
||||
TiledMmaPV tiled_mma_pv;
|
||||
auto thr_mma_pv = tiled_mma_pv.get_thread_slice(thread_idx);
|
||||
|
||||
// Mainloop setup PV
|
||||
Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{});
|
||||
|
||||
Tensor tOsV = thr_mma_pv.partition_B(sV); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
Tensor tOrV = thr_mma_pv.make_fragment_B(tOsV); // (MMA,MMA_M,MMA_N,PIPE)
|
||||
|
||||
int k_tile_count = Fusion{}.get_unmasked_trip_count(blk_coord, TileShape{}, problem_size);
|
||||
|
||||
pipeline_q.consumer_wait(smem_pipe_read_q);
|
||||
|
||||
// mapping into QK accumulator
|
||||
Tensor cP = make_identity_tensor(take<0,2>(TileShapeQK{}));
|
||||
Tensor tPcP = thr_mma_qk.partition_C(cP);
|
||||
int m_block = get<0>(blk_coord);
|
||||
tPcP.data() = tPcP.data() + E<0>{} * m_block * get<0>(TileShapeQK{});
|
||||
|
||||
// Allocate PV acc
|
||||
Tensor acc_pv = partition_fragment_C(tiled_mma_pv, take<0, 2>(TileShapePV{}));
|
||||
|
||||
cutlass::fmha::collective::CollectiveSoftmax<ElementAccumulator, Fusion, decltype(params)> softmax{params};
|
||||
auto softmax_state = softmax.init(acc_pv, tiled_mma_pv);
|
||||
|
||||
if (true)
|
||||
{
|
||||
--k_tile_count;
|
||||
// Allocate QK acc
|
||||
Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{}));
|
||||
|
||||
pipeline.consumer_wait(smem_pipe_read);
|
||||
|
||||
// MMA QK
|
||||
warpgroup_fence_operand(acc_qk);
|
||||
warpgroup_arrive();
|
||||
|
||||
gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,_0{}), tSrK(_,_,_,smem_pipe_read.index()), acc_qk);
|
||||
warpgroup_commit_batch();
|
||||
|
||||
++smem_pipe_read;
|
||||
|
||||
// Wait for the pipeline MMAs to drain
|
||||
warpgroup_wait<0>();
|
||||
warpgroup_fence_operand(acc_qk);
|
||||
|
||||
softmax.step(acc_qk, tiled_mma_qk, tPcP, softmax_state, problem_size);
|
||||
|
||||
Tensor acc_qk_fixed = make_fragment_like<Element>(convert_c_layout_to_a_layout(acc_qk.layout(), shape<1>(typename decltype(tiled_mma_pv)::LayoutA_TV{})));
|
||||
|
||||
Tensor acc_qk_input = make_tensor(acc_qk_fixed.data(), acc_qk.layout());
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(acc_qk); i++) {
|
||||
acc_qk_input(i) = static_cast<Element>(acc_qk(i));
|
||||
}
|
||||
|
||||
pipeline.consumer_wait(smem_pipe_read);
|
||||
|
||||
// MMA PV
|
||||
warpgroup_fence_operand(acc_pv);
|
||||
warpgroup_fence_operand(acc_qk_fixed);
|
||||
warpgroup_arrive();
|
||||
|
||||
gemm_zero_acc(tiled_mma_pv, acc_qk_fixed, tOrV(_,_,_,smem_pipe_read.index()), acc_pv);
|
||||
warpgroup_commit_batch();
|
||||
|
||||
//
|
||||
// Advance the pipe
|
||||
//
|
||||
|
||||
// Advance consumer pipeline
|
||||
++smem_pipe_read;
|
||||
|
||||
pipeline.consumer_release(smem_pipe_release);
|
||||
++smem_pipe_release;
|
||||
|
||||
tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{});
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for ( ; k_tile_count > 0; --k_tile_count)
|
||||
{
|
||||
// Allocate QK acc
|
||||
Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{}));
|
||||
|
||||
pipeline.consumer_wait(smem_pipe_read);
|
||||
|
||||
// MMA QK
|
||||
warpgroup_fence_operand(acc_qk);
|
||||
warpgroup_arrive();
|
||||
|
||||
gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,_0{}), tSrK(_,_,_,smem_pipe_read.index()), acc_qk);
|
||||
warpgroup_commit_batch();
|
||||
|
||||
++smem_pipe_read;
|
||||
|
||||
if (warp_idx == 0) {
|
||||
load_k.template step<false>(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b);
|
||||
}
|
||||
|
||||
// Wait for the pipeline MMAs to drain
|
||||
warpgroup_wait<0>();
|
||||
warpgroup_fence_operand(acc_qk);
|
||||
warpgroup_fence_operand(acc_pv);
|
||||
|
||||
softmax.template step_interleave_begin<false>(acc_qk, tiled_mma_qk, tPcP, softmax_state, acc_pv, tiled_mma_pv, problem_size);
|
||||
|
||||
pipeline.consumer_release(smem_pipe_release);
|
||||
|
||||
++smem_pipe_release;
|
||||
|
||||
pipeline.consumer_wait(smem_pipe_read);
|
||||
|
||||
// MMA PV
|
||||
auto layout_qk_input = convert_c_layout_to_a_layout(acc_qk.layout(), shape<1>(typename decltype(tiled_mma_pv)::LayoutA_TV{}));
|
||||
|
||||
Tensor acc_qk_input = make_tensor(acc_qk.data(), layout_qk_input);
|
||||
|
||||
static_assert(decltype(size<1>(layout_qk_input) == _1{})::value);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<2>(tOrV); i++) {
|
||||
Tensor acc_qk_element = make_fragment_like<Element>(layout_qk_input(_, _0{}, _0{}));
|
||||
Tensor acc_qk_element_mk = tensor_op_mk_v(tiled_mma_pv, acc_qk_element);
|
||||
Tensor acc_qk_input_mk = tensor_op_mk_v(tiled_mma_pv, acc_qk_input(_, _0{}, i));
|
||||
softmax.step_interleave_step(acc_qk_input_mk, softmax_state);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size(acc_qk_element_mk); j++) {
|
||||
acc_qk_element_mk(j) = static_cast<Element>(acc_qk_input_mk(j));
|
||||
}
|
||||
warpgroup_arrive();
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size<1>(tOrV); j++) {
|
||||
cute::gemm(tiled_mma_pv, acc_qk_element, tOrV(_,j,i,smem_pipe_read.index()), acc_pv(_,_0{},j));
|
||||
}
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
|
||||
// Wait for the pipeline MMAs to drain
|
||||
pipeline.consumer_release(smem_pipe_release);
|
||||
++smem_pipe_release;
|
||||
|
||||
++smem_pipe_read;
|
||||
|
||||
if (warp_idx == 0) {
|
||||
load_v.template step<true>(k_tile_iter, load_state_v, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b);
|
||||
}
|
||||
|
||||
tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{});
|
||||
}
|
||||
|
||||
k_tile_count += Fusion{}.get_masked_trip_count(blk_coord, TileShape{}, problem_size);
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for ( ; k_tile_count > 0; --k_tile_count)
|
||||
{
|
||||
// Allocate QK acc
|
||||
Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{}));
|
||||
|
||||
pipeline.consumer_wait(smem_pipe_read);
|
||||
|
||||
// MMA QK
|
||||
warpgroup_fence_operand(acc_qk);
|
||||
warpgroup_arrive();
|
||||
|
||||
gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,_0{}), tSrK(_,_,_,smem_pipe_read.index()), acc_qk);
|
||||
warpgroup_commit_batch();
|
||||
|
||||
++smem_pipe_read;
|
||||
|
||||
if (warp_idx == 0) {
|
||||
load_k.template step<false>(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b);
|
||||
}
|
||||
|
||||
// Wait for the pipeline MMAs to drain
|
||||
warpgroup_wait<0>();
|
||||
warpgroup_fence_operand(acc_qk);
|
||||
warpgroup_fence_operand(acc_pv);
|
||||
|
||||
softmax.step_interleave_begin(acc_qk, tiled_mma_qk, tPcP, softmax_state, acc_pv, tiled_mma_pv, problem_size);
|
||||
|
||||
pipeline.consumer_release(smem_pipe_release);
|
||||
|
||||
++smem_pipe_release;
|
||||
|
||||
pipeline.consumer_wait(smem_pipe_read);
|
||||
|
||||
// MMA PV
|
||||
auto layout_qk_input = convert_c_layout_to_a_layout(acc_qk.layout(), shape<1>(typename decltype(tiled_mma_pv)::LayoutA_TV{}));
|
||||
|
||||
Tensor acc_qk_input = make_tensor(acc_qk.data(), layout_qk_input);
|
||||
|
||||
static_assert(decltype(size<1>(layout_qk_input) == _1{})::value);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<2>(tOrV); i++) {
|
||||
Tensor acc_qk_element = make_fragment_like<Element>(layout_qk_input(_, _0{}, _0{}));
|
||||
Tensor acc_qk_element_mk = tensor_op_mk_v(tiled_mma_pv, acc_qk_element);
|
||||
Tensor acc_qk_input_mk = tensor_op_mk_v(tiled_mma_pv, acc_qk_input(_, _0{}, i));
|
||||
softmax.step_interleave_step(acc_qk_input_mk, softmax_state);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size(acc_qk_element_mk); j++) {
|
||||
acc_qk_element_mk(j) = static_cast<Element>(acc_qk_input_mk(j));
|
||||
}
|
||||
warpgroup_arrive();
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size<1>(tOrV); j++) {
|
||||
cute::gemm(tiled_mma_pv, acc_qk_element, tOrV(_,j,i,smem_pipe_read.index()), acc_pv(_,_0{},j));
|
||||
}
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
|
||||
// Wait for the pipeline MMAs to drain
|
||||
pipeline.consumer_release(smem_pipe_release);
|
||||
++smem_pipe_release;
|
||||
|
||||
++smem_pipe_read;
|
||||
|
||||
if (warp_idx == 0) {
|
||||
load_v.template step<true>(k_tile_iter, load_state_v, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b);
|
||||
}
|
||||
|
||||
tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{});
|
||||
}
|
||||
|
||||
// Wait for the pipeline MMAs to drain
|
||||
warpgroup_wait<0>();
|
||||
warpgroup_fence_operand(acc_pv);
|
||||
|
||||
Tensor lse = softmax.tail(softmax_state, acc_pv, tiled_mma_pv);
|
||||
|
||||
return make_tuple(acc_pv, lse);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
|
||||
@ -0,0 +1,560 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "../collective/fmha_common.hpp"
|
||||
#include "../collective/fmha_collective_load.hpp"
|
||||
#include "../collective/fmha_collective_softmax.hpp"
|
||||
#include "../kernel/fmha_options.hpp"
|
||||
|
||||
namespace cutlass::fmha::collective {
|
||||
|
||||
using namespace cute;
|
||||
using cutlass::fmha::kernel::Tag;
|
||||
using cutlass::fmha::kernel::find_option_t;
|
||||
|
||||
template<
|
||||
class Element_,
|
||||
class ElementAccumulatorQK_,
|
||||
class ElementAccumulatorPV_,
|
||||
class TileShape_, // SeqQ, SeqKV, Head
|
||||
class LayoutQ_, class LayoutK_, class LayoutV_, // SeqX, Head, (Batches)
|
||||
class Fusion,
|
||||
class... Options
|
||||
>
|
||||
struct FmhaMainloopTmaWarpSpecialized {
|
||||
|
||||
using Element = Element_;
|
||||
using ElementAccumulatorQK = ElementAccumulatorQK_;
|
||||
using ElementAccumulatorPV = ElementAccumulatorPV_;
|
||||
using TileShape = TileShape_;
|
||||
|
||||
using LayoutQ = LayoutQ_;
|
||||
using LayoutK = LayoutK_;
|
||||
using LayoutV = LayoutV_;
|
||||
|
||||
// Options
|
||||
static constexpr bool kIsPersistent = find_option_t<Tag::kIsPersistent, false_type, Options...>::value;
|
||||
static constexpr bool kIsMainloopLocked = find_option_t<Tag::kIsMainloopLocked, false_type, Options...>::value;
|
||||
|
||||
static constexpr int NumLoadWarpGroups = 1;
|
||||
static constexpr int NumMmaWarpGroups = find_option_t<Tag::kNumMmaWarpGroups, Int<2>, Options...>::value;
|
||||
static constexpr int StageCount = find_option_t<Tag::kStagesKV, Int<5>, Options...>::value;
|
||||
static constexpr int StageCountQ = find_option_t<Tag::kStagesQ, Int<NumMmaWarpGroups>, Options...>::value;
|
||||
|
||||
static const int kOuterLoads = 1;
|
||||
using StagesQ = cutlass::gemm::collective::StageCount<StageCountQ>;
|
||||
using Stages = cutlass::gemm::collective::StageCount<StageCount>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
static_assert(StagesQ::value >= NumMmaWarpGroups);
|
||||
static_assert(Stages::value >= 2);
|
||||
|
||||
// 16B alignment lets us use TMA
|
||||
static constexpr int Alignment = 16 / sizeof(Element);
|
||||
|
||||
using TileShapeQK = Shape<
|
||||
decltype(tuple_element_t<0, TileShape>{} / Int<NumMmaWarpGroups>{}),
|
||||
tuple_element_t<1, TileShape>,
|
||||
tuple_element_t<2, TileShape>>;
|
||||
|
||||
using TileShapePV = decltype(select<0,2,1>(TileShapeQK{}));
|
||||
|
||||
using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Element, LayoutQ, Alignment,
|
||||
Element, LayoutK, Alignment,
|
||||
ElementAccumulatorQK,
|
||||
TileShapeQK, ClusterShape, Stages,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
|
||||
|
||||
using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
// the stride for A does not matter since we do not load from smem at all
|
||||
Element, LayoutK, Alignment,
|
||||
Element, decltype(select<1,0,2>(LayoutV{})), Alignment,
|
||||
ElementAccumulatorPV,
|
||||
TileShapePV, ClusterShape, Stages,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
|
||||
|
||||
using TiledMmaQK = typename CollectiveMmaQK::TiledMma;
|
||||
using TiledMmaPV = decltype(convert_to_gmma_rs(typename CollectiveMmaPV::TiledMma{}));
|
||||
|
||||
using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int<StagesQ::value>{}));
|
||||
using SmemLayoutK = typename CollectiveMmaQK::SmemLayoutB;
|
||||
using SmemLayoutV = typename CollectiveMmaPV::SmemLayoutB;
|
||||
|
||||
using MainloopPipeline = cutlass::PipelineTmaAsync<Stages::value>;
|
||||
using MainloopPipelineQ = cutlass::PipelineTmaAsync<StagesQ::value>;
|
||||
|
||||
using PipelineState = typename cutlass::PipelineState<MainloopPipeline::Stages>;
|
||||
using PipelineStateQ = typename cutlass::PipelineState<MainloopPipelineQ::Stages>;
|
||||
|
||||
static constexpr int kInnerLoadBytes = size(SmemLayoutK{}(_,_,_0{})) * sizeof(Element);
|
||||
static constexpr int kOuterLoadBytes = size(SmemLayoutQ{}(_,_,_0{})) * sizeof(Element);
|
||||
|
||||
using TileShapeOut = TileShapePV;
|
||||
using TiledMmaOut = TiledMmaPV;
|
||||
using ElementOut = ElementAccumulatorPV;
|
||||
|
||||
struct SharedStorage {
|
||||
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
|
||||
union {
|
||||
cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
|
||||
cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
|
||||
};
|
||||
};
|
||||
|
||||
struct Arguments {
|
||||
const Element* ptr_Q;
|
||||
LayoutQ dQ;
|
||||
const Element* ptr_K;
|
||||
LayoutK dK;
|
||||
const Element* ptr_V;
|
||||
LayoutV dV;
|
||||
};
|
||||
|
||||
using TMA_Q = typename CollectiveMmaQK::Params::TMA_A;
|
||||
using TMA_K = typename CollectiveMmaQK::Params::TMA_B;
|
||||
using TMA_V = typename CollectiveMmaPV::Params::TMA_B;
|
||||
|
||||
struct Params {
|
||||
TMA_Q tma_load_q;
|
||||
TMA_K tma_load_k;
|
||||
TMA_V tma_load_v;
|
||||
|
||||
float scale_softmax;
|
||||
float scale_softmax_log2;
|
||||
float rp_dropout;
|
||||
};
|
||||
|
||||
using LoadQ = cutlass::fmha::collective::CollectiveLoadTma<
|
||||
cutlass::fmha::collective::LoadKind::kQ,
|
||||
MainloopPipelineQ,
|
||||
Element,
|
||||
SmemLayoutQ,
|
||||
TMA_Q
|
||||
>;
|
||||
|
||||
using LoadK = cutlass::fmha::collective::CollectiveLoadTma<
|
||||
cutlass::fmha::collective::LoadKind::kK,
|
||||
MainloopPipeline,
|
||||
Element,
|
||||
SmemLayoutK,
|
||||
TMA_K
|
||||
>;
|
||||
|
||||
using LoadV = cutlass::fmha::collective::CollectiveLoadTma<
|
||||
cutlass::fmha::collective::LoadKind::kV,
|
||||
MainloopPipeline,
|
||||
Element,
|
||||
SmemLayoutV,
|
||||
TMA_V
|
||||
>;
|
||||
|
||||
static_assert(size(typename CollectiveMmaQK::TiledMma{}) == size(typename CollectiveMmaPV::TiledMma{}));
|
||||
|
||||
template<class ProblemShape>
|
||||
static bool can_implement(ProblemShape const& problem_size, Arguments const& args) {
|
||||
return true
|
||||
&& (get<4>(problem_size) <= get<2>(TileShape{}))
|
||||
&& ((get<4>(problem_size) % Alignment) == 0)
|
||||
&& ((get<2>(problem_size) % Alignment) == 0)
|
||||
;
|
||||
}
|
||||
|
||||
template<class ProblemShape>
|
||||
static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace) {
|
||||
|
||||
auto problem_shape_qk = make_shape(get<2>(problem_size), get<3>(problem_size), get<4>(problem_size), make_shape(get<0>(problem_size), get<1>(problem_size)));
|
||||
auto params_qk = CollectiveMmaQK::to_underlying_arguments(problem_shape_qk,
|
||||
typename CollectiveMmaQK::Arguments {
|
||||
args.ptr_Q, args.dQ,
|
||||
args.ptr_K, args.dK,
|
||||
}, /*workspace=*/ nullptr);
|
||||
|
||||
auto problem_shape_pv = select<0,2,1,3>(problem_shape_qk);
|
||||
auto params_pv = CollectiveMmaPV::to_underlying_arguments(problem_shape_pv,
|
||||
typename CollectiveMmaPV::Arguments {
|
||||
args.ptr_K, args.dK, // never used, dummy
|
||||
args.ptr_V, select<1,0,2>(args.dV),
|
||||
}, /*workspace=*/ nullptr);
|
||||
|
||||
return Params{
|
||||
params_qk.tma_load_a,
|
||||
params_qk.tma_load_b,
|
||||
params_pv.tma_load_b,
|
||||
1.0f / (float) std::sqrt(get<4>(problem_size)),
|
||||
(float) (std::log2(std::exp(1.0)) / std::sqrt(get<4>(problem_size))),
|
||||
1.0f
|
||||
};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& params) {
|
||||
cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor());
|
||||
}
|
||||
|
||||
template<bool kLoadQ, class BlkCoord, class ProblemShape, class LoadWarpBarrier>
|
||||
CUTLASS_DEVICE void
|
||||
load_kv_maybe_q(
|
||||
int block_rank_in_cluster,
|
||||
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size,
|
||||
MainloopPipeline& pipeline, PipelineState& smem_pipe_write,
|
||||
MainloopPipelineQ& pipeline_q, PipelineStateQ& smem_pipe_write_q,
|
||||
SharedStorage& storage,
|
||||
LoadWarpBarrier& load_warp_barrier, bool do_barrier)
|
||||
{
|
||||
int fusion_tile_count = Fusion{}.get_trip_count(blk_coord, TileShape{}, problem_size);
|
||||
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
uint16_t mcast_mask_b = 0;
|
||||
|
||||
if (lane_predicate == 1) {
|
||||
if constexpr (cute::is_same_v<typename CollectiveMmaQK::GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int m = 0; m < size<0>(block_layout); ++m) {
|
||||
mcast_mask_b |= (uint16_t(1) << block_layout(m,_0{},Int<0>{}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto q_tile_iter = cute::make_coord_iterator(Int<NumMmaWarpGroups>{});
|
||||
[[maybe_unused]] int q_tile_count = NumMmaWarpGroups;
|
||||
|
||||
auto k_tile_iter = cute::make_coord_iterator(fusion_tile_count);
|
||||
int k_tile_count = 2 * fusion_tile_count;
|
||||
|
||||
LoadQ load_q{params.tma_load_q, pipeline_q, storage.smem_q};
|
||||
auto load_state_q = load_q.init_state(_0{}, problem_size, TileShapeQK{}, blk_coord, NumMmaWarpGroups);
|
||||
|
||||
LoadK load_k{params.tma_load_k, pipeline, storage.smem_k};
|
||||
auto load_state_k = load_k.init_state(block_rank_in_cluster, problem_size, TileShapeQK{}, blk_coord, fusion_tile_count);
|
||||
|
||||
LoadV load_v{params.tma_load_v, pipeline, storage.smem_v};
|
||||
auto load_state_v = load_v.init_state(block_rank_in_cluster, problem_size, TileShapePV{}, blk_coord, fusion_tile_count);
|
||||
|
||||
if constexpr (kLoadQ) {
|
||||
load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, q_tile_count);
|
||||
}
|
||||
|
||||
load_k.template step<false>(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count, mcast_mask_b);
|
||||
|
||||
if constexpr (kLoadQ) {
|
||||
load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, q_tile_count);
|
||||
}
|
||||
|
||||
if constexpr (! kLoadQ) {
|
||||
if (do_barrier) {
|
||||
load_warp_barrier.arrive();
|
||||
load_warp_barrier.wait(/*phase=*/ 0);
|
||||
do_barrier = false;
|
||||
}
|
||||
}
|
||||
|
||||
load_v.template step<true>(k_tile_iter, load_state_v, smem_pipe_write, lane_predicate, k_tile_count, mcast_mask_b);
|
||||
|
||||
if constexpr (kLoadQ) {
|
||||
while (q_tile_count > 0) {
|
||||
load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, q_tile_count);
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
while (k_tile_count > 0) {
|
||||
load_k.template step<false>(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count, mcast_mask_b);
|
||||
load_v.template step<true>(k_tile_iter, load_state_v, smem_pipe_write, lane_predicate, k_tile_count, mcast_mask_b);
|
||||
}
|
||||
}
|
||||
|
||||
template<class BlkCoord, class ProblemShape, class LoadWarpBarrier>
|
||||
CUTLASS_DEVICE void
|
||||
load_maybe_q(
|
||||
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size,
|
||||
MainloopPipelineQ& pipeline_q, PipelineStateQ& smem_pipe_write_q,
|
||||
SharedStorage& storage,
|
||||
LoadWarpBarrier& load_warp_barrier, bool do_barrier)
|
||||
{
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
LoadQ load_q{params.tma_load_q, pipeline_q, storage.smem_q};
|
||||
auto load_state_q = load_q.init_state(_0{}, problem_size, TileShapeQK{}, blk_coord, NumMmaWarpGroups);
|
||||
|
||||
auto q_tile_iter = cute::make_coord_iterator(Int<NumMmaWarpGroups>{});
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int q_tile_count = 0; q_tile_count < NumMmaWarpGroups; q_tile_count++) {
|
||||
int count = 1;
|
||||
load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, count);
|
||||
if (q_tile_count == 0 && do_barrier) {
|
||||
load_warp_barrier.arrive();
|
||||
load_warp_barrier.wait(/*phase=*/ 0);
|
||||
do_barrier = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<class BlkCoord, class ProblemShape, class MainloopPipelineReducer, class PipelineStateReducer>
|
||||
CUTLASS_DEVICE void
|
||||
reduce(
|
||||
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size,
|
||||
MainloopPipelineReducer& pipeline_reducer, PipelineStateReducer& smem_pipe_write_reducer,
|
||||
SharedStorage& storage)
|
||||
{ /* no-op */ }
|
||||
|
||||
template<class BlkCoord, class ProblemShape, class MainloopPipelineReducer, class PipelineStateReducer, class MathWgOrderBarrier>
|
||||
CUTLASS_DEVICE auto
|
||||
compute(
|
||||
BlkCoord const& blk_coord, BlkCoord const& wg_coord,
|
||||
Params const& params, ProblemShape const& problem_size,
|
||||
MainloopPipeline& pipeline, PipelineState& smem_pipe_read,
|
||||
MainloopPipelineQ& pipeline_q, PipelineStateQ& smem_pipe_read_q,
|
||||
MainloopPipelineReducer&, PipelineStateReducer&,
|
||||
SharedStorage& storage,
|
||||
MathWgOrderBarrier& math_wg_order_barrier)
|
||||
{
|
||||
int thread_idx = int(threadIdx.x);
|
||||
|
||||
PipelineState smem_pipe_release = smem_pipe_read;
|
||||
PipelineStateQ smem_pipe_release_q = smem_pipe_read_q;
|
||||
|
||||
TiledMmaQK tiled_mma_qk;
|
||||
auto thr_mma_qk = tiled_mma_qk.get_thread_slice(thread_idx);
|
||||
|
||||
// Mainloop setup QK
|
||||
Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{});
|
||||
Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{});
|
||||
|
||||
Tensor tSsQ = thr_mma_qk.partition_A(sQ); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tSsK = thr_mma_qk.partition_B(sK); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
Tensor tSrQ = thr_mma_qk.make_fragment_A(tSsQ); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
Tensor tSrK = thr_mma_qk.make_fragment_B(tSsK); // (MMA,MMA_M,MMA_N,PIPE)
|
||||
|
||||
// Prepare: MMA PV
|
||||
TiledMmaPV tiled_mma_pv;
|
||||
auto thr_mma_pv = tiled_mma_pv.get_thread_slice(thread_idx);
|
||||
|
||||
// Mainloop setup PV
|
||||
Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{});
|
||||
|
||||
Tensor tOsV = thr_mma_pv.partition_B(sV); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
Tensor tOrV = thr_mma_pv.make_fragment_B(tOsV); // (MMA,MMA_M,MMA_N,PIPE)
|
||||
|
||||
int k_tile_count = Fusion{}.get_unmasked_trip_count(blk_coord, TileShape{}, problem_size);
|
||||
|
||||
pipeline_q.consumer_wait(smem_pipe_read_q);
|
||||
|
||||
// mapping into QK accumulator
|
||||
Tensor cP = make_identity_tensor(take<0,2>(TileShapeQK{}));
|
||||
Tensor tPcP = thr_mma_qk.partition_C(cP);
|
||||
int m_block = get<0>(wg_coord);
|
||||
tPcP.data() = tPcP.data() + E<0>{} * m_block * get<0>(TileShapeQK{});
|
||||
|
||||
// Allocate PV acc
|
||||
Tensor acc_pv = partition_fragment_C(tiled_mma_pv, take<0, 2>(TileShapePV{}));
|
||||
|
||||
cutlass::fmha::collective::CollectiveSoftmax<ElementAccumulatorQK, Fusion, decltype(params)> softmax{params};
|
||||
auto softmax_state = softmax.init(acc_pv, tiled_mma_pv);
|
||||
|
||||
if (true)
|
||||
{
|
||||
--k_tile_count;
|
||||
// Allocate QK acc
|
||||
Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{}));
|
||||
|
||||
pipeline.consumer_wait(smem_pipe_read);
|
||||
math_wg_order_barrier.wait();
|
||||
|
||||
// MMA QK
|
||||
warpgroup_fence_operand(acc_qk);
|
||||
warpgroup_arrive();
|
||||
|
||||
gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,smem_pipe_read_q.index()), tSrK(_,_,_,smem_pipe_read.index()), acc_qk);
|
||||
warpgroup_commit_batch();
|
||||
math_wg_order_barrier.arrive();
|
||||
|
||||
++smem_pipe_read;
|
||||
|
||||
// Wait for the pipeline MMAs to drain
|
||||
warpgroup_wait<0>();
|
||||
warpgroup_fence_operand(acc_qk);
|
||||
|
||||
softmax.step(acc_qk, tiled_mma_qk, tPcP, softmax_state, problem_size);
|
||||
|
||||
Tensor acc_qk_fixed = make_acc_into_op<Element>(acc_qk, typename TiledMmaPV::LayoutA_TV{});
|
||||
|
||||
pipeline.consumer_wait(smem_pipe_read);
|
||||
|
||||
// MMA PV
|
||||
warpgroup_fence_operand(acc_pv);
|
||||
warpgroup_fence_operand(acc_qk_fixed);
|
||||
warpgroup_arrive();
|
||||
|
||||
gemm_zero_acc(tiled_mma_pv, acc_qk_fixed, tOrV(_,_,_,smem_pipe_read.index()), acc_pv);
|
||||
warpgroup_commit_batch();
|
||||
|
||||
pipeline.consumer_release(smem_pipe_release);
|
||||
++smem_pipe_release;
|
||||
|
||||
// Advance consumer pipeline
|
||||
++smem_pipe_read;
|
||||
tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{});
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
while (k_tile_count > 0)
|
||||
{
|
||||
--k_tile_count;
|
||||
|
||||
// Allocate QK acc
|
||||
Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{}));
|
||||
|
||||
pipeline.consumer_wait(smem_pipe_read);
|
||||
|
||||
// MMA QK
|
||||
warpgroup_fence_operand(acc_qk);
|
||||
warpgroup_arrive();
|
||||
|
||||
gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,smem_pipe_read_q.index()), tSrK(_,_,_,smem_pipe_read.index()), acc_qk);
|
||||
warpgroup_commit_batch();
|
||||
|
||||
++smem_pipe_read;
|
||||
auto tok = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
|
||||
// Wait for the pipeline MMAs to drain
|
||||
warpgroup_wait<0>();
|
||||
warpgroup_fence_operand(acc_qk);
|
||||
warpgroup_fence_operand(acc_pv);
|
||||
|
||||
if constexpr (kIsMainloopLocked) math_wg_order_barrier.wait();
|
||||
softmax.template step<false>(acc_qk, tiled_mma_qk, tPcP, softmax_state, acc_pv, tiled_mma_pv, problem_size);
|
||||
if constexpr (kIsMainloopLocked) math_wg_order_barrier.arrive();
|
||||
|
||||
Tensor acc_qk_fixed = make_acc_into_op<Element>(acc_qk, typename TiledMmaPV::LayoutA_TV{});
|
||||
|
||||
pipeline.consumer_wait(smem_pipe_read, tok);
|
||||
|
||||
// MMA PV
|
||||
warpgroup_fence_operand(acc_pv);
|
||||
warpgroup_fence_operand(acc_qk_fixed);
|
||||
warpgroup_arrive();
|
||||
|
||||
cute::gemm(tiled_mma_pv, acc_qk_fixed, tOrV(_,_,_,smem_pipe_read.index()), acc_pv);
|
||||
warpgroup_commit_batch();
|
||||
|
||||
pipeline.consumer_release(smem_pipe_release);
|
||||
++smem_pipe_release;
|
||||
|
||||
pipeline.consumer_release(smem_pipe_release);
|
||||
++smem_pipe_release;
|
||||
|
||||
++smem_pipe_read;
|
||||
tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{});
|
||||
}
|
||||
|
||||
k_tile_count += Fusion{}.get_masked_trip_count(blk_coord, TileShape{}, problem_size);
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
while (k_tile_count > 0)
|
||||
{
|
||||
--k_tile_count;
|
||||
|
||||
// Allocate QK acc
|
||||
Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{}));
|
||||
|
||||
pipeline.consumer_wait(smem_pipe_read);
|
||||
|
||||
// MMA QK
|
||||
warpgroup_fence_operand(acc_qk);
|
||||
warpgroup_arrive();
|
||||
|
||||
gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,smem_pipe_read_q.index()), tSrK(_,_,_,smem_pipe_read.index()), acc_qk);
|
||||
warpgroup_commit_batch();
|
||||
|
||||
++smem_pipe_read;
|
||||
auto tok = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
|
||||
// Wait for the pipeline MMAs to drain
|
||||
warpgroup_wait<0>();
|
||||
warpgroup_fence_operand(acc_qk);
|
||||
warpgroup_fence_operand(acc_pv);
|
||||
|
||||
//if constexpr (kIsPersistent)
|
||||
// if (k_tile_count == 0) pipeline_q.consumer_release(smem_pipe_release_q);
|
||||
|
||||
if constexpr (kIsMainloopLocked) math_wg_order_barrier.wait();
|
||||
softmax.step(acc_qk, tiled_mma_qk, tPcP, softmax_state, acc_pv, tiled_mma_pv, problem_size);
|
||||
if constexpr (kIsMainloopLocked) math_wg_order_barrier.arrive();
|
||||
|
||||
Tensor acc_qk_fixed = make_acc_into_op<Element>(acc_qk, typename TiledMmaPV::LayoutA_TV{});
|
||||
|
||||
pipeline.consumer_wait(smem_pipe_read, tok);
|
||||
|
||||
// MMA PV
|
||||
warpgroup_fence_operand(acc_pv);
|
||||
warpgroup_fence_operand(acc_qk_fixed);
|
||||
warpgroup_arrive();
|
||||
|
||||
cute::gemm(tiled_mma_pv, acc_qk_fixed, tOrV(_,_,_,smem_pipe_read.index()), acc_pv);
|
||||
warpgroup_commit_batch();
|
||||
|
||||
pipeline.consumer_release(smem_pipe_release);
|
||||
++smem_pipe_release;
|
||||
|
||||
pipeline.consumer_release(smem_pipe_release);
|
||||
++smem_pipe_release;
|
||||
|
||||
++smem_pipe_read;
|
||||
tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{});
|
||||
}
|
||||
|
||||
if (kIsPersistent) pipeline_q.consumer_release(smem_pipe_release_q);
|
||||
|
||||
// Wait for the pipeline MMAs to drain
|
||||
warpgroup_wait<0>();
|
||||
warpgroup_fence_operand(acc_pv);
|
||||
|
||||
if (kIsPersistent) pipeline.consumer_release(smem_pipe_release);
|
||||
++smem_pipe_release;
|
||||
|
||||
Tensor lse = softmax.tail(softmax_state, acc_pv, tiled_mma_pv);
|
||||
|
||||
return make_tuple(acc_pv, lse);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
245
examples/88_hopper_fmha/collective/fmha_common.hpp
Normal file
245
examples/88_hopper_fmha/collective/fmha_common.hpp
Normal file
@ -0,0 +1,245 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/kernel_hardware_info.h"
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
namespace cutlass::fmha::collective {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<typename Atom, typename TA, typename TB, typename TC>
|
||||
CUTE_DEVICE void gemm_reset_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) {
|
||||
constexpr int rA = decltype(rank(tA))::value;
|
||||
constexpr int rB = decltype(rank(tB))::value;
|
||||
constexpr int rC = decltype(rank(tC))::value;
|
||||
if constexpr (rA == 2 && rB == 2 && rC == 1) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<1>(tA); k_block++) {
|
||||
cute::gemm(atom, tA(_,k_block), tB(_,k_block), tC);
|
||||
atom.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
} else {
|
||||
static_assert(rA == 3 && rB == 3 && rC == 3);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tA); k_block++) {
|
||||
cute::gemm(atom, tA(_,_,k_block), tB(_,_,k_block), tC);
|
||||
atom.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Atom, typename TA, typename TB, typename TC>
|
||||
CUTE_DEVICE void gemm_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) {
|
||||
atom.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
gemm_reset_zero_acc(atom, tA, tB, tC);
|
||||
}
|
||||
|
||||
template<typename T, typename Fn>
|
||||
CUTE_DEVICE constexpr typename T::value_type reduce(T const& t, Fn fn) {
|
||||
if constexpr (decltype(size(t) % _2{} == _0{})::value) {
|
||||
auto partial = make_tensor<typename T::value_type>(size(t) / _2{});
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(partial); i++) {
|
||||
partial(i) = fn(t(i), t(i + size(partial)));
|
||||
}
|
||||
return reduce(partial, fn);
|
||||
} else {
|
||||
auto result = t(_0{});
|
||||
CUTE_UNROLL
|
||||
for (int i = 1; i < size(t); i++) {
|
||||
result = fn(result, t(i));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
struct fmha_max {
|
||||
CUTE_DEVICE float operator()(float a, float b) { return ::max(a, b); }
|
||||
};
|
||||
|
||||
template<typename Threshold, typename Source, typename Reference>
|
||||
inline auto __device__ constexpr layout_separate(Threshold const& thr,
|
||||
Source const& src, Reference const& ref) {
|
||||
auto lt = filter(transform_layout(src, ref, [&](auto const& s, auto const& r) {
|
||||
if constexpr(decltype(r < thr)::value) {
|
||||
return s;
|
||||
} else {
|
||||
return make_layout(_1{}, _0{});
|
||||
}
|
||||
}));
|
||||
auto ge = filter(transform_layout(src, ref, [&](auto const& s, auto const& r) {
|
||||
if constexpr(decltype(r >= thr)::value) {
|
||||
return s;
|
||||
} else {
|
||||
return make_layout(_1{}, _0{});
|
||||
}
|
||||
}));
|
||||
return make_layout(lt, ge);
|
||||
}
|
||||
|
||||
template<typename TiledMma, typename Acc>
|
||||
inline auto __device__ constexpr layout_acc_mn(TiledMma const& tiled_mma, Acc const& acc) {
|
||||
auto separated = layout_separate(get<0>(typename TiledMma::Shape_MNK{}),
|
||||
get<0>(acc), stride<1>(typename TiledMma::LayoutC_TV{}));
|
||||
auto V_M = get<0>(separated);
|
||||
auto V_N = get<1>(separated);
|
||||
return make_layout(make_layout(V_M, get<1>(acc)), make_layout(V_N, get<2>(acc)));
|
||||
}
|
||||
|
||||
template<typename TiledMma, typename Acc>
|
||||
inline auto __device__ constexpr layout_op_mk_v(TiledMma const& tiled_mma, Acc const& acc) {
|
||||
return layout_separate(get<0>(typename TiledMma::Shape_MNK{}),
|
||||
get<0>(acc), stride<1>(typename TiledMma::LayoutA_TV{}));
|
||||
}
|
||||
|
||||
template<typename TiledMma, typename Acc>
|
||||
inline auto __device__ constexpr tensor_op_mk_v(TiledMma const& tiled_mma, Acc&& acc) {
|
||||
return make_tensor(acc.data(), layout_op_mk_v(tiled_mma, acc.layout()));
|
||||
}
|
||||
|
||||
template<typename TiledMma>
|
||||
inline auto __device__ constexpr reduction_target_n(TiledMma const& tiled_mma) {
|
||||
auto separated = layout_separate(get<0>(typename TiledMma::Shape_MNK{}),
|
||||
make_layout(shape<0>(typename TiledMma::LayoutC_TV{})),
|
||||
stride<0>(typename TiledMma::LayoutC_TV{}));
|
||||
return get<1>(separated);
|
||||
}
|
||||
|
||||
|
||||
template<template<cute::GMMA::Major, cute::GMMA::Major, cute::GMMA::ScaleIn, cute::GMMA::ScaleIn> class Primitive, cute::GMMA::Major tA, cute::GMMA::Major tB, cute::GMMA::ScaleIn sA, cute::GMMA::ScaleIn sB>
|
||||
inline auto __device__ constexpr convert_to_gmma_rs(cute::MMA_Atom<Primitive<tA, tB, sA, sB>> const& tiled_mma) {
|
||||
using Atom = cute::MMA_Atom<Primitive<tA, tB, sA, sB>>;
|
||||
using ElementA = typename Atom::ValTypeA;
|
||||
using ElementB = typename Atom::ValTypeB;
|
||||
using ElementC = typename Atom::ValTypeC;
|
||||
using Shape_MNK = typename Atom::Shape_MNK;
|
||||
using RS = decltype(cute::GMMA::rs_op_selector<ElementA, ElementB, ElementC, Shape_MNK, tA, tB, sA, sB>());
|
||||
return cute::MMA_Atom<RS>{};
|
||||
}
|
||||
|
||||
template<template<cute::GMMA::ScaleIn, cute::GMMA::ScaleIn> class Primitive, cute::GMMA::ScaleIn sA, cute::GMMA::ScaleIn sB>
|
||||
inline auto __device__ constexpr convert_to_gmma_rs(cute::MMA_Atom<Primitive<sA, sB>> const& tiled_mma) {
|
||||
using Atom = cute::MMA_Atom<Primitive<sA, sB>>;
|
||||
using ElementA = typename Atom::ValTypeA;
|
||||
using ElementB = typename Atom::ValTypeB;
|
||||
using ElementC = typename Atom::ValTypeC;
|
||||
using Shape_MNK = typename Atom::Shape_MNK;
|
||||
constexpr auto tA = cute::GMMA::Major::K;
|
||||
constexpr auto tB = cute::GMMA::Major::K;
|
||||
using RS = decltype(cute::GMMA::rs_op_selector<ElementA, ElementB, ElementC, Shape_MNK, tA, tB, sA, sB>());
|
||||
return cute::MMA_Atom<RS>{};
|
||||
}
|
||||
|
||||
template<class Atom, class... Args>
|
||||
CUTE_DEVICE auto constexpr convert_to_gmma_rs(cute::TiledMMA<Atom, Args...> const& tiled_mma) {
|
||||
return cute::TiledMMA<decltype(convert_to_gmma_rs(Atom{})), Args...>{};
|
||||
}
|
||||
|
||||
template<typename CLayout, typename AValueShape>
|
||||
CUTE_DEVICE auto constexpr convert_c_layout_to_a_layout(CLayout const& c, AValueShape const& a) {
|
||||
return make_layout(
|
||||
make_shape(a, shape<1>(c), make_shape(shape<2>(c), size<0>(c) / size(a))),
|
||||
make_stride(stride<0>(c), stride<1>(c), make_stride(stride<2>(c), size<2>(a) * stride<0,2>(c))));
|
||||
}
|
||||
|
||||
template<class Layout, class Stages = _1>
|
||||
CUTE_DEVICE constexpr auto unstageSmemLayout(Layout const& layout, Stages stages = {}) {
|
||||
return composition(layout, make_tuple(_, _, make_layout(stages)));
|
||||
}
|
||||
|
||||
template<class Element, class Accumulator, class OperandLayout_TV>
|
||||
CUTE_DEVICE auto make_acc_into_op(Accumulator const& acc, OperandLayout_TV const& operand_layout_tv) {
|
||||
Tensor operand = make_fragment_like<Element>(convert_c_layout_to_a_layout(acc.layout(), shape<1>(operand_layout_tv)));
|
||||
Tensor operand_as_acc = make_tensor(operand.data(), acc.layout());
|
||||
|
||||
cute::copy(acc, operand_as_acc);
|
||||
|
||||
if constexpr (sizeof(Element) == 1) {
|
||||
|
||||
// 00 11 22 33 00 11 22 33 acc layout
|
||||
// 00 00 11 11 22 22 33 33 operand layout
|
||||
// BB AA AA BB AA BB BB AA conflict-free exchange pattern
|
||||
// 16-bit exchange; so process two at a time potentially
|
||||
int tid = threadIdx.x % 4;
|
||||
auto values_u32 = recast<uint32_t>(operand);
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int n = 0; n < size<1>(values_u32); n++) {
|
||||
CUTE_UNROLL
|
||||
for (int k = 0; k < size<2>(values_u32); k++) {
|
||||
CUTE_UNROLL
|
||||
for (int ii = 0; ii < 8; ii += 4) {
|
||||
|
||||
uint32_t values_tmp_0 = values_u32(ii / 2 + 0, n, k);
|
||||
uint32_t values_tmp_1 = values_u32(ii / 2 + 1, n, k);
|
||||
|
||||
// step A:
|
||||
// t 1 v 0 -> t 0 v 1
|
||||
// t 2 v 0 -> t 1 v 0
|
||||
// t 0 v 1 -> t 2 v 0
|
||||
// t 3 v 1 -> t 3 v 1
|
||||
|
||||
int v_to_send = tid == 1 || tid == 2 ? 0 : 1;
|
||||
int v_to_recv = v_to_send;
|
||||
int t_to_recv_from = (0x3021 >> (tid * 4)) & 0xF;
|
||||
|
||||
uint32_t values_tmp_a = v_to_send == 0 ? values_tmp_0 : values_tmp_1;
|
||||
|
||||
values_tmp_a = __shfl_sync(0xFFFFFFFF, values_tmp_a, t_to_recv_from, 4);
|
||||
|
||||
// step B:
|
||||
// t 0 v 0 -> t 0 v 0
|
||||
// t 3 v 0 -> t 1 v 1
|
||||
// t 1 v 1 -> t 2 v 1
|
||||
// t 2 v 1 -> t 3 v 0
|
||||
|
||||
v_to_send = 1 - v_to_send;
|
||||
v_to_recv = 1 - v_to_recv;
|
||||
t_to_recv_from = (0x2130 >> (tid * 4)) & 0xF;
|
||||
|
||||
uint32_t values_tmp_b = v_to_send == 0 ? values_tmp_0 : values_tmp_1;
|
||||
|
||||
values_tmp_b = __shfl_sync(0xFFFFFFFF, values_tmp_b, t_to_recv_from, 4);
|
||||
|
||||
values_u32(ii / 2 + 0, n, k) = __byte_perm(values_tmp_a, values_tmp_b, v_to_send == 0 ? 0x1054 : 0x5410);
|
||||
values_u32(ii / 2 + 1, n, k) = __byte_perm(values_tmp_a, values_tmp_b, v_to_send == 0 ? 0x3276 : 0x7632);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return operand;
|
||||
}
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
156
examples/88_hopper_fmha/collective/fmha_epilogue.hpp
Normal file
156
examples/88_hopper_fmha/collective/fmha_epilogue.hpp
Normal file
@ -0,0 +1,156 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
|
||||
#include "../collective/fmha_common.hpp"
|
||||
|
||||
namespace cutlass::fmha::collective {
|
||||
|
||||
template<class Element, class ElementAccumulator, class TileShape_WG>
|
||||
struct FmhaFwdEpilogue {
|
||||
|
||||
static constexpr int Alignment = 16 / sizeof(Element);
|
||||
|
||||
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<Element, ElementAccumulator, void>;
|
||||
using CollectiveEpilogueTMA = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape_WG, Shape<_1,_1,_1>, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
void, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment,
|
||||
Element, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment,
|
||||
cutlass::epilogue::TmaWarpSpecialized,
|
||||
DefaultOperation
|
||||
>::CollectiveOp;
|
||||
|
||||
struct Arguments {
|
||||
Element* ptr_O;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> dO;
|
||||
|
||||
ElementAccumulator* ptr_LSE;
|
||||
cute::tuple<cute::_1, cute::tuple<int, int>> dLSE;
|
||||
};
|
||||
|
||||
struct Params {
|
||||
ElementAccumulator* ptr_LSE;
|
||||
cute::tuple<cute::_1, cute::tuple<int, int>> dLSE;
|
||||
|
||||
typename CollectiveEpilogueTMA::Params epilogue_TMA;
|
||||
};
|
||||
|
||||
using TensorStorage = typename CollectiveEpilogueTMA::TensorStorage;
|
||||
using PipelineStorage = typename CollectiveEpilogueTMA::PipelineStorage;
|
||||
using LoadPipeline = typename CollectiveEpilogueTMA::LoadPipeline;
|
||||
static constexpr int TmaTransactionBytes = CollectiveEpilogueTMA::TmaTransactionBytes;
|
||||
|
||||
template<class ProblemShape>
|
||||
static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace = nullptr) {
|
||||
auto problem_size_o = make_shape(get<2>(problem_size), get<4>(problem_size), 1,
|
||||
make_shape(get<0>(problem_size), get<1>(problem_size)));
|
||||
typename CollectiveEpilogueTMA::Arguments args_tma{{}, args.ptr_O, args.dO, args.ptr_O, args.dO};
|
||||
return Params{
|
||||
args.ptr_LSE, args.dLSE,
|
||||
CollectiveEpilogueTMA::to_underlying_arguments(problem_size_o, args_tma, workspace)
|
||||
};
|
||||
}
|
||||
|
||||
template<class TileShape, class BlkCoord, class ResultTuple, class TiledMma, class ProblemShape>
|
||||
CUTLASS_DEVICE void operator()(
|
||||
TileShape const& tile_shape, BlkCoord const& blk_coord,
|
||||
ResultTuple const& result, TiledMma const& tiled_mma,
|
||||
ProblemShape const& problem_size, Params const& params,
|
||||
LoadPipeline epi_load_pipeline,
|
||||
TensorStorage& epi_tensor_storage)
|
||||
{
|
||||
using X = Underscore;
|
||||
|
||||
auto acc = get<0>(result);
|
||||
auto lse = get<1>(result);
|
||||
|
||||
auto thr_mma = tiled_mma.get_thread_slice(threadIdx.x);
|
||||
|
||||
int seqlen_q = get<2>(problem_size);
|
||||
int num_batch = get<0>(problem_size);
|
||||
int num_heads = get<1>(problem_size);
|
||||
// Epilogue for lse
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE),
|
||||
make_shape(seqlen_q, get<1>(tile_shape), make_shape(num_batch, num_heads)),
|
||||
make_stride(_1{}, _0{}, get<1>(params.dLSE)));
|
||||
Tensor gLSE_full = local_tile(mLSE, tile_shape, make_coord(_, _, _), Step<_1, _1, X>{});
|
||||
Tensor gLSE = gLSE_full(_, _, get<0>(blk_coord), get<1>(blk_coord), get<2>(blk_coord));
|
||||
Tensor tOgLSE = thr_mma.partition_C(gLSE);
|
||||
Tensor cO = make_identity_tensor(take<0,2>(tile_shape));
|
||||
Tensor tOcO = thr_mma.partition_C(cO);
|
||||
if (get<1>(tOcO(_0{})) == 0) {
|
||||
auto tOgLSE_mn = make_tensor(tOgLSE.data(), layout_acc_mn(tiled_mma, tOgLSE.layout()));
|
||||
auto tOcO_mn = make_tensor(tOcO.data(), layout_acc_mn(tiled_mma, tOcO.layout()));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<0>(tOgLSE_mn); i++) {
|
||||
if (get<0>(tOcO_mn(i)) + get<0>(blk_coord) * get<0>(tile_shape) < get<2>(problem_size)) {
|
||||
tOgLSE_mn(i, _0{}) = lse(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
auto problem_size_o = make_shape(get<2>(problem_size), get<4>(problem_size), _,
|
||||
make_shape(get<0>(problem_size), get<1>(problem_size)));
|
||||
|
||||
CollectiveEpilogueTMA epilogue_tma(params.epilogue_TMA, epi_tensor_storage);
|
||||
|
||||
using EpiStorePipeline = typename CollectiveEpilogueTMA::StorePipeline;
|
||||
typename EpiStorePipeline::Params epi_store_pipeline_params;
|
||||
epi_store_pipeline_params.always_wait = true;
|
||||
EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
|
||||
|
||||
typename CollectiveEpilogueTMA::LoadPipelineState epi_load_pipe_consumer_state;
|
||||
PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>();
|
||||
|
||||
auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] =
|
||||
epilogue_tma.store(
|
||||
epi_load_pipeline, epi_load_pipe_consumer_state,
|
||||
epi_store_pipeline, epi_store_pipe_producer_state,
|
||||
problem_size_o, tile_shape, make_coord(get<0>(blk_coord), _0{}, _, get<2>(blk_coord)),
|
||||
acc, tiled_mma, threadIdx.x % cutlass::NumThreadsPerWarpGroup,
|
||||
epi_tensor_storage
|
||||
);
|
||||
|
||||
epilogue_tma.store_tail(
|
||||
epi_load_pipeline, epi_load_pipe_consumer_state_next,
|
||||
epi_store_pipeline, epi_store_pipe_producer_state_next
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
157
examples/88_hopper_fmha/collective/fmha_epilogue_bwd.hpp
Normal file
157
examples/88_hopper_fmha/collective/fmha_epilogue_bwd.hpp
Normal file
@ -0,0 +1,157 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
|
||||
#include "../collective/fmha_epilogue.hpp"
|
||||
|
||||
namespace cutlass::fmha::collective {
|
||||
|
||||
template<class Element, class ElementAccumulator, class TileShape_WG>
|
||||
struct FmhaBwdEpilogueKV {
|
||||
|
||||
static constexpr int Alignment = 16 / sizeof(Element);
|
||||
|
||||
struct Arguments {
|
||||
Element* ptr_K;
|
||||
cute::tuple<int, int, int, cute::_1> dK;
|
||||
|
||||
Element* ptr_V;
|
||||
cute::tuple<int, int, int, _1> dV;
|
||||
};
|
||||
|
||||
//using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<Element, ElementAccumulator, void>;
|
||||
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
using DefaultOperation = cutlass::epilogue::fusion::Sm90EVT<
|
||||
cutlass::epilogue::fusion::Sm90Compute<cutlass::first, Element, ElementAccumulator, RoundStyle>,
|
||||
cutlass::epilogue::fusion::Sm90AccFetch
|
||||
>;
|
||||
using CollectiveEpilogueTMA = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape_WG, Shape<_1,_1,_1>, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
void, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment,
|
||||
Element, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment,
|
||||
cutlass::epilogue::TmaWarpSpecialized,
|
||||
DefaultOperation
|
||||
>::CollectiveOp;
|
||||
|
||||
struct Params {
|
||||
typename CollectiveEpilogueTMA::Params epilogue_K;
|
||||
typename CollectiveEpilogueTMA::Params epilogue_V;
|
||||
};
|
||||
|
||||
|
||||
using TensorStorage = typename CollectiveEpilogueTMA::TensorStorage[2];
|
||||
using PipelineStorage = typename CollectiveEpilogueTMA::PipelineStorage;
|
||||
using LoadPipeline = typename CollectiveEpilogueTMA::LoadPipeline;
|
||||
static constexpr int TmaTransactionBytes = CollectiveEpilogueTMA::TmaTransactionBytes;
|
||||
|
||||
template<class ProblemShape>
|
||||
static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace = nullptr) {
|
||||
auto dK = make_stride(get<2>(args.dK), get<3>(args.dK),
|
||||
make_stride(get<0>(args.dK), get<1>(args.dK)));
|
||||
auto dV = make_stride(get<2>(args.dV), get<3>(args.dV),
|
||||
make_stride(get<0>(args.dV), get<1>(args.dV)));
|
||||
|
||||
auto problem_size_kv = make_shape(get<3>(problem_size), get<4>(problem_size), 1,
|
||||
make_shape(get<0>(problem_size), get<1>(problem_size)));
|
||||
typename CollectiveEpilogueTMA::Arguments args_k{{}, args.ptr_K, dK, args.ptr_K, dK};
|
||||
typename CollectiveEpilogueTMA::Arguments args_v{{}, args.ptr_V, dV, args.ptr_V, dV};
|
||||
return Params{
|
||||
CollectiveEpilogueTMA::to_underlying_arguments(problem_size_kv, args_k, nullptr),
|
||||
CollectiveEpilogueTMA::to_underlying_arguments(problem_size_kv, args_v, nullptr)
|
||||
};
|
||||
}
|
||||
|
||||
template<class TileShape, class BlkCoord, class ResultTuple, class TiledMma, class ProblemShape>
|
||||
CUTLASS_DEVICE void operator()(
|
||||
TileShape const& tile_shape, BlkCoord const& blk_coord,
|
||||
ResultTuple const& result, TiledMma const& tiled_mma,
|
||||
ProblemShape const& problem_size, Params const& params,
|
||||
LoadPipeline epi_load_pipeline, TensorStorage& epi_tensor_storage)
|
||||
{
|
||||
auto acc_k = get<0>(result);
|
||||
auto acc_v = get<1>(result);
|
||||
|
||||
auto problem_size_kv = make_shape(get<3>(problem_size), get<4>(problem_size), _,
|
||||
make_shape(get<0>(problem_size), get<1>(problem_size)));
|
||||
|
||||
using EpiStorePipeline = typename CollectiveEpilogueTMA::StorePipeline;
|
||||
typename EpiStorePipeline::Params epi_store_pipeline_params;
|
||||
epi_store_pipeline_params.always_wait = true;
|
||||
EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
|
||||
|
||||
typename CollectiveEpilogueTMA::LoadPipelineState epi_load_pipe_consumer_state;
|
||||
PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>();
|
||||
|
||||
CollectiveEpilogueTMA epilogue_k{params.epilogue_K, epi_tensor_storage[0]};
|
||||
CollectiveEpilogueTMA epilogue_v{params.epilogue_V, epi_tensor_storage[1]};
|
||||
|
||||
{
|
||||
auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] =
|
||||
epilogue_k.store(
|
||||
epi_load_pipeline, epi_load_pipe_consumer_state,
|
||||
epi_store_pipeline, epi_store_pipe_producer_state,
|
||||
problem_size_kv, tile_shape, make_coord(get<1>(blk_coord), _0{}, _, get<2>(blk_coord)),
|
||||
acc_k, tiled_mma, threadIdx.x % cutlass::NumThreadsPerWarpGroup,
|
||||
epi_tensor_storage[0]
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
{
|
||||
auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] =
|
||||
epilogue_v.store(
|
||||
epi_load_pipeline, epi_load_pipe_consumer_state,
|
||||
epi_store_pipeline, epi_store_pipe_producer_state,
|
||||
problem_size_kv, tile_shape, make_coord(get<1>(blk_coord), _0{}, _, get<2>(blk_coord)),
|
||||
acc_v, tiled_mma, threadIdx.x % cutlass::NumThreadsPerWarpGroup,
|
||||
epi_tensor_storage[1]
|
||||
);
|
||||
|
||||
epilogue_k.store_tail(
|
||||
epi_load_pipeline, epi_load_pipe_consumer_state_next,
|
||||
epi_store_pipeline, epi_store_pipe_producer_state_next
|
||||
);
|
||||
|
||||
epilogue_v.store_tail(
|
||||
epi_load_pipeline, epi_load_pipe_consumer_state_next,
|
||||
epi_store_pipeline, epi_store_pipe_producer_state_next
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
283
examples/88_hopper_fmha/collective/fmha_fusion.hpp
Normal file
283
examples/88_hopper_fmha/collective/fmha_fusion.hpp
Normal file
@ -0,0 +1,283 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
namespace cutlass::fmha::collective {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
struct DefaultFusion {
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size
|
||||
) {
|
||||
return ceil_div(get<3>(problem_size), get<1>(tile_shape));
|
||||
}
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_masked_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size
|
||||
) {
|
||||
return get_trip_count(blk_coord, tile_shape, problem_size);
|
||||
}
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_unmasked_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size
|
||||
) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template<class AccQK, class IndexQK, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
void before_softmax(
|
||||
AccQK& acc_qk,
|
||||
IndexQK const& index_qk,
|
||||
ProblemSize const& problem_size
|
||||
|
||||
) {
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
struct ResidualFusion : DefaultFusion {
|
||||
|
||||
using Base = DefaultFusion;
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_masked_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size
|
||||
) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_unmasked_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size
|
||||
) {
|
||||
return get_trip_count(blk_coord, tile_shape, problem_size) - 1;
|
||||
}
|
||||
|
||||
template<class AccQK, class IndexQK, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
void before_softmax(
|
||||
AccQK& acc_qk,
|
||||
IndexQK const& index_qk,
|
||||
ProblemSize const& problem_size
|
||||
) {
|
||||
// This is useful is seqlen_k % kBlockN != 0 since it masks
|
||||
// the remaining elements out from softmax.
|
||||
// d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar
|
||||
// issues as they are transparently taken care of by TMA and the
|
||||
// epilogue, if it is instantiated with predication support.
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(acc_qk); i++) {
|
||||
auto pos = index_qk(i);
|
||||
if (get<1>(pos) >= get<3>(problem_size)) {
|
||||
acc_qk(i) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct CausalFusion : DefaultFusion {
|
||||
|
||||
using Base = DefaultFusion;
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size
|
||||
) {
|
||||
// See note below on different ways to think about causal attention
|
||||
// Again, we'd add the offset_q into the max_blocks_q calculation
|
||||
int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size);
|
||||
int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape));
|
||||
return std::min(max_blocks_k, max_blocks_q);
|
||||
}
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_masked_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size
|
||||
) {
|
||||
return ceil_div(get<0>(tile_shape), get<1>(tile_shape));
|
||||
}
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_unmasked_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size
|
||||
) {
|
||||
return get_trip_count(blk_coord, tile_shape, problem_size) - get_masked_trip_count(blk_coord, tile_shape, problem_size);
|
||||
}
|
||||
|
||||
template<class AccQK, class IndexQK, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
void before_softmax(
|
||||
AccQK& acc_qk,
|
||||
IndexQK const& index_qk,
|
||||
ProblemSize const& problem_size
|
||||
) {
|
||||
// There are two ways to do causal if N_Q != N_K
|
||||
// (1) is to assume that the Q is at the beginning of the matrix
|
||||
// - this is what we demonstrate here
|
||||
// (2) is that it is at the end of the matrix
|
||||
// - this is usually what we want for inference settings
|
||||
// where we only compute the next row and use cache for the rest
|
||||
// - if you'd like this, you only need to add an offset like so:
|
||||
// get<0>(pos) + offset_q < get<1>(pos)
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(acc_qk); i++) {
|
||||
auto pos = index_qk(i);
|
||||
if (get<0>(pos) < get<1>(pos)) {
|
||||
acc_qk(i) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template<class Base>
|
||||
struct FusionBwdAdapter {
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size
|
||||
) {
|
||||
return Base{}.get_trip_count(select<1,0,2>(blk_coord), select<1,0,2>(tile_shape), select<0,1,3,2,4>(problem_size));
|
||||
}
|
||||
|
||||
template<class AccQK, class IndexQK, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
void before_softmax(
|
||||
AccQK& acc_qk,
|
||||
IndexQK const& index_qk,
|
||||
ProblemSize const& problem_size
|
||||
) {
|
||||
auto index_base = index_qk(_0{});
|
||||
auto index_shape = shape(index_qk);
|
||||
auto index_stride = transform_leaf(stride(index_qk), [](auto elem) {
|
||||
if constexpr (is_scaled_basis<decltype(elem)>::value) {
|
||||
if constexpr(decltype(elem.mode() == _0{})::value) {
|
||||
return ScaledBasis<decltype(elem.value()), 1>(elem.value());
|
||||
} else {
|
||||
return ScaledBasis<decltype(elem.value()), 0>(elem.value());
|
||||
}
|
||||
} else {
|
||||
return elem;
|
||||
}
|
||||
});
|
||||
auto index_qk_bwd = make_tensor(make_inttuple_iter(select<1,0>(index_base)), make_layout(index_shape, index_stride));
|
||||
Base{}.before_softmax(acc_qk, index_qk_bwd, problem_size);
|
||||
}
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
bool is_contributing(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct FusionBwdAdapter<CausalFusion> {
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size
|
||||
) {
|
||||
return get<2>(problem_size) / get<0>(TileShape{});
|
||||
}
|
||||
|
||||
template<class AccQK, class IndexQK, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
void before_softmax(
|
||||
AccQK& acc_qk,
|
||||
IndexQK const& index_qk,
|
||||
ProblemSize const& problem_size
|
||||
|
||||
) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(acc_qk); i++) {
|
||||
auto pos = index_qk(i);
|
||||
if (get<1>(pos) < get<0>(pos)) {
|
||||
acc_qk(i) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
bool is_contributing(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size
|
||||
) {
|
||||
int max_q = get<0>(blk_coord) * get<0>(tile_shape) + get<0>(tile_shape);
|
||||
int min_k = get<1>(blk_coord) * get<1>(tile_shape);
|
||||
return min_k <= max_q;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
278
examples/88_hopper_fmha/device/device_universal.hpp
Normal file
278
examples/88_hopper_fmha/device/device_universal.hpp
Normal file
@ -0,0 +1,278 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*!
|
||||
\file
|
||||
\brief An universal device layer for cutlass 3.x-style kernels.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// common
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
#include "cutlass/cluster_launch.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
#endif // !defined(__CUDACC_RTC__)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::device {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class Kernel_>
|
||||
class Universal {
|
||||
public:
|
||||
using Kernel = Kernel_;
|
||||
|
||||
static int const kThreadCount = Kernel::MaxThreadsPerBlock;
|
||||
|
||||
/// Argument structure: User API
|
||||
using Arguments = typename Kernel::Arguments;
|
||||
/// Argument structure: Kernel API
|
||||
using Params = typename Kernel::Params;
|
||||
|
||||
private:
|
||||
|
||||
/// Kernel API parameters object
|
||||
Params params_;
|
||||
|
||||
bool is_initialized(bool set = false) {
|
||||
static bool initialized = false;
|
||||
if (set) initialized = true;
|
||||
return initialized;
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Access the Params structure
|
||||
Params const& params() const {
|
||||
return params_;
|
||||
}
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status
|
||||
can_implement(Arguments const& args) {
|
||||
if (Kernel::can_implement(args)) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
else {
|
||||
return Status::kInvalid;
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t
|
||||
get_workspace_size(Arguments const& args) {
|
||||
size_t workspace_bytes = 0;
|
||||
workspace_bytes += Kernel::get_workspace_size(args);
|
||||
return workspace_bytes;
|
||||
}
|
||||
|
||||
/// Computes the grid shape
|
||||
static dim3
|
||||
get_grid_shape(Params const& params) {
|
||||
return Kernel::get_grid_shape(params);
|
||||
}
|
||||
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
static int maximum_active_blocks(int /* smem_capacity */ = -1) {
|
||||
CUTLASS_TRACE_HOST("Universal::maximum_active_blocks()");
|
||||
int max_active_blocks = -1;
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
|
||||
// first, account for dynamic smem capacity if needed
|
||||
cudaError_t result;
|
||||
if (smem_size >= (48 << 10)) {
|
||||
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
|
||||
result = cudaFuncSetAttribute(
|
||||
device_kernel<Kernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaFuncSetAttribute() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
// query occupancy after setting smem size
|
||||
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks,
|
||||
device_kernel<Kernel>,
|
||||
Kernel::MaxThreadsPerBlock,
|
||||
smem_size);
|
||||
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
|
||||
return max_active_blocks;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status
|
||||
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("Universal::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
// Initialize the workspace
|
||||
Status status = Kernel::initialize_workspace(args, workspace, stream);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
// Initialize the Params structure
|
||||
params_ = Kernel::to_underlying_arguments(args, workspace);
|
||||
|
||||
if (is_initialized()) return Status::kSuccess;
|
||||
|
||||
// account for dynamic smem capacity if needed
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
if (smem_size >= (48 << 10)) {
|
||||
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
|
||||
cudaError_t result = cudaFuncSetAttribute(
|
||||
device_kernel<Kernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
is_initialized(true);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Update API is preserved in 3.0, but does not guarantee a lightweight update of params.
|
||||
Status
|
||||
update(Arguments const& args, void* workspace = nullptr) {
|
||||
CUTLASS_TRACE_HOST("Universal()::update() - workspace: " << workspace);
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
if (workspace_bytes > 0 && nullptr == workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
params_ = Kernel::to_underlying_arguments(args, workspace);
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Primary run() entry point API that is static allowing users to create and manage their own params.
|
||||
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
|
||||
static Status
|
||||
run(Params& params, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("Universal::run()");
|
||||
dim3 const block = Kernel::get_block_shape();
|
||||
dim3 const grid = get_grid_shape(params);
|
||||
|
||||
// configure smem size and carveout
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
|
||||
Status launch_result;
|
||||
// Use extended launch API only for mainloops that use it
|
||||
if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) {
|
||||
dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}),
|
||||
cute::size<1>(typename Kernel::ClusterShape{}),
|
||||
cute::size<2>(typename Kernel::ClusterShape{}));
|
||||
void const* kernel = (void const*) device_kernel<Kernel>;
|
||||
void* kernel_params[] = {¶ms};
|
||||
launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params);
|
||||
}
|
||||
else {
|
||||
launch_result = Status::kSuccess;
|
||||
cutlass::arch::synclog_setup();
|
||||
device_kernel<Kernel><<<grid, block, smem_size, stream>>>(params);
|
||||
}
|
||||
|
||||
cudaError_t result = cudaGetLastError();
|
||||
if (cudaSuccess == result && Status::kSuccess == launch_result) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
else {
|
||||
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
|
||||
//
|
||||
|
||||
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
||||
Status
|
||||
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
Status status = initialize(args, workspace, stream);
|
||||
if (Status::kSuccess == status) {
|
||||
status = run(params_, stream);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
||||
Status
|
||||
operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
return run(args, workspace, stream);
|
||||
}
|
||||
|
||||
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
||||
Status
|
||||
run(cudaStream_t stream = nullptr) {
|
||||
return run(params_, stream);
|
||||
}
|
||||
|
||||
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
||||
Status
|
||||
operator()(cudaStream_t stream = nullptr) {
|
||||
return run(params_, stream);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::device
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
299
examples/88_hopper_fmha/device/fmha_device_bwd.hpp
Normal file
299
examples/88_hopper_fmha/device/fmha_device_bwd.hpp
Normal file
@ -0,0 +1,299 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/*!
|
||||
\file
|
||||
\brief An universal device layer for cutlass 3.x-style kernels.
|
||||
*/
|
||||
|
||||
// common
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "../device/device_universal.hpp"
|
||||
#include "../collective/fmha_collective_bwd_tma_warpspecialized.hpp"
|
||||
#include "../collective/fmha_fusion.hpp"
|
||||
#include "../collective/fmha_epilogue_bwd.hpp"
|
||||
#include "../kernel/fmha_kernel_bwd_sum_OdO.hpp"
|
||||
#include "../kernel/fmha_kernel_bwd_convert.hpp"
|
||||
#include "../kernel/fmha_kernel_tma_warpspecialized.hpp"
|
||||
#include "../kernel/fmha_tile_scheduler.hpp"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::fmha::device {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<class Element, class ElementAccumulator, class TileShape, class Fusion, class... Options>
|
||||
class FmhaBwd {
|
||||
public:
|
||||
/// Argument structure: User API
|
||||
struct Arguments {
|
||||
cute::tuple<int, int, int, int, int> problem_size;
|
||||
|
||||
const Element* ptr_Q;
|
||||
cute::tuple<int, int, int, cute::_1> stride_Q;
|
||||
const Element* ptr_K;
|
||||
cute::tuple<int, int, int, cute::_1> stride_K;
|
||||
const Element* ptr_V;
|
||||
cute::tuple<int, int, int, cute::_1> stride_V;
|
||||
|
||||
const Element* ptr_O;
|
||||
cute::tuple<int, int, int, cute::_1> stride_O;
|
||||
const ElementAccumulator* ptr_LSE;
|
||||
cute::tuple<int, int, _1> stride_LSE;
|
||||
|
||||
const Element* ptr_dO;
|
||||
cute::tuple<int, int, int, cute::_1> stride_dO;
|
||||
|
||||
Element* ptr_dQ;
|
||||
cute::tuple<int, int, int, cute::_1> stride_dQ;
|
||||
Element* ptr_dK;
|
||||
cute::tuple<int, int, int, cute::_1> stride_dK;
|
||||
Element* ptr_dV;
|
||||
cute::tuple<int, int, int, cute::_1> stride_dV;
|
||||
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
};
|
||||
|
||||
using OperationSumOdO = cutlass::device::Universal<cutlass::fmha::kernel::FmhaKernelBwdSumOdO<Element, ElementAccumulator>>;
|
||||
using OperationConvert = cutlass::device::Universal<cutlass::fmha::kernel::FmhaKernelBwdConvert<Element, ElementAccumulator>>;
|
||||
|
||||
using Mainloop = cutlass::fmha::collective::FmhaBwdMainloopTmaWarpSpecialized<
|
||||
Element, ElementAccumulator, TileShape,
|
||||
cutlass::fmha::collective::FusionBwdAdapter<Fusion>, Options...>;
|
||||
|
||||
using Epilogue = cutlass::fmha::collective::FmhaBwdEpilogueKV<Element, ElementAccumulator, typename Mainloop::TileShapePV>;
|
||||
|
||||
using Operation = cutlass::device::Universal<
|
||||
cutlass::fmha::kernel::FmhaKernelTmaWarpSpecialized<
|
||||
Mainloop,
|
||||
Epilogue,
|
||||
cutlass::fmha::kernel::TileSchedulerBwdAdapter<cutlass::fmha::kernel::IndividualTileScheduler>, Options...>>;
|
||||
|
||||
struct Params {
|
||||
OperationSumOdO op_sum_OdO;
|
||||
Operation op;
|
||||
OperationConvert op_convert;
|
||||
ElementAccumulator* dQ_acc;
|
||||
size_t dQ_acc_size;
|
||||
};
|
||||
|
||||
private:
|
||||
Params params_;
|
||||
|
||||
static typename OperationSumOdO::Arguments to_sum_OdO_arguments(Arguments const& args, ElementAccumulator* dest = nullptr) {
|
||||
auto [B, H, Q, K, D] = args.problem_size;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
Q = cutlass::round_up(Q, 8); // Alignment
|
||||
auto stride_sum_OdO = make_stride(H*Q, Q, _1{});
|
||||
return typename OperationSumOdO::Arguments {
|
||||
args.problem_size,
|
||||
args.ptr_O, args.stride_O,
|
||||
args.ptr_dO, args.stride_dO,
|
||||
dest, stride_sum_OdO
|
||||
};
|
||||
}
|
||||
|
||||
static typename OperationConvert::Arguments to_convert_arguments(Arguments const& args, ElementAccumulator* src = nullptr) {
|
||||
auto [B, H, Q, K, D] = args.problem_size;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
Q = cutlass::round_up(Q, 8); // Alignment
|
||||
auto stride_src_dQ = make_stride(B == 1 ? 0 : (H*Q*D), Q*D, D, _1{});
|
||||
return typename OperationConvert::Arguments {
|
||||
args.problem_size,
|
||||
src, stride_src_dQ,
|
||||
nullptr, stride_src_dQ,
|
||||
nullptr, stride_src_dQ,
|
||||
args.ptr_dQ, args.stride_dQ,
|
||||
nullptr, args.stride_dK,
|
||||
nullptr, args.stride_dV
|
||||
};
|
||||
}
|
||||
|
||||
static typename Operation::Arguments to_bwd_arguments(
|
||||
Arguments const& args,
|
||||
ElementAccumulator* sum_OdO = nullptr, cute::tuple<int, int, _1> const& stride_sum_OdO = {},
|
||||
ElementAccumulator* dQ_acc = nullptr, cute::tuple<int, int, int, _1> const& stride_dQ = {}
|
||||
) {
|
||||
return typename Operation::Arguments{
|
||||
args.problem_size,
|
||||
{ args.ptr_Q, args.stride_Q,
|
||||
args.ptr_K, args.stride_K,
|
||||
args.ptr_V, args.stride_V,
|
||||
args.ptr_dO, args.stride_dO,
|
||||
args.ptr_LSE, args.stride_LSE,
|
||||
sum_OdO, stride_sum_OdO,
|
||||
dQ_acc, stride_dQ },
|
||||
{ args.ptr_dK, args.stride_dK,
|
||||
args.ptr_dV, args.stride_dV },
|
||||
args.hw_info
|
||||
};
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status
|
||||
can_implement(Arguments const& args) {
|
||||
Status status = Status::kSuccess;
|
||||
|
||||
status = OperationSumOdO::can_implement(to_sum_OdO_arguments(args));
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = OperationConvert::can_implement(to_convert_arguments(args));
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = Operation::can_implement(to_bwd_arguments(args));
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t
|
||||
get_workspace_size(Arguments const& args) {
|
||||
auto [B, H, Q, K, D] = args.problem_size;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
Q = cutlass::round_up(Q, 8); // Alignment
|
||||
size_t workspace_bytes = 0;
|
||||
// OdO vector
|
||||
workspace_bytes += B*H*Q * sizeof(ElementAccumulator);
|
||||
// FP32 versions of outputs that are churned (start off with Q only)
|
||||
workspace_bytes += B*H*Q*D * sizeof(ElementAccumulator);
|
||||
return workspace_bytes;
|
||||
}
|
||||
|
||||
/// Initializes state from arguments.
|
||||
Status
|
||||
initialize_split(Arguments const& args, void* workspace_dQ, void* workspace_sum_OdO, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("Universal::initialize_split() - workspace_dQ="
|
||||
<< workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
auto [B, H, Q, K, D] = args.problem_size;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
Q = cutlass::round_up(Q, 8); // Alignment
|
||||
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_sum_OdO);
|
||||
ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_dQ);
|
||||
params_.dQ_acc = dQ_acc;
|
||||
params_.dQ_acc_size = B*H*Q*D * sizeof(ElementAccumulator);
|
||||
auto args_sum_OdO = to_sum_OdO_arguments(args, sum_OdO);
|
||||
auto args_convert = to_convert_arguments(args, dQ_acc);
|
||||
params_.op_sum_OdO.initialize(args_sum_OdO, nullptr, stream);
|
||||
params_.op_convert.initialize(args_convert, nullptr, stream);
|
||||
auto args_bwd = to_bwd_arguments(args, sum_OdO, args_sum_OdO.stride_sum_OdO, dQ_acc, args_convert.stride_src_dQ);
|
||||
params_.op.initialize(args_bwd, nullptr, stream);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Initializes state from arguments.
|
||||
Status
|
||||
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("Universal::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
auto [B, H, Q, K, D] = args.problem_size;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
Q = cutlass::round_up(Q, 8); // Alignment
|
||||
char* workspace_chr = reinterpret_cast<char*>(workspace);
|
||||
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_chr);
|
||||
workspace_chr += B*H*Q * sizeof(ElementAccumulator);
|
||||
ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_chr);
|
||||
return initialize_split(args, dQ_acc, sum_OdO, stream);
|
||||
}
|
||||
|
||||
/// Primary run() entry point API that is static allowing users to create and manage their own params.
|
||||
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
|
||||
static Status
|
||||
run(Params& params, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("FmhaDeviceBwd::run()");
|
||||
|
||||
Status result = Status::kSuccess;
|
||||
result = params.op_sum_OdO.run(stream);
|
||||
if (result != Status::kSuccess) {
|
||||
return result;
|
||||
}
|
||||
|
||||
auto cuda_result = cudaMemsetAsync(params.dQ_acc, 0, params.dQ_acc_size, stream);
|
||||
if (cuda_result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
result = params.op.run(stream);
|
||||
if (result != Status::kSuccess) {
|
||||
return result;
|
||||
}
|
||||
|
||||
result = params.op_convert.run(stream);
|
||||
if (result != Status::kSuccess) {
|
||||
return result;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
//
|
||||
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
|
||||
//
|
||||
|
||||
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
||||
Status
|
||||
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
Status status = initialize(args, workspace, stream);
|
||||
if (Status::kSuccess == status) {
|
||||
status = run(params_, stream);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
||||
Status
|
||||
run(cudaStream_t stream = nullptr) {
|
||||
return run(params_, stream);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::fmha::device
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
158
examples/88_hopper_fmha/kernel/fmha_kernel_builder.hpp
Normal file
158
examples/88_hopper_fmha/kernel/fmha_kernel_builder.hpp
Normal file
@ -0,0 +1,158 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../collective/fmha_collective_tma.hpp"
|
||||
#include "../collective/fmha_collective_tma_warpspecialized.hpp"
|
||||
#include "../collective/fmha_epilogue.hpp"
|
||||
#include "../kernel/fmha_kernel_tma.hpp"
|
||||
#include "../kernel/fmha_kernel_tma_warpspecialized.hpp"
|
||||
#include "../kernel/fmha_options.hpp"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
template<
|
||||
class Element_,
|
||||
class ElementAccumulatorQK_,
|
||||
class ElementAccumulatorPV_,
|
||||
class TileShape_, // BlockQO, BlockKV, BlockHead
|
||||
class LayoutQ_,
|
||||
class LayoutK_,
|
||||
class LayoutV_,
|
||||
class Fusion,
|
||||
class DispatchPolicy,
|
||||
class... Options
|
||||
>
|
||||
struct FmhaBuilder;
|
||||
|
||||
template<
|
||||
class Element,
|
||||
class ElementAccumulator,
|
||||
class TileShape, // BlockQO, BlockKV, BlockHead
|
||||
class Fusion,
|
||||
class... Options
|
||||
>
|
||||
struct FmhaBuilder<
|
||||
Element,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator,
|
||||
TileShape,
|
||||
cute::tuple<int, _1, cute::tuple<int, int>>,
|
||||
cute::tuple<int, _1, cute::tuple<int, int>>,
|
||||
cute::tuple<int, _1, cute::tuple<int, int>>,
|
||||
Fusion,
|
||||
cutlass::gemm::KernelTma,
|
||||
Options...
|
||||
> {
|
||||
|
||||
using CollectiveMainloop = cutlass::fmha::collective::FmhaMainloopTma<Element, ElementAccumulator, TileShape, Fusion, Options...>;
|
||||
|
||||
using CollectiveEpilogue = cutlass::fmha::collective::FmhaFwdEpilogue<
|
||||
Element, ElementAccumulator, typename CollectiveMainloop::TileShapePV>;
|
||||
|
||||
using Kernel = cutlass::fmha::kernel::FmhaKernelTma<CollectiveMainloop, CollectiveEpilogue, Options...>;
|
||||
};
|
||||
|
||||
template<
|
||||
class Element,
|
||||
class ElementAccumulatorQK,
|
||||
class ElementAccumulatorPV,
|
||||
class TileShape, // BlockQO, BlockKV, BlockHead
|
||||
class LayoutQ,
|
||||
class LayoutK,
|
||||
class LayoutV,
|
||||
class Fusion,
|
||||
class... Options
|
||||
>
|
||||
struct FmhaBuilder<
|
||||
Element,
|
||||
ElementAccumulatorQK,
|
||||
ElementAccumulatorPV,
|
||||
TileShape,
|
||||
LayoutQ,
|
||||
LayoutK,
|
||||
LayoutV,
|
||||
Fusion,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative,
|
||||
Options...
|
||||
> {
|
||||
|
||||
using CollectiveMainloop = cutlass::fmha::collective::FmhaMainloopTmaWarpSpecialized<
|
||||
Element, ElementAccumulatorQK, ElementAccumulatorPV,
|
||||
TileShape, LayoutQ, LayoutK, LayoutV,
|
||||
Fusion, Options...>;
|
||||
|
||||
using CollectiveEpilogue = cutlass::fmha::collective::FmhaFwdEpilogue<
|
||||
Element, ElementAccumulatorPV, typename CollectiveMainloop::TileShapePV>;
|
||||
|
||||
static constexpr bool kIsPersistent = find_option_t<Tag::kIsPersistent, false_type, Options...>::value;
|
||||
using TileScheduler = std::conditional_t<kIsPersistent, cutlass::fmha::kernel::PersistentTileScheduler, cutlass::fmha::kernel::IndividualTileScheduler>;
|
||||
|
||||
using Kernel = cutlass::fmha::kernel::FmhaKernelTmaWarpSpecialized<CollectiveMainloop, CollectiveEpilogue, TileScheduler, Options...>;
|
||||
};
|
||||
|
||||
template<
|
||||
class Element,
|
||||
class ElementAccumulatorQK,
|
||||
class ElementAccumulatorPV,
|
||||
class TileShape, // BlockQO, BlockKV, BlockHead
|
||||
class LayoutQ,
|
||||
class LayoutK,
|
||||
class LayoutV,
|
||||
class Fusion,
|
||||
class... Options
|
||||
>
|
||||
struct FmhaBuilder<
|
||||
Element,
|
||||
ElementAccumulatorQK,
|
||||
ElementAccumulatorPV,
|
||||
TileShape,
|
||||
LayoutQ,
|
||||
LayoutK,
|
||||
LayoutV,
|
||||
Fusion,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpong,
|
||||
Options...
|
||||
> {
|
||||
using Kernel = typename FmhaBuilder<
|
||||
Element, ElementAccumulatorQK, ElementAccumulatorPV,
|
||||
TileShape,
|
||||
LayoutQ, LayoutK, LayoutV,
|
||||
Fusion,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative,
|
||||
Options...,
|
||||
Option<Tag::kIsPersistent, true_type>,
|
||||
Option<Tag::kLoadsQSeparately, true_type>
|
||||
>::Kernel;
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
||||
143
examples/88_hopper_fmha/kernel/fmha_kernel_bwd_convert.hpp
Normal file
143
examples/88_hopper_fmha/kernel/fmha_kernel_bwd_convert.hpp
Normal file
@ -0,0 +1,143 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/layout.hpp"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<class Element, class ElementAccumulator>
|
||||
struct FmhaKernelBwdConvert {
|
||||
|
||||
struct Arguments {
|
||||
tuple<int, int, int, int, int> problem_size;
|
||||
|
||||
const ElementAccumulator* ptr_src_dQ;
|
||||
tuple<int, int, int, _1> stride_src_dQ;
|
||||
const ElementAccumulator* ptr_src_dK;
|
||||
tuple<int, int, int, _1> stride_src_dK;
|
||||
const ElementAccumulator* ptr_src_dV;
|
||||
tuple<int, int, int, _1> stride_src_dV;
|
||||
|
||||
Element* ptr_dest_dQ;
|
||||
tuple<int, int, int, _1> stride_dest_dQ;
|
||||
Element* ptr_dest_dK;
|
||||
tuple<int, int, int, _1> stride_dest_dK;
|
||||
Element* ptr_dest_dV;
|
||||
tuple<int, int, int, _1> stride_dest_dV;
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
static constexpr int SharedStorageSize = 0;
|
||||
|
||||
static const int MinBlocksPerMultiprocessor = 1;
|
||||
static const int MaxThreadsPerBlock = 128;
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
|
||||
static const int kBlockSeq = 8;
|
||||
|
||||
static size_t get_workspace_size(Arguments const& args) { return 0; }
|
||||
static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
static const int kNumThreadsD = 16;
|
||||
static const int kNumThreadsSeq = MaxThreadsPerBlock / kNumThreadsD;
|
||||
static const int kElementsPerLoad = 4;
|
||||
|
||||
static const int kIterationsSeq = kBlockSeq / kNumThreadsSeq;
|
||||
|
||||
static bool can_implement(Arguments const& args) {
|
||||
return get<4>(args.problem_size) % kElementsPerLoad == 0;
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
dim3 grid(size<0>(params.problem_size), size<1>(params.problem_size), ceil_div(std::max(size<2>(params.problem_size), size<3>(params.problem_size)), kBlockSeq));
|
||||
return grid;
|
||||
}
|
||||
|
||||
static dim3 get_block_shape() {
|
||||
dim3 block(kNumThreadsD, kNumThreadsSeq, 1);
|
||||
return block;
|
||||
}
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template<class StrideSrc, class StrideDest>
|
||||
CUTLASS_DEVICE void copy(Params const& params, const ElementAccumulator* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, int count) {
|
||||
auto ptr_src_bh = ptr_src + get<0>(stride_src) * blockIdx.x + get<1>(stride_src) * blockIdx.y;
|
||||
auto ptr_dest_bh = ptr_dest + get<0>(stride_dest) * blockIdx.x + get<1>(stride_dest) * blockIdx.y;
|
||||
|
||||
for (int idx_s_t = threadIdx.y; idx_s_t < kBlockSeq; idx_s_t += kNumThreadsSeq) {
|
||||
int idx_s = idx_s_t + kBlockSeq * blockIdx.z;
|
||||
if (idx_s >= count) continue;
|
||||
auto ptr_src_bhs = ptr_src_bh + idx_s * get<2>(stride_src);
|
||||
auto ptr_dest_bhs = ptr_dest_bh + idx_s * get<2>(stride_dest);
|
||||
|
||||
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<4>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) {
|
||||
ElementAccumulator value_src[kElementsPerLoad];
|
||||
Element value_dest[kElementsPerLoad];
|
||||
|
||||
using VecSrc = uint_bit_t<sizeof_bits_v<ElementAccumulator> * kElementsPerLoad>;
|
||||
using VecDest = uint_bit_t<sizeof_bits_v<Element> * kElementsPerLoad>;
|
||||
*reinterpret_cast<VecSrc*>(value_src) = *reinterpret_cast<const VecSrc*>(&ptr_src_bhs[idx_d]);
|
||||
|
||||
for (int v = 0; v < kElementsPerLoad; v++) {
|
||||
value_dest[v] = value_src[v];
|
||||
}
|
||||
|
||||
*reinterpret_cast<VecDest*>(&ptr_dest_bhs[idx_d]) = *reinterpret_cast<const VecDest*>(value_dest);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) {
|
||||
if (params.ptr_src_dQ != nullptr) {
|
||||
copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<2>(params.problem_size));
|
||||
}
|
||||
if (params.ptr_src_dK != nullptr) {
|
||||
copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<3>(params.problem_size));
|
||||
}
|
||||
if (params.ptr_src_dV != nullptr) {
|
||||
copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<3>(params.problem_size));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
||||
134
examples/88_hopper_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp
Normal file
134
examples/88_hopper_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp
Normal file
@ -0,0 +1,134 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/layout.hpp"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<class Element, class ElementAccumulator>
|
||||
struct FmhaKernelBwdSumOdO {
|
||||
|
||||
struct Arguments {
|
||||
cute::tuple<int, int, int, int, int> problem_size;
|
||||
|
||||
const Element* ptr_O;
|
||||
cute::tuple<int, int, int, cute::_1> stride_O;
|
||||
const Element* ptr_dO;
|
||||
cute::tuple<int, int, int, cute::_1> stride_dO;
|
||||
|
||||
ElementAccumulator* ptr_sum_OdO;
|
||||
cute::tuple<int, int, _1> stride_sum_OdO;
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
static constexpr int SharedStorageSize = 0;
|
||||
|
||||
static const int MinBlocksPerMultiprocessor = 1;
|
||||
static const int MaxThreadsPerBlock = 128;
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
|
||||
static size_t get_workspace_size(Arguments const& args) { return 0; }
|
||||
static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
static const int kBlockQ = 16;
|
||||
|
||||
static const int kNumThreadsD = 8;
|
||||
static const int kNumThreadsQ = MaxThreadsPerBlock / kNumThreadsD;
|
||||
static const int kElementsPerLoad = 2;
|
||||
|
||||
static const int kIterationsQ = kBlockQ / kNumThreadsQ;
|
||||
|
||||
static bool can_implement(Arguments const& args) {
|
||||
return get<4>(args.problem_size) % kElementsPerLoad == 0;
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
dim3 grid(ceil_div(size<2>(params.problem_size), kBlockQ), size<1>(params.problem_size), size<0>(params.problem_size));
|
||||
return grid;
|
||||
}
|
||||
|
||||
static dim3 get_block_shape() {
|
||||
dim3 block(kNumThreadsD, kNumThreadsQ, 1);
|
||||
return block;
|
||||
}
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) {
|
||||
auto ptr_O_bh = params.ptr_O + blockIdx.y * get<1>(params.stride_O) + blockIdx.z * get<0>(params.stride_O);
|
||||
auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<1>(params.stride_dO) + blockIdx.z * get<0>(params.stride_dO);
|
||||
auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1>(params.stride_sum_OdO) + blockIdx.z * get<0>(params.stride_sum_OdO);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int idx_q_t = threadIdx.y; idx_q_t < kBlockQ; idx_q_t += kNumThreadsQ) {
|
||||
int idx_q = idx_q_t + kBlockQ * blockIdx.x;
|
||||
if (idx_q >= get<2>(params.problem_size)) continue;
|
||||
ElementAccumulator acc = 0;
|
||||
auto ptr_O_bhq = ptr_O_bh + idx_q * get<2>(params.stride_O);
|
||||
auto ptr_dO_bhq = ptr_dO_bh + idx_q * get<2>(params.stride_dO);
|
||||
auto ptr_sum_OdO_bhq = ptr_sum_OdO_bh + idx_q * get<2>(params.stride_sum_OdO);
|
||||
|
||||
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<4>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) {
|
||||
Element value_O[kElementsPerLoad];
|
||||
Element value_dO[kElementsPerLoad];
|
||||
|
||||
using Vec = uint_bit_t<sizeof_bits_v<Element> * kElementsPerLoad>;
|
||||
*reinterpret_cast<Vec*>(value_O) = *reinterpret_cast<const Vec*>(&ptr_O_bhq[idx_d]);
|
||||
*reinterpret_cast<Vec*>(value_dO) = *reinterpret_cast<const Vec*>(&ptr_dO_bhq[idx_d]);
|
||||
|
||||
for (int v = 0; v < kElementsPerLoad; v++) {
|
||||
acc += value_O[v] * value_dO[v];
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 1; i < kNumThreadsD; i *= 2) {
|
||||
acc += __shfl_xor_sync((uint32_t)-1, acc, i, kNumThreadsD);
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
*ptr_sum_OdO_bhq = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
||||
222
examples/88_hopper_fmha/kernel/fmha_kernel_tma.hpp
Normal file
222
examples/88_hopper_fmha/kernel/fmha_kernel_tma.hpp
Normal file
@ -0,0 +1,222 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cutlass/arch/arch.h"
|
||||
|
||||
#include "../kernel/fmha_tile_scheduler.hpp"
|
||||
#include "../kernel/fmha_options.hpp"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
template<
|
||||
class CollectiveMainloop,
|
||||
class CollectiveEpilogue,
|
||||
class... Options
|
||||
>
|
||||
struct FmhaKernelTma {
|
||||
|
||||
// Options
|
||||
static constexpr int kBlocksPerSM = find_option_t<Tag::kBlocksPerSM, Int<2>, Options...>::value;
|
||||
|
||||
using Element = typename CollectiveMainloop::Element;
|
||||
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
|
||||
|
||||
using TileScheduler = IndividualTileScheduler;
|
||||
|
||||
using StagesQ = typename CollectiveMainloop::StagesQ;
|
||||
using Stages = typename CollectiveMainloop::Stages;
|
||||
|
||||
using TileShape = typename CollectiveMainloop::TileShape;
|
||||
using ClusterShape = typename CollectiveMainloop::ClusterShape;
|
||||
|
||||
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
|
||||
using MainloopPipelineQ = typename CollectiveMainloop::MainloopPipelineQ;
|
||||
|
||||
using SmemLayoutQ = typename CollectiveMainloop::SmemLayoutQ;
|
||||
using SmemLayoutK = typename CollectiveMainloop::SmemLayoutK;
|
||||
|
||||
struct SharedStorage {
|
||||
union {
|
||||
typename CollectiveMainloop::SharedStorage mainloop;
|
||||
typename CollectiveEpilogue::TensorStorage epilogue;
|
||||
};
|
||||
|
||||
using PipelineStorage = typename MainloopPipeline::SharedStorage;
|
||||
using PipelineStorageQ = typename MainloopPipelineQ::SharedStorage;
|
||||
alignas(16) PipelineStorage pipeline_storage;
|
||||
alignas(16) PipelineStorageQ pipeline_storage_q;
|
||||
|
||||
using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage;
|
||||
alignas(16) EpiLoadPipelineStorage epi_load;
|
||||
};
|
||||
|
||||
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
||||
|
||||
using ProblemShape = cute::tuple<int, int, int, int, int>;
|
||||
|
||||
struct Arguments {
|
||||
ProblemShape problem_size;
|
||||
typename CollectiveMainloop::Arguments mainloop;
|
||||
typename CollectiveEpilogue::Arguments epilogue;
|
||||
KernelHardwareInfo hw_info;
|
||||
};
|
||||
|
||||
struct Params {
|
||||
ProblemShape problem_size;
|
||||
typename CollectiveMainloop::Params mainloop;
|
||||
typename CollectiveEpilogue::Params epilogue;
|
||||
typename TileScheduler::Params tile_scheduler;
|
||||
};
|
||||
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
using PipelineState = typename cutlass::PipelineState<MainloopPipeline::Stages>;
|
||||
using PipelineParamsQ = typename MainloopPipelineQ::Params;
|
||||
using PipelineStateQ = typename cutlass::PipelineState<MainloopPipelineQ::Stages>;
|
||||
|
||||
static const int MinBlocksPerMultiprocessor = kBlocksPerSM;
|
||||
static const int MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock;
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
|
||||
static size_t get_workspace_size(Arguments const& args) { return 0; }
|
||||
static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
static bool can_implement(Arguments const& args) {
|
||||
return CollectiveMainloop::can_implement(args.problem_size, args.mainloop);
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
return TileScheduler::get_grid_shape(params.tile_scheduler);
|
||||
}
|
||||
|
||||
static dim3 get_block_shape() {
|
||||
dim3 block(MaxThreadsPerBlock, 1, 1);
|
||||
return block;
|
||||
}
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||
return Params{
|
||||
args.problem_size,
|
||||
CollectiveMainloop::to_underlying_arguments(args.problem_size, args.mainloop, workspace),
|
||||
CollectiveEpilogue::to_underlying_arguments(args.problem_size, args.epilogue, workspace),
|
||||
TileScheduler::to_underlying_arguments(args.problem_size, args.hw_info, ClusterShape{}, TileShape{})
|
||||
};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) {
|
||||
TileScheduler tile_scheduler{params.tile_scheduler};
|
||||
|
||||
// Shared memory.
|
||||
auto& storage = *reinterpret_cast<SharedStorage*>(smem);
|
||||
|
||||
int thread_idx = int(threadIdx.x);
|
||||
|
||||
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
|
||||
|
||||
int warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
int warp_group_thread_idx = thread_idx % cutlass::NumThreadsPerWarpGroup;
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
// Issue Tma Descriptor Prefetch from a single thread
|
||||
if ((warp_idx == 0) && lane_predicate) {
|
||||
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
|
||||
}
|
||||
|
||||
|
||||
PipelineParamsQ pipeline_params_q;
|
||||
pipeline_params_q.transaction_bytes = size(SmemLayoutQ{}(_,_,_0{})) * sizeof(Element); // Q
|
||||
pipeline_params_q.role = MainloopPipelineQ::ThreadCategory::ProducerConsumer;
|
||||
pipeline_params_q.is_leader = warp_group_thread_idx == 0;
|
||||
pipeline_params_q.num_consumers = cutlass::NumThreadsPerWarpGroup;
|
||||
|
||||
PipelineParams pipeline_params;
|
||||
pipeline_params.transaction_bytes = size(SmemLayoutK{}(_,_,_0{})) * sizeof(Element); // KV
|
||||
pipeline_params.role = MainloopPipeline::ThreadCategory::ProducerConsumer;
|
||||
pipeline_params.is_leader = warp_group_thread_idx == 0;
|
||||
pipeline_params.num_consumers = cutlass::NumThreadsPerWarpGroup;
|
||||
|
||||
MainloopPipelineQ pipeline_q(storage.pipeline_storage_q, pipeline_params_q, Shape<_1, _1, _1>{});
|
||||
MainloopPipeline pipeline(storage.pipeline_storage, pipeline_params, ClusterShape{});
|
||||
|
||||
using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
|
||||
typename EpiLoadPipeline::Params epi_load_pipeline_params;
|
||||
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::ProducerConsumer;
|
||||
epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster();
|
||||
epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp;
|
||||
epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup;
|
||||
epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes;
|
||||
EpiLoadPipeline epi_load_pipeline(storage.epi_load, epi_load_pipeline_params);
|
||||
|
||||
// State variables used for iterating the circular buffer
|
||||
// smem_pipe_read / release is used by the consumer of SMEM data - i.e MMA
|
||||
// smem_pipe_write is used by the producer of SMEM data - i.e TMA
|
||||
PipelineState smem_pipe_read;
|
||||
PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
|
||||
PipelineStateQ smem_pipe_read_q;
|
||||
PipelineStateQ smem_pipe_write_q = cutlass::make_producer_start_state<MainloopPipelineQ>();
|
||||
|
||||
// We need this to guarantee that the Pipeline init is visible
|
||||
// To all producers and consumer blocks in the Cluster
|
||||
// and to finish smem init
|
||||
if constexpr (size(ClusterShape{}) > 1) {
|
||||
cute::cluster_arrive_relaxed();
|
||||
cute::cluster_wait();
|
||||
}
|
||||
else {
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
|
||||
CollectiveMainloop collective_mainloop;
|
||||
auto result = collective_mainloop.compute(
|
||||
block_rank_in_cluster,
|
||||
blk_coord, params.mainloop, params.problem_size,
|
||||
pipeline, smem_pipe_read, smem_pipe_write,
|
||||
pipeline_q, smem_pipe_read_q, smem_pipe_write_q,
|
||||
storage.mainloop
|
||||
);
|
||||
|
||||
CollectiveEpilogue epilogue;
|
||||
epilogue(typename CollectiveMainloop::TileShapePV{}, blk_coord,
|
||||
result, typename CollectiveMainloop::TiledMmaPV{},
|
||||
params.problem_size, params.epilogue,
|
||||
epi_load_pipeline, storage.epilogue);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
||||
@ -0,0 +1,418 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/reg_reconfig.h"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cutlass/arch/arch.h"
|
||||
|
||||
#include "../kernel/fmha_options.hpp"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<
|
||||
class CollectiveMainloop,
|
||||
class CollectiveEpilogue,
|
||||
class TileScheduler,
|
||||
class... Options
|
||||
>
|
||||
struct FmhaKernelTmaWarpSpecialized {
|
||||
|
||||
// Options
|
||||
static constexpr bool kIsEpilogueLocked = find_option_t<Tag::kIsEpilogueLocked, false_type, Options...>::value;
|
||||
static constexpr bool kLoadsQSeparately = find_option_t<Tag::kLoadsQSeparately, false_type, Options...>::value;
|
||||
|
||||
|
||||
static const int NumLoadWarpGroups = 1;
|
||||
static constexpr int NumMmaWarpGroups = CollectiveMainloop::NumMmaWarpGroups;
|
||||
|
||||
using TileShape = typename CollectiveMainloop::TileShape;
|
||||
using ClusterShape = typename CollectiveMainloop::ClusterShape;
|
||||
|
||||
using MainloopPipelineOuter = typename CollectiveMainloop::MainloopPipelineQ;
|
||||
using MainloopPipelineInner = typename CollectiveMainloop::MainloopPipeline;
|
||||
using MainloopPipelineReducer = cutlass::PipelineAsync<2>;
|
||||
|
||||
static constexpr uint32_t StagesPerMathWarpGroup = 2;
|
||||
using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier<
|
||||
StagesPerMathWarpGroup, NumMmaWarpGroups>;
|
||||
|
||||
struct TensorStorageStruct {
|
||||
typename CollectiveMainloop::SharedStorage mainloop;
|
||||
typename CollectiveEpilogue::TensorStorage epilogue[NumMmaWarpGroups];
|
||||
};
|
||||
union TensorStorageUnion {
|
||||
typename CollectiveMainloop::SharedStorage mainloop;
|
||||
typename CollectiveEpilogue::TensorStorage epilogue[NumMmaWarpGroups];
|
||||
};
|
||||
using TensorStorage = std::conditional_t<CollectiveMainloop::kIsPersistent, TensorStorageStruct, TensorStorageUnion>;
|
||||
|
||||
struct SharedStorage {
|
||||
TensorStorage tensors;
|
||||
|
||||
using PipelineStorageInner = typename MainloopPipelineInner::SharedStorage;
|
||||
using PipelineStorageOuter = typename MainloopPipelineOuter::SharedStorage;
|
||||
using PipelineStorageReducer = typename MainloopPipelineReducer::SharedStorage;
|
||||
|
||||
alignas(16) PipelineStorageInner pipeline_storage_inner;
|
||||
alignas(16) PipelineStorageOuter pipeline_storage_outer;
|
||||
alignas(16) PipelineStorageReducer pipeline_storage_reducer;
|
||||
|
||||
using MathWarpGroupOrderBarrierStorage = typename MathWarpGroupOrderBarrier::SharedStorage;
|
||||
alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order;
|
||||
|
||||
alignas(16) cutlass::arch::ClusterBarrier load_warp_barrier;
|
||||
|
||||
using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage;
|
||||
alignas(16) EpiLoadPipelineStorage epi_load;
|
||||
};
|
||||
|
||||
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
||||
|
||||
using ProblemShape = cute::tuple<int, int, int, int, int>;
|
||||
|
||||
struct Arguments {
|
||||
ProblemShape problem_size;
|
||||
typename CollectiveMainloop::Arguments mainloop;
|
||||
typename CollectiveEpilogue::Arguments epilogue;
|
||||
KernelHardwareInfo hw_info;
|
||||
};
|
||||
|
||||
struct Params {
|
||||
ProblemShape problem_size;
|
||||
typename CollectiveMainloop::Params mainloop;
|
||||
typename CollectiveEpilogue::Params epilogue;
|
||||
typename TileScheduler::Params tile_scheduler;
|
||||
};
|
||||
|
||||
using PipelineParamsInner = typename MainloopPipelineInner::Params;
|
||||
using PipelineStateInner = typename cutlass::PipelineState<MainloopPipelineInner::Stages>;
|
||||
using PipelineParamsOuter = typename MainloopPipelineOuter::Params;
|
||||
using PipelineStateOuter = typename cutlass::PipelineState<MainloopPipelineOuter::Stages>;
|
||||
using PipelineParamsReducer = typename MainloopPipelineReducer::Params;
|
||||
using PipelineStateReducer = typename cutlass::PipelineState<MainloopPipelineReducer::Stages>;
|
||||
|
||||
static const int MinBlocksPerMultiprocessor = 1;
|
||||
static const int MaxThreadsPerBlock = (NumMmaWarpGroups + NumLoadWarpGroups) * cutlass::NumThreadsPerWarpGroup;
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
|
||||
static constexpr uint32_t LoadRegisterRequirement = 40 - 2 * 8;
|
||||
static constexpr uint32_t TotalRegisterSupply = (64*1024 / MaxThreadsPerBlock / MinBlocksPerMultiprocessor / 8) * 8 * MaxThreadsPerBlock / cutlass::NumThreadsPerWarpGroup;
|
||||
static constexpr uint32_t MmaRegisterRequirement = ((TotalRegisterSupply - LoadRegisterRequirement) / NumMmaWarpGroups / 8) * 8;
|
||||
|
||||
static size_t get_workspace_size(Arguments const& args) { return 0; }
|
||||
static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
static bool can_implement(Arguments const& args) {
|
||||
return CollectiveMainloop::can_implement(args.problem_size, args.mainloop);
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
return TileScheduler::get_grid_shape(params.tile_scheduler);
|
||||
}
|
||||
|
||||
static dim3 get_block_shape() {
|
||||
dim3 block(MaxThreadsPerBlock, 1, 1);
|
||||
return block;
|
||||
}
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||
return Params{
|
||||
args.problem_size,
|
||||
CollectiveMainloop::to_underlying_arguments(args.problem_size, args.mainloop, workspace),
|
||||
CollectiveEpilogue::to_underlying_arguments(args.problem_size, args.epilogue, workspace),
|
||||
TileScheduler::to_underlying_arguments(args.problem_size, args.hw_info, ClusterShape{}, TileShape{})
|
||||
};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) {
|
||||
|
||||
enum class WarpGroupRole {
|
||||
Producer = 0,
|
||||
Consumer0 = 1,
|
||||
Consumer1 = 2,
|
||||
Consumer2 = 3,
|
||||
Consumer3 = 4,
|
||||
};
|
||||
enum class ProducerWarpRole {
|
||||
LoadKV = 1,
|
||||
Reducer = 0,
|
||||
MaybeLoadQ = 2, // is kLoadsQSeparately is true, this warp loads Q (otherwise warp 0 does it)
|
||||
MainloopEpilogue = 3,
|
||||
};
|
||||
|
||||
static constexpr ProducerWarpRole WarpRoleLoadQ = kLoadsQSeparately ? ProducerWarpRole::MaybeLoadQ : ProducerWarpRole::LoadKV;
|
||||
|
||||
TileScheduler tile_scheduler{params.tile_scheduler};
|
||||
|
||||
// Shared memory.
|
||||
auto& storage = *reinterpret_cast<SharedStorage*>(smem);
|
||||
|
||||
int lane_idx = cutlass::canonical_lane_idx();
|
||||
int warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
int warp_idx_in_warp_group = warp_idx % cutlass::NumWarpsPerWarpGroup;
|
||||
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
||||
auto warp_group_role = WarpGroupRole(warp_group_idx);
|
||||
auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group);
|
||||
int consumer_warp_group_idx = warp_group_idx - (int) WarpGroupRole::Consumer0;
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
|
||||
|
||||
// Issue Tma Descriptor Prefetch from a single thread
|
||||
if ((warp_idx == 0) && lane_predicate) {
|
||||
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
|
||||
}
|
||||
|
||||
PipelineParamsOuter pipeline_params_outer;
|
||||
pipeline_params_outer.transaction_bytes = CollectiveMainloop::kOuterLoadBytes;
|
||||
pipeline_params_outer.is_leader = lane_predicate && (producer_warp_role == WarpRoleLoadQ);
|
||||
pipeline_params_outer.num_consumers = cutlass::NumThreadsPerWarpGroup;
|
||||
|
||||
PipelineParamsInner pipeline_params_inner;
|
||||
pipeline_params_inner.transaction_bytes = CollectiveMainloop::kInnerLoadBytes;
|
||||
pipeline_params_inner.is_leader = lane_predicate && (producer_warp_role == ProducerWarpRole::LoadKV);
|
||||
pipeline_params_inner.num_consumers = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;
|
||||
|
||||
PipelineParamsReducer pipeline_params_reducer;
|
||||
pipeline_params_reducer.producer_arv_count = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;
|
||||
pipeline_params_reducer.consumer_arv_count = cutlass::NumThreadsPerWarp;
|
||||
|
||||
using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
|
||||
typename EpiLoadPipeline::Params epi_load_pipeline_params;
|
||||
|
||||
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::MainloopEpilogue) {
|
||||
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer;
|
||||
}
|
||||
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::LoadKV) {
|
||||
pipeline_params_inner.role = MainloopPipelineInner::ThreadCategory::Producer;
|
||||
}
|
||||
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == WarpRoleLoadQ) {
|
||||
pipeline_params_outer.role = MainloopPipelineOuter::ThreadCategory::Producer;
|
||||
}
|
||||
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Reducer) {
|
||||
pipeline_params_reducer.role = MainloopPipelineReducer::ThreadCategory::Consumer;
|
||||
}
|
||||
if (warp_group_role == WarpGroupRole::Consumer0 ||
|
||||
warp_group_role == WarpGroupRole::Consumer1 ||
|
||||
warp_group_role == WarpGroupRole::Consumer2 ||
|
||||
warp_group_role == WarpGroupRole::Consumer3
|
||||
) {
|
||||
pipeline_params_inner.role = MainloopPipelineInner::ThreadCategory::Consumer;
|
||||
pipeline_params_outer.role = MainloopPipelineOuter::ThreadCategory::Consumer;
|
||||
pipeline_params_reducer.role = MainloopPipelineReducer::ThreadCategory::Producer;
|
||||
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer;
|
||||
}
|
||||
|
||||
MainloopPipelineOuter pipeline_outer(storage.pipeline_storage_outer, pipeline_params_outer, Shape<_1, _1, _1>{});
|
||||
MainloopPipelineInner pipeline_inner(storage.pipeline_storage_inner, pipeline_params_inner, ClusterShape{});
|
||||
MainloopPipelineReducer pipeline_reducer(storage.pipeline_storage_reducer, pipeline_params_reducer);
|
||||
|
||||
// State variables used for iterating the circular buffer
|
||||
// smem_pipe_read / release is used by the consumer of SMEM data - i.e MMA
|
||||
// smem_pipe_write is used by the producer of SMEM data - i.e TMA
|
||||
PipelineStateInner smem_pipe_read_inner;
|
||||
PipelineStateInner smem_pipe_write_inner = cutlass::make_producer_start_state<MainloopPipelineInner>();
|
||||
|
||||
PipelineStateOuter smem_pipe_read_outer;
|
||||
PipelineStateOuter smem_pipe_write_outer = cutlass::make_producer_start_state<MainloopPipelineOuter>();
|
||||
|
||||
PipelineStateReducer smem_pipe_read_reducer;
|
||||
PipelineStateReducer smem_pipe_write_reducer = cutlass::make_producer_start_state<MainloopPipelineReducer>();
|
||||
|
||||
typename MathWarpGroupOrderBarrier::Params params_math_wg_order_barrier;
|
||||
// DMA Load WG will not participate in these Ordered Barrier syncs
|
||||
params_math_wg_order_barrier.group_id = consumer_warp_group_idx;
|
||||
params_math_wg_order_barrier.group_size = cutlass::NumThreadsPerWarpGroup; // Number of threads / participants in a group
|
||||
MathWarpGroupOrderBarrier math_wg_order_barrier(storage.math_wg_order, params_math_wg_order_barrier);
|
||||
|
||||
// Epilogue Load pipeline
|
||||
epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster();
|
||||
epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp;
|
||||
epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup;
|
||||
epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes;
|
||||
EpiLoadPipeline epi_load_pipeline(storage.epi_load, epi_load_pipeline_params);
|
||||
|
||||
if constexpr (kLoadsQSeparately) {
|
||||
if ((warp_idx == 0) && lane_predicate) {
|
||||
storage.load_warp_barrier.init(2 * cutlass::NumThreadsPerWarp);
|
||||
}
|
||||
cutlass::arch::fence_barrier_init();
|
||||
}
|
||||
|
||||
// We need this to guarantee that the Pipeline init is visible
|
||||
// To all producers and consumer blocks in the Cluster
|
||||
// and to finish smem init
|
||||
if constexpr (size(ClusterShape{}) > 1) {
|
||||
cute::cluster_arrive_relaxed();
|
||||
cute::cluster_wait();
|
||||
}
|
||||
else {
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
CollectiveMainloop collective_mainloop;
|
||||
|
||||
if (warp_group_role == WarpGroupRole::Producer) {
|
||||
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
|
||||
if (producer_warp_role == ProducerWarpRole::LoadKV) {
|
||||
bool do_barrier = kLoadsQSeparately;
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
collective_mainloop.template load_kv_maybe_q<!kLoadsQSeparately>(
|
||||
block_rank_in_cluster,
|
||||
blk_coord, params.mainloop, params.problem_size,
|
||||
pipeline_inner, smem_pipe_write_inner,
|
||||
pipeline_outer, smem_pipe_write_outer,
|
||||
storage.tensors.mainloop,
|
||||
storage.load_warp_barrier, do_barrier
|
||||
);
|
||||
do_barrier = false;
|
||||
}
|
||||
}
|
||||
else if (kLoadsQSeparately && (producer_warp_role == ProducerWarpRole::MaybeLoadQ)) {
|
||||
bool do_barrier = true;
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
collective_mainloop.load_maybe_q(
|
||||
blk_coord, params.mainloop, params.problem_size,
|
||||
pipeline_outer, smem_pipe_write_outer,
|
||||
storage.tensors.mainloop,
|
||||
storage.load_warp_barrier, do_barrier
|
||||
);
|
||||
do_barrier = false;
|
||||
}
|
||||
} else if (producer_warp_role == ProducerWarpRole::Reducer) {
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
collective_mainloop.reduce(
|
||||
blk_coord, params.mainloop, params.problem_size,
|
||||
pipeline_reducer, smem_pipe_read_reducer,
|
||||
storage.tensors.mainloop
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (
|
||||
warp_group_role == WarpGroupRole::Consumer0 ||
|
||||
warp_group_role == WarpGroupRole::Consumer1 ||
|
||||
warp_group_role == WarpGroupRole::Consumer2 ||
|
||||
warp_group_role == WarpGroupRole::Consumer3
|
||||
) {
|
||||
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
auto wg_coord = blk_coord;
|
||||
|
||||
constexpr int kOuterLoads = CollectiveMainloop::kOuterLoads;
|
||||
|
||||
if (warp_group_role == WarpGroupRole::Consumer0) {
|
||||
smem_pipe_read_outer.advance(0 * kOuterLoads);
|
||||
}
|
||||
else if (warp_group_role == WarpGroupRole::Consumer1) {
|
||||
smem_pipe_read_outer.advance(1 * kOuterLoads);
|
||||
}
|
||||
else if (warp_group_role == WarpGroupRole::Consumer2) {
|
||||
smem_pipe_read_outer.advance(2 * kOuterLoads);
|
||||
}
|
||||
else if (warp_group_role == WarpGroupRole::Consumer3) {
|
||||
smem_pipe_read_outer.advance(3 * kOuterLoads);
|
||||
}
|
||||
|
||||
constexpr int wg_dim = is_constant<0, decltype(get<1>(wg_coord))>::value ? 0 : 1;
|
||||
auto& wg_block = get<wg_dim>(wg_coord);
|
||||
if (warp_group_role == WarpGroupRole::Consumer0) {
|
||||
wg_block = NumMmaWarpGroups * wg_block + 0;
|
||||
}
|
||||
else if (warp_group_role == WarpGroupRole::Consumer1) {
|
||||
wg_block = NumMmaWarpGroups * wg_block + 1;
|
||||
}
|
||||
else if (warp_group_role == WarpGroupRole::Consumer2) {
|
||||
wg_block = NumMmaWarpGroups * wg_block + 2;
|
||||
}
|
||||
else if (warp_group_role == WarpGroupRole::Consumer3) {
|
||||
wg_block = NumMmaWarpGroups * wg_block + 3;
|
||||
}
|
||||
|
||||
auto result = collective_mainloop.compute(
|
||||
blk_coord, wg_coord,
|
||||
params.mainloop, params.problem_size,
|
||||
pipeline_inner, smem_pipe_read_inner,
|
||||
pipeline_outer, smem_pipe_read_outer,
|
||||
pipeline_reducer, smem_pipe_write_reducer,
|
||||
storage.tensors.mainloop,
|
||||
math_wg_order_barrier
|
||||
);
|
||||
|
||||
if (warp_group_role == WarpGroupRole::Consumer0) {
|
||||
smem_pipe_read_outer.advance(kOuterLoads * (NumMmaWarpGroups - 0));
|
||||
}
|
||||
if constexpr (NumMmaWarpGroups >= 2) {
|
||||
if (warp_group_role == WarpGroupRole::Consumer1) {
|
||||
smem_pipe_read_outer.advance(kOuterLoads * (NumMmaWarpGroups - 1));
|
||||
}
|
||||
}
|
||||
if constexpr (NumMmaWarpGroups >= 3) {
|
||||
if (warp_group_role == WarpGroupRole::Consumer2) {
|
||||
smem_pipe_read_outer.advance(kOuterLoads * (NumMmaWarpGroups - 2));
|
||||
}
|
||||
}
|
||||
if constexpr (NumMmaWarpGroups >= 4) {
|
||||
if (warp_group_role == WarpGroupRole::Consumer3) {
|
||||
smem_pipe_read_outer.advance(kOuterLoads * (NumMmaWarpGroups - 3));
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (kIsEpilogueLocked) ; math_wg_order_barrier.wait();
|
||||
|
||||
CollectiveEpilogue epilogue;
|
||||
epilogue(typename CollectiveMainloop::TileShapePV{}, wg_coord,
|
||||
result, typename CollectiveMainloop::TiledMmaPV{},
|
||||
params.problem_size, params.epilogue,
|
||||
epi_load_pipeline, storage.tensors.epilogue[consumer_warp_group_idx]);
|
||||
|
||||
if constexpr (kIsEpilogueLocked) ; math_wg_order_barrier.arrive();
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
||||
83
examples/88_hopper_fmha/kernel/fmha_options.hpp
Normal file
83
examples/88_hopper_fmha/kernel/fmha_options.hpp
Normal file
@ -0,0 +1,83 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
template<auto kTag, typename Default, typename... Options>
|
||||
struct find_option;
|
||||
|
||||
template<auto kTag, typename Default>
|
||||
struct find_option<kTag, Default> {
|
||||
using option_value = Default;
|
||||
};
|
||||
|
||||
template<auto kTag, typename Default, typename Option, typename... Options>
|
||||
struct find_option<kTag, Default, Option, Options...> :
|
||||
std::conditional_t<
|
||||
Option::tag == kTag,
|
||||
Option,
|
||||
find_option<kTag, Default, Options...>
|
||||
>
|
||||
{};
|
||||
|
||||
template<auto kTag, typename Default, typename... Options>
|
||||
using find_option_t = typename find_option<kTag, Default, Options...>::option_value;
|
||||
|
||||
enum class Tag {
|
||||
kIsPersistent,
|
||||
kNumMmaWarpGroups,
|
||||
kLoadsQSeparately,
|
||||
|
||||
kIsMainloopLocked,
|
||||
kIsEpilogueLocked,
|
||||
|
||||
kStagesQ,
|
||||
kStagesKV,
|
||||
|
||||
kEpilogueKind,
|
||||
|
||||
kBlocksPerSM,
|
||||
kClusterM,
|
||||
|
||||
kAccQK
|
||||
};
|
||||
|
||||
template<auto kTag, class Value>
|
||||
struct Option {
|
||||
static constexpr auto tag = kTag;
|
||||
using option_value = Value;
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
||||
204
examples/88_hopper_fmha/kernel/fmha_tile_scheduler.hpp
Normal file
204
examples/88_hopper_fmha/kernel/fmha_tile_scheduler.hpp
Normal file
@ -0,0 +1,204 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/kernel_hardware_info.h"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct IndividualTileScheduler {
|
||||
|
||||
struct Params {
|
||||
dim3 grid;
|
||||
};
|
||||
|
||||
bool valid_ = true;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
IndividualTileScheduler(Params const&) {}
|
||||
|
||||
template<class ProblemSize, class ClusterShape, class TileShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
|
||||
ClusterShape const& cluster_shape, TileShape const& tile_shape)
|
||||
{
|
||||
using namespace cute;
|
||||
dim3 grid(round_up(ceil_div(size<2>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)), size<0>(problem_size), size<1>(problem_size));
|
||||
return Params{ grid };
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
return params.grid;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_valid() {
|
||||
return valid_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
auto get_block_coord() {
|
||||
using namespace cute;
|
||||
return make_coord(blockIdx.x, _0{}, make_coord(blockIdx.y, blockIdx.z));
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
IndividualTileScheduler& operator++() {
|
||||
valid_ = false;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct PersistentTileScheduler {
|
||||
|
||||
struct Params {
|
||||
int num_blocks;
|
||||
FastDivmod divmod_m_block;
|
||||
FastDivmod divmod_b;
|
||||
FastDivmod divmod_h;
|
||||
|
||||
KernelHardwareInfo hw_info;
|
||||
};
|
||||
|
||||
int block_idx = 0;
|
||||
Params params;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
PersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {}
|
||||
|
||||
template<class ProblemSize, class ClusterShape, class TileShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
|
||||
ClusterShape const& cluster_shape, TileShape const& tile_shape)
|
||||
{
|
||||
using namespace cute;
|
||||
// Get SM count if needed, otherwise use user supplied SM count
|
||||
int sm_count = hw_info.sm_count;
|
||||
if (sm_count <= 0) {
|
||||
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
|
||||
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
|
||||
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
|
||||
hw_info.sm_count = sm_count;
|
||||
|
||||
int num_m_blocks = cutlass::round_up(ceil_div(size<2>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape));
|
||||
int num_blocks = num_m_blocks * size<0>(problem_size) * size<1>(problem_size);
|
||||
|
||||
return Params {
|
||||
num_blocks,
|
||||
{ num_m_blocks}, { size<0>(problem_size) }, { size<1>(problem_size) },
|
||||
hw_info
|
||||
};
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1);
|
||||
return grid;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_valid() {
|
||||
return block_idx < params.num_blocks;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
auto get_block_coord() {
|
||||
using namespace cute;
|
||||
int block_decode = block_idx;
|
||||
int m_block, bidb, bidh;
|
||||
params.divmod_m_block(block_decode, m_block, block_decode);
|
||||
params.divmod_b(block_decode, bidb, block_decode);
|
||||
params.divmod_h(block_decode, bidh, block_decode);
|
||||
return make_coord(m_block, _0{}, make_coord(bidb, bidh));
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
PersistentTileScheduler& operator++() {
|
||||
block_idx += gridDim.x;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Base>
|
||||
struct TileSchedulerBwdAdapter {
|
||||
|
||||
using Params = typename Base::Params;
|
||||
|
||||
Base base_;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
TileSchedulerBwdAdapter(Params const& params) : base_(params) {}
|
||||
|
||||
template<class ProblemSize, class ClusterShape, class TileShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
|
||||
ClusterShape const& cluster_shape, TileShape const& tile_shape)
|
||||
{
|
||||
using namespace cute;
|
||||
return Base::to_underlying_arguments(select<0,1,3,2,4>(problem_size), hw_info, select<1,0,2>(cluster_shape), select<1,0,2>(tile_shape));
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
return Base::get_grid_shape(params);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_valid() {
|
||||
return base_.is_valid();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
auto get_block_coord() {
|
||||
using namespace cute;
|
||||
return select<1,0,2>(base_.get_block_coord());
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
TileSchedulerBwdAdapter& operator++() {
|
||||
++base_;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
||||
357
examples/88_hopper_fmha/reference/fmha_bwd_reference.hpp
Normal file
357
examples/88_hopper_fmha/reference/fmha_bwd_reference.hpp
Normal file
@ -0,0 +1,357 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class ProblemShape,
|
||||
class TensorQ, class TensorK, class TensorV,
|
||||
class TensorO, class TensorLSE, class TensorDO,
|
||||
class TensorDQ, /* class TensorDK, class TensorDV, */
|
||||
class Fusion
|
||||
>
|
||||
void __global__ fmha_bwd_reference_dQ_kernel(
|
||||
ProblemShape problem_shape,
|
||||
TensorQ mQ, TensorK mK, TensorV mV,
|
||||
TensorO mO, TensorLSE mLSE, TensorDO mDO,
|
||||
TensorDQ mDQ, /* TensorDK mDK, TensorDV mDV, */
|
||||
Fusion fusion
|
||||
) {
|
||||
using namespace cute;
|
||||
|
||||
using Element = typename TensorO::value_type;
|
||||
using ElementAccumulator = typename TensorLSE::value_type;
|
||||
|
||||
extern __shared__ char mS_mem[];
|
||||
Element* mS = reinterpret_cast<Element*>(mS_mem);
|
||||
|
||||
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<1>(mO)));
|
||||
|
||||
for (int idx_L = blockIdx.y; idx_L < size<2>(mDQ); idx_L += gridDim.y) {
|
||||
for (int idx_Q = blockIdx.x; idx_Q < size<0>(mDQ); idx_Q += gridDim.x) {
|
||||
|
||||
for (int idx_K = threadIdx.x; idx_K < size<0>(mK); idx_K += blockDim.x) {
|
||||
ElementAccumulator acc_qk = 0;
|
||||
ElementAccumulator acc_dov = 0;
|
||||
ElementAccumulator acc_doo = 0;
|
||||
for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) {
|
||||
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
|
||||
acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
|
||||
acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L);
|
||||
}
|
||||
|
||||
auto id = make_identity_tensor(make_shape(1, 1));
|
||||
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
|
||||
frag(0) = acc_qk;
|
||||
fusion.before_softmax(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
|
||||
acc_qk = frag(0);
|
||||
|
||||
mS[idx_K] = static_cast<Element>(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo));
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int idx_D = threadIdx.x; idx_D < size<1>(mDQ); idx_D += blockDim.x) {
|
||||
ElementAccumulator acc = 0;
|
||||
for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) {
|
||||
acc += mS[idx_K] * mK(idx_K, idx_D, idx_L);
|
||||
}
|
||||
mDQ(idx_Q, idx_D, idx_L) = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class ProblemShape,
|
||||
class TensorQ, class TensorK, class TensorV,
|
||||
class TensorO, class TensorLSE, class TensorDO,
|
||||
/* class TensorDQ, */ class TensorDK, /* class TensorDV, */
|
||||
class Fusion
|
||||
>
|
||||
void __global__ fmha_bwd_reference_dK_kernel(
|
||||
ProblemShape problem_shape,
|
||||
TensorQ mQ, TensorK mK, TensorV mV,
|
||||
TensorO mO, TensorLSE mLSE, TensorDO mDO,
|
||||
/* TensorDQ mDQ, */ TensorDK mDK, /* TensorDV mDV, */
|
||||
Fusion fusion
|
||||
) {
|
||||
using namespace cute;
|
||||
|
||||
using Element = typename TensorO::value_type;
|
||||
using ElementAccumulator = typename TensorLSE::value_type;
|
||||
|
||||
extern __shared__ char mS_mem[];
|
||||
Element* mS = reinterpret_cast<Element*>(mS_mem);
|
||||
|
||||
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<1>(mO)));
|
||||
|
||||
for (int idx_L = blockIdx.y; idx_L < size<2>(mDK); idx_L += gridDim.y) {
|
||||
for (int idx_K = blockIdx.x; idx_K < size<0>(mDK); idx_K += gridDim.x) {
|
||||
|
||||
for (int idx_Q = threadIdx.x; idx_Q < size<0>(mDO); idx_Q += blockDim.x) {
|
||||
ElementAccumulator acc_qk = 0;
|
||||
ElementAccumulator acc_dov = 0;
|
||||
ElementAccumulator acc_doo = 0;
|
||||
for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) {
|
||||
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
|
||||
acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
|
||||
acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L);
|
||||
}
|
||||
|
||||
auto id = make_identity_tensor(make_shape(1, 1));
|
||||
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
|
||||
frag(0) = acc_qk;
|
||||
fusion.before_softmax(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
|
||||
acc_qk = frag(0);
|
||||
|
||||
mS[idx_Q] = static_cast<Element>(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo));
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int idx_D = threadIdx.x; idx_D < size<1>(mDK); idx_D += blockDim.x) {
|
||||
ElementAccumulator acc = 0;
|
||||
for (int idx_Q = 0; idx_Q < size<0>(mDO); idx_Q++) {
|
||||
acc += mS[idx_Q] * mQ(idx_Q, idx_D, idx_L);
|
||||
}
|
||||
mDK(idx_K, idx_D, idx_L) = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class ProblemShape,
|
||||
class TensorQ, class TensorK, class TensorV,
|
||||
class TensorO, class TensorLSE, class TensorDO,
|
||||
/* class TensorDQ, class TensorDK, */ class TensorDV,
|
||||
class Fusion
|
||||
>
|
||||
void __global__ fmha_bwd_reference_dV_kernel(
|
||||
ProblemShape problem_shape,
|
||||
TensorQ mQ, TensorK mK, TensorV mV,
|
||||
TensorO mO, TensorLSE mLSE, TensorDO mDO,
|
||||
/* TensorDQ mDQ, TensorDK mDK, */ TensorDV mDV,
|
||||
Fusion fusion
|
||||
) {
|
||||
using namespace cute;
|
||||
|
||||
using Element = typename TensorO::value_type;
|
||||
using ElementAccumulator = typename TensorLSE::value_type;
|
||||
|
||||
extern __shared__ char mS_mem[];
|
||||
Element* mS = reinterpret_cast<Element*>(mS_mem);
|
||||
|
||||
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<1>(mO)));
|
||||
|
||||
for (int idx_L = blockIdx.y; idx_L < size<2>(mDV); idx_L += gridDim.y) {
|
||||
for (int idx_K = blockIdx.x; idx_K < size<0>(mDV); idx_K += gridDim.x) {
|
||||
|
||||
for (int idx_Q = threadIdx.x; idx_Q < size<0>(mDO); idx_Q += blockDim.x) {
|
||||
ElementAccumulator acc_qk = 0;
|
||||
for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) {
|
||||
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
|
||||
}
|
||||
|
||||
auto id = make_identity_tensor(make_shape(1, 1));
|
||||
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
|
||||
frag(0) = acc_qk;
|
||||
fusion.before_softmax(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
|
||||
acc_qk = frag(0);
|
||||
|
||||
mS[idx_Q] = static_cast<Element>(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)));
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int idx_D = threadIdx.x; idx_D < size<1>(mDV); idx_D += blockDim.x) {
|
||||
ElementAccumulator acc = 0;
|
||||
for (int idx_Q = 0; idx_Q < size<0>(mDO); idx_Q++) {
|
||||
acc += mS[idx_Q] * mDO(idx_Q, idx_D, idx_L);
|
||||
}
|
||||
mDV(idx_K, idx_D, idx_L) = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class ProblemShape,
|
||||
class TensorQ, class TensorK, class TensorV,
|
||||
class TensorO, class TensorLSE, class TensorDO,
|
||||
/**/ class TensorDQ, /** / class TensorDK, / ** / class TensorDV, / **/
|
||||
class Fusion
|
||||
>
|
||||
void fmha_bwd_reference_dQ(
|
||||
ProblemShape problem_shape,
|
||||
TensorQ mQ, TensorK mK, TensorV mV,
|
||||
TensorO mO, TensorLSE mLSE, TensorDO mDO,
|
||||
/**/ TensorDQ mDQ, /** / TensorDK mDK, / ** / TensorDV mDV, / **/
|
||||
Fusion fusion
|
||||
) {
|
||||
using namespace cute;
|
||||
|
||||
dim3 grid(size<0>(mDQ), size<2>(mDQ), 1);
|
||||
dim3 block(256);
|
||||
int shared_mem = size<0>(mK) * sizeof(typename TensorO::value_type);
|
||||
|
||||
if (shared_mem >= (48 << 10)) {
|
||||
CUTLASS_TRACE_HOST(" Setting smem size to " << shared_mem);
|
||||
auto result = cudaFuncSetAttribute(
|
||||
fmha_bwd_reference_dQ_kernel<ProblemShape, TensorQ, TensorK, TensorV, TensorO, TensorLSE, TensorDO, TensorDQ, Fusion>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
shared_mem);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaFuncSetAttribute() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
fmha_bwd_reference_dQ_kernel<<<grid, block, shared_mem>>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, fusion);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class ProblemShape,
|
||||
class TensorQ, class TensorK, class TensorV,
|
||||
class TensorO, class TensorLSE, class TensorDO,
|
||||
/** / class TensorDQ, / **/ class TensorDK, /** / class TensorDV, / **/
|
||||
class Fusion
|
||||
>
|
||||
void fmha_bwd_reference_dK(
|
||||
ProblemShape problem_shape,
|
||||
TensorQ mQ, TensorK mK, TensorV mV,
|
||||
TensorO mO, TensorLSE mLSE, TensorDO mDO,
|
||||
/** / TensorDQ mDQ, / **/ TensorDK mDK, /** / TensorDV mDV, / **/
|
||||
Fusion fusion
|
||||
) {
|
||||
using namespace cute;
|
||||
|
||||
dim3 grid(size<0>(mDK), size<2>(mDK), 1);
|
||||
dim3 block(256);
|
||||
int shared_mem = size<0>(mDO) * sizeof(typename TensorO::value_type);
|
||||
|
||||
if (shared_mem >= (48 << 10)) {
|
||||
CUTLASS_TRACE_HOST(" Setting smem size to " << shared_mem);
|
||||
auto result = cudaFuncSetAttribute(
|
||||
fmha_bwd_reference_dK_kernel<ProblemShape, TensorQ, TensorK, TensorV, TensorO, TensorLSE, TensorDO, TensorDK, Fusion>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
shared_mem);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaFuncSetAttribute() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
fmha_bwd_reference_dK_kernel<<<grid, block, shared_mem>>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDK, fusion);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class ProblemShape,
|
||||
class TensorQ, class TensorK, class TensorV,
|
||||
class TensorO, class TensorLSE, class TensorDO,
|
||||
/** / class TensorDQ, / ** / class TensorDK, / **/ class TensorDV, /**/
|
||||
class Fusion
|
||||
>
|
||||
void fmha_bwd_reference_dV(
|
||||
ProblemShape problem_shape,
|
||||
TensorQ mQ, TensorK mK, TensorV mV,
|
||||
TensorO mO, TensorLSE mLSE, TensorDO mDO,
|
||||
/** / TensorDQ mDQ, / ** / TensorDK mDK, / **/ TensorDV mDV, /**/
|
||||
Fusion fusion
|
||||
) {
|
||||
using namespace cute;
|
||||
|
||||
dim3 grid(size<0>(mDV), size<2>(mDV), 1);
|
||||
dim3 block(256);
|
||||
int shared_mem = size<0>(mDO) * sizeof(typename TensorO::value_type);
|
||||
|
||||
if (shared_mem >= (48 << 10)) {
|
||||
CUTLASS_TRACE_HOST(" Setting smem size to " << shared_mem);
|
||||
auto result = cudaFuncSetAttribute(
|
||||
fmha_bwd_reference_dV_kernel<ProblemShape, TensorQ, TensorK, TensorV, TensorO, TensorLSE, TensorDO, TensorDV, Fusion>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
shared_mem);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaFuncSetAttribute() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
fmha_bwd_reference_dV_kernel<<<grid, block, shared_mem>>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDV, fusion);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class ProblemShape,
|
||||
class TensorQ, class TensorK, class TensorV,
|
||||
class TensorO, class TensorLSE, class TensorDO,
|
||||
class TensorDQ, class TensorDK, class TensorDV,
|
||||
class Fusion
|
||||
>
|
||||
void fmha_bwd_reference(
|
||||
ProblemShape problem_shape,
|
||||
TensorQ mQ, TensorK mK, TensorV mV,
|
||||
TensorO mO, TensorLSE mLSE, TensorDO mDO,
|
||||
TensorDQ mDQ, TensorDK mDK, TensorDV mDV,
|
||||
Fusion fusion
|
||||
) {
|
||||
fmha_bwd_reference_dQ(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, fusion);
|
||||
fmha_bwd_reference_dK(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDK, fusion);
|
||||
fmha_bwd_reference_dV(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDV, fusion);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
156
examples/88_hopper_fmha/reference/fmha_reference.hpp
Normal file
156
examples/88_hopper_fmha/reference/fmha_reference.hpp
Normal file
@ -0,0 +1,156 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class ProblemShape,
|
||||
class TensorQ,
|
||||
class TensorK,
|
||||
class TensorV,
|
||||
class TensorO,
|
||||
class TensorLSE,
|
||||
class Fusion
|
||||
>
|
||||
void __global__ fmha_reference_kernel(
|
||||
ProblemShape problem_shape,
|
||||
TensorQ mQ, TensorK mK, TensorV mV,
|
||||
TensorO mO, TensorLSE mLSE,
|
||||
Fusion fusion
|
||||
) {
|
||||
using namespace cute;
|
||||
|
||||
using Element = typename TensorO::value_type;
|
||||
using ElementAccumulator = typename TensorLSE::value_type;
|
||||
|
||||
extern __shared__ char mS_mem[];
|
||||
Element* mS = reinterpret_cast<Element*>(mS_mem);
|
||||
|
||||
ElementAccumulator softmax_scale = static_cast<ElementAccumulator>(1.0 / sqrt(1.0 * size<1>(mO)));
|
||||
|
||||
auto id = make_identity_tensor(make_shape(1, 1));
|
||||
for (int idx_L = blockIdx.y; idx_L < size<2>(mO); idx_L += gridDim.y) {
|
||||
for (int idx_Q = blockIdx.x; idx_Q < size<0>(mO); idx_Q += gridDim.x) {
|
||||
for (int idx_K = threadIdx.x; idx_K < size<0>(mK); idx_K += blockDim.x) {
|
||||
ElementAccumulator acc = 0;
|
||||
for (int idx_D = 0; idx_D < size<1>(mK); idx_D++) {
|
||||
acc += mQ(idx_Q, idx_D, idx_L) * mK(idx_K, idx_D, idx_L);
|
||||
}
|
||||
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
|
||||
frag(0) = acc;
|
||||
fusion.before_softmax(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
|
||||
mS[idx_K] = static_cast<Element>(frag(0) * softmax_scale);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
ElementAccumulator maxS = -std::numeric_limits<ElementAccumulator>::infinity();
|
||||
for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) {
|
||||
maxS = std::max<ElementAccumulator>(maxS, mS[idx_K]);
|
||||
}
|
||||
if (maxS == -std::numeric_limits<ElementAccumulator>::infinity()) maxS = 0;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int idx_K = threadIdx.x; idx_K < size<0>(mK); idx_K += blockDim.x) {
|
||||
mS[idx_K] = static_cast<Element>(exp(mS[idx_K] - maxS));
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
ElementAccumulator sum = 0;
|
||||
for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) {
|
||||
sum += mS[idx_K];
|
||||
}
|
||||
|
||||
Element scale = static_cast<Element>(1.0 / sum);
|
||||
|
||||
for (int idx_D = threadIdx.x; idx_D < size<1>(mO); idx_D += blockDim.x) {
|
||||
ElementAccumulator acc = 0;
|
||||
for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) {
|
||||
acc += mS[idx_K] * mV(idx_K, idx_D, idx_L) * scale;
|
||||
}
|
||||
mO(idx_Q, idx_D, idx_L) = static_cast<Element>(acc);
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
mLSE(idx_Q, idx_L) = log(sum) + maxS;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class ProblemShape,
|
||||
class TensorQ,
|
||||
class TensorK,
|
||||
class TensorV,
|
||||
class TensorO,
|
||||
class TensorLSE,
|
||||
class Fusion
|
||||
>
|
||||
void fmha_reference(
|
||||
ProblemShape problem_shape,
|
||||
TensorQ mQ, TensorK mK, TensorV mV,
|
||||
TensorO mO, TensorLSE mLSE,
|
||||
Fusion fusion
|
||||
) {
|
||||
using namespace cute;
|
||||
|
||||
dim3 grid(size<0>(mO), size<2>(mO), 1);
|
||||
dim3 block(256);
|
||||
int shared_mem = size<0>(mK) * sizeof(typename TensorO::value_type);
|
||||
|
||||
if (shared_mem >= (48 << 10)) {
|
||||
CUTLASS_TRACE_HOST(" Setting smem size to " << shared_mem);
|
||||
auto result = cudaFuncSetAttribute(
|
||||
fmha_reference_kernel<ProblemShape, TensorQ, TensorK, TensorV, TensorO, TensorLSE, Fusion>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
shared_mem);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaFuncSetAttribute() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
fmha_reference_kernel<<<grid, block, shared_mem>>>(problem_shape, mQ, mK, mV, mO, mLSE, fusion);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
129
examples/88_hopper_fmha/reference/reference_abs_error.hpp
Normal file
129
examples/88_hopper_fmha/reference/reference_abs_error.hpp
Normal file
@ -0,0 +1,129 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
#include "cutlass/util/device_memory.h"
|
||||
|
||||
template<typename Element>
|
||||
__global__ void reference_abs_diff_kernel(
|
||||
Element* data, Element* data_ref, size_t count,
|
||||
double* max_diff, double* sum_diff,
|
||||
bool print_diff
|
||||
) {
|
||||
double thread_max_diff = 0;
|
||||
double thread_sum_diff = 0;
|
||||
|
||||
__shared__ double block_max_diff;
|
||||
__shared__ double block_sum_diff;
|
||||
|
||||
for (size_t i = threadIdx.x + blockIdx.x * blockDim.x; i < count; i += blockDim.x * gridDim.x) {
|
||||
double diff = fabs(data[i] - data_ref[i]);
|
||||
if (print_diff) if (diff != diff || diff > 0.01f) printf("difference at %lld: %f ... %f vs %f\n", static_cast<long long int>(i), diff, (double)data[i], (double)data_ref[i]);
|
||||
thread_max_diff = fmax(diff, thread_max_diff);
|
||||
thread_sum_diff += diff;
|
||||
}
|
||||
|
||||
for (int i = 0; i < blockDim.x; i++) {
|
||||
if (i == threadIdx.x) {
|
||||
if (i == 0) {
|
||||
block_max_diff = thread_max_diff;
|
||||
block_sum_diff = thread_sum_diff;
|
||||
} else {
|
||||
block_max_diff = fmax(block_max_diff, thread_max_diff);
|
||||
block_sum_diff += thread_sum_diff;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
atomicAdd(sum_diff, block_sum_diff);
|
||||
|
||||
for (;;) {
|
||||
unsigned long long prev = *reinterpret_cast<unsigned long long*>(max_diff);
|
||||
double prev_diff = reinterpret_cast<double const&>(prev);
|
||||
double new_max_diff = fmax(block_max_diff, prev_diff);
|
||||
unsigned long long found = atomicCAS(reinterpret_cast<unsigned long long*>(max_diff), prev, reinterpret_cast<unsigned long long const&>(new_max_diff));
|
||||
if (found == prev) break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Element>
|
||||
void reference_abs_diff(
|
||||
cutlass::DeviceAllocation<Element> const& data,
|
||||
cutlass::DeviceAllocation<Element> const& data_ref,
|
||||
double& max_diff, double& mean_diff
|
||||
) {
|
||||
static bool kPrintDiff = getenv("REF_PRINT_DIFF") && atoi(getenv("REF_PRINT_DIFF")) == 1;
|
||||
|
||||
cutlass::DeviceAllocation<double> result;
|
||||
result.reset(2);
|
||||
assert(data.size() == data_ref.size());
|
||||
|
||||
cudaError_t err = cudaMemset(result.get(), 0, result.size() * sizeof(double));
|
||||
if (err != cudaSuccess) {
|
||||
std::cerr << "Memset failed. Last CUDA error: "
|
||||
<< cudaGetErrorString(err) << std::endl;
|
||||
max_diff = mean_diff = 1e20;
|
||||
return;
|
||||
}
|
||||
|
||||
dim3 block(256, 1, 1);
|
||||
dim3 grid(1024, 1, 1);
|
||||
reference_abs_diff_kernel<<<block, grid>>>(
|
||||
data.get(), data_ref.get(), data.size(),
|
||||
result.get(), result.get() + 1, kPrintDiff);
|
||||
|
||||
err = cudaDeviceSynchronize();
|
||||
if (err != cudaSuccess) {
|
||||
std::cerr << "Difference kernel failed. Last CUDA error: "
|
||||
<< cudaGetErrorString(err) << std::endl;
|
||||
max_diff = mean_diff = 1e20;
|
||||
return;
|
||||
}
|
||||
|
||||
double result_host[2];
|
||||
err = cudaMemcpy(result_host, result.get(), result.size() * sizeof(double), cudaMemcpyDefault);
|
||||
if (err != cudaSuccess) {
|
||||
std::cerr << "Copy failed. Last CUDA error: "
|
||||
<< cudaGetErrorString(err) << std::endl;
|
||||
max_diff = mean_diff = 1e20;
|
||||
return;
|
||||
}
|
||||
|
||||
max_diff = result_host[0];
|
||||
mean_diff = result_host[1] / static_cast<double>(data.size());
|
||||
}
|
||||
@ -163,6 +163,7 @@ foreach(EXAMPLE
|
||||
82_blackwell_distributed_gemm
|
||||
83_blackwell_sparse_gemm
|
||||
84_blackwell_narrow_precision_sparse_gemm
|
||||
88_hopper_fmha
|
||||
)
|
||||
|
||||
add_subdirectory(${EXAMPLE})
|
||||
|
||||
@ -55,3 +55,7 @@ cutlass_example_add_executable(
|
||||
tiled_copy.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
cute_tutorial_tiled_copy_if
|
||||
tiled_copy_if.cu
|
||||
)
|
||||
|
||||
297
examples/cute/tutorial/tiled_copy_if.cu
Normal file
297
examples/cute/tutorial/tiled_copy_if.cu
Normal file
@ -0,0 +1,297 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <thrust/host_vector.h>
|
||||
#include <thrust/device_vector.h>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include "cutlass/util/print_error.hpp"
|
||||
#include "cutlass/util/GPU_Clock.hpp"
|
||||
#include "cutlass/util/helper_cuda.hpp"
|
||||
|
||||
// This example extends `tiled_copy` using predicate tensors to guard memory accesses performed
|
||||
// by `cute::copy_if()`. This enables tensors to have shapes that are not integer multiples of
|
||||
// block sizes.
|
||||
//
|
||||
// This is accomplished by instantiating a tensor of coordinates which correspond to tensor elements
|
||||
// to be accessed and then computing a predicate tensor which masks accesses. The example demonstrates
|
||||
// how constructing of an identity tensor containing coordinates and a predicate tensor containing
|
||||
// mask bits can be implemented using the same CuTe operations used to tile the tensors in
|
||||
// Global Memory.
|
||||
//
|
||||
// This example implements two variants:
|
||||
// - copy_if_kernel() uses `cute::local_partition()` to construct each thread's slice
|
||||
// - copy_if_kernel_vectorized() uses `make_tiled_copy() to implement vectorized memory accesses.
|
||||
//
|
||||
// The tensor shapes and strides must be divisible by the shape of the vector access.
|
||||
//
|
||||
|
||||
/// Simple copy kernel.
|
||||
//
|
||||
// Uses local_partition() to partition a tile among threads arranged as (THR_M, THR_N).
|
||||
template <class TensorS, class TensorD, class BlockShape, class ThreadLayout>
|
||||
__global__ void copy_if_kernel(TensorS S, TensorD D, BlockShape block_shape, ThreadLayout)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
// Construct a coordinate tensor whose elements are the coordinates used to access tensors S and D.
|
||||
auto shape_S = shape(S);
|
||||
Tensor C = make_identity_tensor(shape_S);
|
||||
// Construct a predicate tensor which compares the coordinates with the original shape
|
||||
Tensor P = cute::lazy::transform(C, [&](auto c) { return elem_less(c, shape_S); });
|
||||
|
||||
// Tile the input tensor into blocks
|
||||
auto block_coord = make_coord(blockIdx.x, blockIdx.y);
|
||||
Tensor tile_S = local_tile(S, block_shape, block_coord); // (BlockShape_M, BlockShape_N)
|
||||
Tensor tile_D = local_tile(D, block_shape, block_coord); // (BlockShape_M, BlockShape_N)
|
||||
Tensor tile_P = local_tile(P, block_shape, block_coord); // (BlockShape_M, BlockShape_N)
|
||||
|
||||
// Construct a partitioning of the tile among threads with the given thread arrangement.
|
||||
|
||||
// Concept: Tensor ThrLayout ThrIndex
|
||||
Tensor thr_tile_S = local_partition(tile_S, ThreadLayout{}, threadIdx.x);
|
||||
Tensor thr_tile_D = local_partition(tile_D, ThreadLayout{}, threadIdx.x);
|
||||
Tensor thr_tile_P = local_partition(tile_P, ThreadLayout{}, threadIdx.x);
|
||||
|
||||
// Copy from GMEM to GMEM using `thr_tile_P` to guard accesses.
|
||||
copy_if(thr_tile_P, thr_tile_S, thr_tile_D);
|
||||
}
|
||||
|
||||
/// Vectorized copy kernel.
|
||||
///
|
||||
/// Uses `make_tiled_copy()` to perform a copy using vector instructions. This operation
|
||||
/// has the precondition that pointers are aligned to the vector size.
|
||||
///
|
||||
template <class TensorS, class TensorD, class BlockShape, class Tiled_Copy>
|
||||
__global__ void copy_if_kernel_vectorized(TensorS S, TensorD D, BlockShape block_shape, Tiled_Copy tiled_copy)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
// Construct a coordinate tensor whose elements are the coordinates used to access tensors S and D.
|
||||
auto shape_S = shape(S);
|
||||
Tensor C = make_identity_tensor(shape_S);
|
||||
// Construct a predicate tensor which compares the coordinates with the original shape
|
||||
Tensor P = cute::lazy::transform(C, [&](auto c) { return elem_less(c, shape_S); });
|
||||
|
||||
// Tile the input tensor into blocks
|
||||
auto block_coord = make_coord(blockIdx.x, blockIdx.y);
|
||||
Tensor tile_S = local_tile(S, block_shape, block_coord); // (BlockShape_M, BlockShape_N)
|
||||
Tensor tile_D = local_tile(D, block_shape, block_coord); // (BlockShape_M, BlockShape_N)
|
||||
Tensor tile_P = local_tile(P, block_shape, block_coord); // (BlockShape_M, BlockShape_N)
|
||||
|
||||
//
|
||||
// Construct a Tensor corresponding to each thread's slice.
|
||||
//
|
||||
ThrCopy thr_copy = tiled_copy.get_thread_slice(threadIdx.x);
|
||||
Tensor thr_tile_S = thr_copy.partition_S(tile_S); // (CPY, CPY_M, CPY_N)
|
||||
Tensor thr_tile_D = thr_copy.partition_D(tile_D); // (CPY, CPY_M, CPY_N)
|
||||
Tensor thr_tile_P = thr_copy.partition_S(tile_P); // (CPY, CPY_M, CPY_N)
|
||||
|
||||
#if 0
|
||||
// Copy from GMEM to GMEM
|
||||
copy_if(tiled_copy, thr_tile_P, thr_tile_S, thr_tile_D);
|
||||
#else
|
||||
// make_fragment_like() constructs a tensor in RMEM with the same shape as thr_tile_S.
|
||||
Tensor frag = make_fragment_like(thr_tile_S);
|
||||
|
||||
// Copy from GMEM to RMEM and from RMEM to GMEM
|
||||
copy_if(tiled_copy, thr_tile_P, thr_tile_S, frag);
|
||||
copy_if(tiled_copy, thr_tile_P, frag, thr_tile_D);
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Main function
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
//
|
||||
// Given a 2D shape, perform an efficient copy
|
||||
//
|
||||
|
||||
using namespace cute;
|
||||
using Element = float;
|
||||
|
||||
// Define a tensor shape with dynamic extents (m, n)
|
||||
auto tensor_shape = make_shape(528, 300);
|
||||
|
||||
thrust::host_vector<Element> h_S(size(tensor_shape));
|
||||
thrust::host_vector<Element> h_D(size(tensor_shape));
|
||||
|
||||
//
|
||||
// Initialize
|
||||
//
|
||||
|
||||
for (size_t i = 0; i < h_S.size(); ++i) {
|
||||
h_S[i] = static_cast<Element>(i);
|
||||
h_D[i] = Element{};
|
||||
}
|
||||
|
||||
thrust::device_vector<Element> d_S = h_S;
|
||||
thrust::device_vector<Element> d_D = h_D;
|
||||
thrust::device_vector<Element> d_Zero = h_D;
|
||||
|
||||
//
|
||||
// Make tensors
|
||||
//
|
||||
|
||||
Tensor tensor_S = make_tensor(make_gmem_ptr(d_S.data().get()), make_layout(tensor_shape));
|
||||
Tensor tensor_D = make_tensor(make_gmem_ptr(d_D.data().get()), make_layout(tensor_shape));
|
||||
|
||||
//
|
||||
// Partition
|
||||
//
|
||||
|
||||
// Define a statically sized block (M, N).
|
||||
//
|
||||
// Note, by convention, capital letters are used to represent static modes.
|
||||
auto block_shape = make_shape(Int<128>{}, Int<64>{});
|
||||
|
||||
// Tile the tensor (m, n) ==> ((M, N), m', n') where (M, N) is the static tile
|
||||
// shape, and modes (m', n') correspond to the number of tiles.
|
||||
//
|
||||
// These will be used to determine the CUDA kernel grid dimensinos.
|
||||
Tensor tiled_tensor_D = tiled_divide(tensor_D, block_shape); // ((M, N), m', n')
|
||||
|
||||
// Describes the layout of threads which is then replicated to tile 'block_shape.'
|
||||
Layout thr_layout = make_layout(make_shape(Int<32>{}, Int< 8>{})); // (ThrM, ThrN)
|
||||
|
||||
//
|
||||
// Determine grid and block dimensions
|
||||
//
|
||||
|
||||
dim3 gridDim (size<1>(tiled_tensor_D), size<2>(tiled_tensor_D)); // Grid shape corresponds to modes m' and n'
|
||||
dim3 blockDim(size(thr_layout));
|
||||
|
||||
//
|
||||
// Launch the kernel
|
||||
//
|
||||
|
||||
// copy_if()
|
||||
copy_if_kernel<<< gridDim, blockDim >>>(
|
||||
tensor_S,
|
||||
tensor_D,
|
||||
block_shape,
|
||||
thr_layout);
|
||||
|
||||
cudaError result = cudaDeviceSynchronize();
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "CUDA Runtime error: " << cudaGetErrorString(result) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
h_D = d_D;
|
||||
|
||||
//
|
||||
// Verification
|
||||
//
|
||||
|
||||
auto verify = [](thrust::host_vector<Element> const &S, thrust::host_vector<Element> const &D){
|
||||
|
||||
int32_t errors = 0;
|
||||
int32_t const kErrorLimit = 10;
|
||||
|
||||
if (S.size() != D.size()) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < D.size(); ++i) {
|
||||
if (S[i] != D[i]) {
|
||||
std::cerr << "Error. S[" << i << "]: " << S[i] << ", D[" << i << "]: " << D[i] << std::endl;
|
||||
|
||||
if (++errors >= kErrorLimit) {
|
||||
std::cerr << "Aborting on " << kErrorLimit << "nth error." << std::endl;
|
||||
return errors;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return errors;
|
||||
};
|
||||
|
||||
if (verify(h_D, h_S)) {
|
||||
return -1;
|
||||
} else {
|
||||
std::cout << "Success." << std::endl;
|
||||
}
|
||||
|
||||
thrust::copy(d_Zero.begin(), d_Zero.end(), d_D.begin());
|
||||
|
||||
// Construct a TiledCopy with a specific access pattern.
|
||||
// This version uses a
|
||||
// (1) Layout-of-Threads to describe the number and arrangement of threads (e.g. row-major, col-major, etc),
|
||||
// (2) Layout-of-Values that each thread will access.
|
||||
|
||||
// Value arrangement per thread
|
||||
Layout val_layout = make_layout(make_shape(Int<4>{}, Int<1>{})); // (4,1) -> val_idx
|
||||
|
||||
// Define `AccessType` which controls the size of the actual memory access instruction.
|
||||
using CopyOp = UniversalCopy<uint_byte_t<sizeof(Element) * size(val_layout)>>; // A very specific access width copy instruction
|
||||
//using CopyOp = UniversalCopy<cutlass::AlignedArray<Element, size(val_layout)>>; // A more generic type that supports many copy strategies
|
||||
//using CopyOp = AutoVectorizingCopy; // An adaptable-width instruction that assumes maximal alignment of inputs
|
||||
|
||||
// A Copy_Atom corresponds to one CopyOperation applied to Tensors of type Element.
|
||||
using Atom = Copy_Atom<CopyOp, Element>;
|
||||
|
||||
// Construct tiled copy, a tiling of copy atoms.
|
||||
//
|
||||
// Note, this assumes the vector and thread layouts are aligned with contigous data
|
||||
// in GMEM. Alternative thread layouts are possible but may result in uncoalesced
|
||||
// reads. Alternative value layouts are also possible, though incompatible layouts
|
||||
// will result in compile time errors.
|
||||
TiledCopy tiled_copy = make_tiled_copy(Atom{}, // Access strategy
|
||||
thr_layout, // thread layout (e.g. 32x4 Col-Major)
|
||||
val_layout); // value layout (e.g. 4x1)
|
||||
|
||||
// copy_if() with vectorization
|
||||
copy_if_kernel_vectorized<<< gridDim, blockDim >>>(
|
||||
tensor_S,
|
||||
tensor_D,
|
||||
block_shape,
|
||||
tiled_copy);
|
||||
|
||||
result = cudaDeviceSynchronize();
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "CUDA Runtime error: " << cudaGetErrorString(result) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
h_D = d_D;
|
||||
|
||||
if (verify(h_D, h_S)) {
|
||||
return -1;
|
||||
} else {
|
||||
std::cout << "Success." << std::endl;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
200
examples/python/CuTeDSL/ampere/smem_allocator.py
Normal file
200
examples/python/CuTeDSL/ampere/smem_allocator.py
Normal file
@ -0,0 +1,200 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import cutlass.cute as cute
|
||||
import cutlass
|
||||
import torch
|
||||
import numpy as np
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
"""
|
||||
A Shared Memory Allocator Example on NVIDIA Ampere architecture using CuTe DSL.
|
||||
|
||||
This example demonstrates how to allocate and manage shared memory in JIT kernels by using the SmemAllocator in CuTe DSL.
|
||||
It shows various ways to allocate different data structures in shared memory:
|
||||
|
||||
1. Struct allocation with natural and strict alignment
|
||||
2. Raw memory block allocation with custom alignment
|
||||
3. Array allocation with automatic alignment
|
||||
4. Tensor allocation with layout specification
|
||||
|
||||
The example includes:
|
||||
- Shared storage struct with mixed alignment requirements
|
||||
- Memory allocation patterns for different data types
|
||||
- Tensor operations on allocated memory
|
||||
|
||||
To run this example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python examples/ampere/smem_allocator.py
|
||||
|
||||
The example will allocate shared memory, perform tensor operations, and verify the results.
|
||||
"""
|
||||
|
||||
|
||||
@cute.struct
|
||||
class complex:
|
||||
real: cutlass.Float32
|
||||
imag: cutlass.Float32
|
||||
|
||||
|
||||
# SharedStorage size is 512, alignment is 128
|
||||
@cute.struct
|
||||
class SharedStorage:
|
||||
# struct elements with natural alignment
|
||||
a: cute.struct.MemRange[cutlass.Float32, 32] # array
|
||||
b: cutlass.Int64 # saclar
|
||||
c: complex # nested struct
|
||||
# struct elements with strict alignment
|
||||
x: cute.struct.Align[
|
||||
cute.struct.MemRange[cutlass.Float32, 32],
|
||||
128,
|
||||
]
|
||||
y: cute.struct.Align[cutlass.Int32, 8]
|
||||
z: cute.struct.Align[complex, 16]
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def kernel(
|
||||
const_a: cutlass.Constexpr,
|
||||
dst_a: cute.Tensor,
|
||||
const_b: cutlass.Constexpr,
|
||||
dst_b: cute.Tensor,
|
||||
const_c: cutlass.Constexpr,
|
||||
dst_c: cute.Tensor,
|
||||
):
|
||||
# Note: SMEM_SIZE bytes (specified in kernel().launch(smem=...)) can be reserved for developer to utilize
|
||||
# Note: alignment of inital allocator base ptr is 1024
|
||||
allocator = cutlass.utils.SmemAllocator()
|
||||
# base ptr of allocator points at: SMEM_ADDR_START (the starting address of available shared memory)
|
||||
|
||||
# -- Allocate a struct --
|
||||
# Note: when specified alignment, max(alignment, alignof(struct)) will be applied
|
||||
# reserves the section of struct in smem, elements in the struct can be accessed by ptr
|
||||
struct_in_smem = allocator.allocate(SharedStorage)
|
||||
# base ptr of allocator now points at: SMEM_ADDR_AFTER_STRUCT = SMEM_ADDR_START + aligned_size(struct)
|
||||
|
||||
# -- Allocate a block of memory --
|
||||
# reserves a section of 64 bytes in smem, align to 128 bytes, returns the section base ptr
|
||||
section_in_smem = allocator.allocate(64, byte_alignment=128)
|
||||
# base ptr of allocator now points at: SMEM_ADDR_AFTER_SECTION = SMEM_ADDR_AFTER_STRUCT + aligned_size(section)
|
||||
|
||||
# -- Allocate an array --
|
||||
# reserves an int64 array of size 14 in smem, returns the array base ptr
|
||||
array_in_smem = allocator.allocate_array(element_type=cutlass.Int64, num_elems=14)
|
||||
# base ptr of allocator now points at: SMEM_ADDR_AFTER_ARRAY = SMEM_ADDR_AFTER_SECTION + aligned_size(array)
|
||||
|
||||
# -- Allocate a tensor --
|
||||
# Note: use cute.ComposedLayout or cute.Layout to specify layout of tensor
|
||||
# Note: iterator swizzle with swizzle layout is currently not supported
|
||||
layout = cute.make_layout((16, 2))
|
||||
tensor_in_smem = allocator.allocate_tensor(
|
||||
element_type=cutlass.Float32, layout=layout, byte_alignment=32, swizzle=None
|
||||
)
|
||||
# base ptr of allocator now points at: SMEM_ADDR_AFTER_TENSOR = SMEM_ADDR_AFTER_ARRAY + aligned_size(tensor)
|
||||
|
||||
# ptr<f16, smem, align<1024>>
|
||||
# ptr<i64, smem, align<128>>
|
||||
# ptr<f32, smem, align<8>>
|
||||
print(struct_in_smem.a.data_ptr())
|
||||
print(struct_in_smem.b)
|
||||
print(struct_in_smem.c.real)
|
||||
# ptr<i8, smem, align<512>>
|
||||
print(section_in_smem)
|
||||
# ptr<i64, smem, align<64>>
|
||||
print(array_in_smem)
|
||||
# tensor<ptr<f16, smem, align<32>> o (16,4):(1,16)>
|
||||
print(tensor_in_smem)
|
||||
|
||||
# fill MemRange tensor in struct and copy to dst
|
||||
a_tensor = struct_in_smem.a.get_tensor(cute.make_layout((8, 4)))
|
||||
a_tensor.fill(const_a)
|
||||
cute.printf("cute.struct.MemRange: {}", a_tensor)
|
||||
dst_a.store(a_tensor.load())
|
||||
|
||||
# convert block of smem to fill tensor and copy to dst
|
||||
layout = cute.make_layout((8, 2))
|
||||
sec_ptr = cute.recast_ptr(section_in_smem, dtype=cutlass.Float32)
|
||||
sec_tensor = cute.make_tensor(sec_ptr, layout)
|
||||
sec_tensor.fill(const_b)
|
||||
cute.printf("block of memory: {}", sec_tensor)
|
||||
dst_b.store(sec_tensor.load())
|
||||
|
||||
# fill allocated tensor in smem and copy to dst
|
||||
tensor_in_smem.fill(const_c)
|
||||
cute.printf("tensor in smem: {}", tensor_in_smem)
|
||||
dst_c.store(tensor_in_smem.load())
|
||||
|
||||
|
||||
@cute.jit
|
||||
def run_allocation_kernel(
|
||||
const_a: cutlass.Constexpr,
|
||||
dst_a: cute.Tensor,
|
||||
const_b: cutlass.Constexpr,
|
||||
dst_b: cute.Tensor,
|
||||
const_c: cutlass.Constexpr,
|
||||
dst_c: cute.Tensor,
|
||||
):
|
||||
# additional size for the example, 64(section) + 112(array) + 128(tensor) < 384
|
||||
addtional_bytes = 384
|
||||
# Note: launch shared memory size is: SMEM_SIZE = 512 + 384 = 896 bytes
|
||||
kernel(const_a, dst_a, const_b, dst_b, const_c, dst_c).launch(
|
||||
grid=(1, 1, 1),
|
||||
block=(1, 1, 1),
|
||||
smem=SharedStorage.size_in_bytes() + addtional_bytes,
|
||||
)
|
||||
|
||||
|
||||
def veify_allocation_kernel(const_a, const_b, const_c):
|
||||
dst_a = torch.zeros((8, 4), dtype=torch.float32, device="cuda")
|
||||
dst_b = torch.zeros((8, 2), dtype=torch.float32, device="cuda")
|
||||
dst_c = torch.zeros((16, 2), dtype=torch.float32, device="cuda")
|
||||
|
||||
run_allocation_kernel(
|
||||
const_a,
|
||||
from_dlpack(dst_a),
|
||||
const_b,
|
||||
from_dlpack(dst_b),
|
||||
const_c,
|
||||
from_dlpack(dst_c),
|
||||
)
|
||||
|
||||
np.testing.assert_equal(const_a, dst_a.detach().cpu().numpy()[0])
|
||||
np.testing.assert_equal(const_b, dst_b.detach().cpu().numpy()[0])
|
||||
np.testing.assert_equal(const_c, dst_c.detach().cpu().numpy()[0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# prepare cuda context
|
||||
cutlass.cuda.initialize_cuda_context()
|
||||
# An example for shared memory allocation
|
||||
const_a = 0.5
|
||||
const_b = 1.0
|
||||
const_c = 2.0
|
||||
veify_allocation_kernel(const_a, const_b, const_c)
|
||||
51
examples/python/CuTeDSL/cute/ffi/CMakeLists.txt
Normal file
51
examples/python/CuTeDSL/cute/ffi/CMakeLists.txt
Normal file
@ -0,0 +1,51 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cmake_minimum_required(VERSION 3.15)
|
||||
project(tensor)
|
||||
|
||||
# Find Python
|
||||
find_package(Python COMPONENTS Interpreter Development REQUIRED)
|
||||
|
||||
# Get Python site-packages directory using Python
|
||||
execute_process(
|
||||
COMMAND ${Python_EXECUTABLE} -c "import site; print(site.getsitepackages()[0])"
|
||||
OUTPUT_VARIABLE Python_SITE_PACKAGES
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
|
||||
message(STATUS "Python site-packages directory: ${Python_SITE_PACKAGES}")
|
||||
|
||||
# Add nanobind path to CMAKE_PREFIX_PATH
|
||||
list(APPEND CMAKE_PREFIX_PATH ${Python_SITE_PACKAGES}/nanobind/cmake)
|
||||
|
||||
# Find nanobind
|
||||
find_package(nanobind REQUIRED)
|
||||
|
||||
# Add the module
|
||||
nanobind_add_module(tensor tensor.cpp)
|
||||
305
examples/python/CuTeDSL/cute/ffi/jit_argument.py
Normal file
305
examples/python/CuTeDSL/cute/ffi/jit_argument.py
Normal file
@ -0,0 +1,305 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
"""Example of accessing POD (Plain Old Data) from C or other languages via LLVM operations.
|
||||
|
||||
This example demonstrates a basic approach to building customized interfaces as C-structures between user code
|
||||
and JIT compiled functions. It provides a minimal-cost solution for calling JIT functions
|
||||
and can be used to build AOT (Ahead-of-Time) launchers for JIT compiled functions.
|
||||
|
||||
The C-structure is defined as:
|
||||
|
||||
.. code-block:: c
|
||||
|
||||
struct Tensor {
|
||||
void *ptr; // Pointer to tensor data
|
||||
int32_t shape[3]; // Tensor dimensions
|
||||
int32_t strides[3]; // Memory strides for each dimension
|
||||
};
|
||||
|
||||
The example defines Tensor and TensorValue classes that wrap C structs for view of a tensor with its data pointer,
|
||||
shape, and strides, enabling efficient data passing between different language boundaries.
|
||||
|
||||
.. note::
|
||||
Future development may include automated code generation flows.
|
||||
"""
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
|
||||
from cutlass._mlir import ir
|
||||
from cutlass._mlir.dialects import llvm
|
||||
import cutlass._mlir.extras.types as T
|
||||
|
||||
|
||||
class ExampleTensorValue(ir.Value):
|
||||
"""A wrapper class for tensor values in MLIR.
|
||||
|
||||
This class extends ir.Value to provide convenient access to tensor data pointer,
|
||||
shape, and strides through MLIR operations.
|
||||
|
||||
:type: ir.Value
|
||||
"""
|
||||
|
||||
def __init__(self, v):
|
||||
"""Initialize a new TensorValue.
|
||||
|
||||
:param v: The underlying MLIR value to wrap
|
||||
:type v: ir.Value
|
||||
"""
|
||||
super().__init__(v)
|
||||
|
||||
@property
|
||||
def data_ptr(self, *, loc=None, ip=None):
|
||||
"""Get the data pointer from the tensor value.
|
||||
|
||||
Extracts the data pointer (first field) from the LLVM struct value.
|
||||
|
||||
:param loc: Optional location information for MLIR operations
|
||||
:type loc: Optional[ir.Location]
|
||||
:param ip: Optional insertion point for MLIR operations
|
||||
:type ip: Optional[ir.InsertionPoint]
|
||||
:return: An integer value representing the data pointer
|
||||
:rtype: ir.Value
|
||||
"""
|
||||
# Extract the data pointer from the LLVM struct value
|
||||
# The data pointer is the first field (index 0) in the struct
|
||||
|
||||
# Use llvm.extractvalue to get the pointer field from the struct
|
||||
ptr_val = llvm.extractvalue(
|
||||
llvm.PointerType.get(),
|
||||
self,
|
||||
[0], # Extract the first field (index 0)
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
return cute.make_ptr(cutlass.Float32, ptr_val)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
"""Get the shape of the tensor.
|
||||
|
||||
Extracts the shape (second field) from the LLVM struct value.
|
||||
|
||||
:return: A tuple of integers representing the tensor dimensions
|
||||
:rtype: tuple[ir.Value, ...]
|
||||
"""
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
# Extract the shape field from the LLVM struct value
|
||||
# The shape is the second field (index 1) in the struct
|
||||
shape_val = llvm.extractvalue(
|
||||
llvm.StructType.get_literal([i32_type] * 3),
|
||||
self,
|
||||
[1], # Extract the second field (index 1)
|
||||
)
|
||||
|
||||
# Extract each dimension from the shape struct
|
||||
return tuple(llvm.extractvalue(i32_type, shape_val, [i]) for i in range(3))
|
||||
|
||||
@property
|
||||
def stride(self):
|
||||
"""Get the strides of the tensor.
|
||||
|
||||
Extracts the strides (third field) from the LLVM struct value.
|
||||
|
||||
:return: A tuple of integers representing the tensor strides
|
||||
:rtype: tuple[ir.Value, ...]
|
||||
"""
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
# Extract the strides field from the LLVM struct value
|
||||
# The strides are the third field (index 2) in the struct
|
||||
strides_val = llvm.extractvalue(
|
||||
llvm.StructType.get_literal([i32_type] * 3),
|
||||
self,
|
||||
[2], # Extract the third field (index 2)
|
||||
)
|
||||
|
||||
# Extract each dimension from the strides struct
|
||||
return tuple(llvm.extractvalue(i32_type, strides_val, [i]) for i in range(3))
|
||||
|
||||
|
||||
class ExampleTensor:
|
||||
"""A class representing a tensor with its data pointer, shape, and strides.
|
||||
|
||||
This class provides a Python interface to create and manipulate tensor structures
|
||||
that can be passed to CUTE JIT compiled functions.
|
||||
|
||||
:ivar _c_struct_p: The C struct pointer for the tensor
|
||||
:ivar _rank: The number of dimensions in the tensor
|
||||
"""
|
||||
|
||||
def __init__(self, c_struct_p, rank):
|
||||
"""Initialize a new Tensor.
|
||||
|
||||
:param c_struct_p: The C struct pointer for the tensor
|
||||
:type c_struct_p: int
|
||||
:param rank: The number of dimensions in the tensor
|
||||
:type rank: int
|
||||
"""
|
||||
self._c_struct_p = c_struct_p
|
||||
self._rank = rank
|
||||
|
||||
def __get_mlir_types__(self):
|
||||
"""Get the MLIR types for this tensor.
|
||||
|
||||
Creates an LLVM structure type representing a C-structure with:
|
||||
|
||||
.. code-block:: c
|
||||
|
||||
struct Tensor {
|
||||
void *ptr;
|
||||
int32_t shape[3];
|
||||
int32_t strides[3];
|
||||
};
|
||||
|
||||
:return: A list containing the MLIR struct type
|
||||
:rtype: list[llvm.StructType]
|
||||
|
||||
Create an LLVM structure type that represents a C-structure like:
|
||||
"""
|
||||
|
||||
# Get the number of dimensions from the shape
|
||||
ndim = self._rank
|
||||
|
||||
# Create the pointer type (void*)
|
||||
ptr_type = llvm.PointerType.get()
|
||||
|
||||
# Create array types for shape and strides (int32_t[ndim])
|
||||
int32_type = ir.IntegerType.get_signless(32)
|
||||
shape_type = llvm.StructType.get_literal([int32_type] * ndim)
|
||||
strides_type = llvm.StructType.get_literal([int32_type] * ndim)
|
||||
|
||||
# Create the structure type
|
||||
struct_type = llvm.StructType.get_literal([ptr_type, shape_type, strides_type])
|
||||
|
||||
return [struct_type]
|
||||
|
||||
def __new_from_mlir_values__(self, values):
|
||||
"""Create a new TensorValue from MLIR values.
|
||||
|
||||
:param values: A list of MLIR values
|
||||
:type values: list[ir.Value]
|
||||
:return: A new TensorValue instance
|
||||
:rtype: TensorValue
|
||||
"""
|
||||
return ExampleTensorValue(values[0])
|
||||
|
||||
def __c_pointers__(self):
|
||||
"""Get the C pointers for this tensor.
|
||||
|
||||
:return: A list containing the C struct pointer
|
||||
:rtype: list[int]
|
||||
"""
|
||||
return [self._c_struct_p]
|
||||
|
||||
|
||||
@cute.jit
|
||||
def foo(tensor):
|
||||
"""Example JIT function that prints tensor information.
|
||||
|
||||
:param tensor: A Tensor instance to print information about
|
||||
:type tensor: Tensor
|
||||
"""
|
||||
cute.printf("data_ptr: {}", tensor.data_ptr)
|
||||
cute.printf("shape: {}", tensor.shape)
|
||||
cute.printf("stride: {}", tensor.stride)
|
||||
|
||||
mA = cute.make_tensor(
|
||||
tensor.data_ptr, cute.make_layout(tensor.shape, stride=tensor.stride)
|
||||
)
|
||||
cute.print_tensor(mA)
|
||||
|
||||
|
||||
import sys
|
||||
import os
|
||||
import subprocess
|
||||
import shutil
|
||||
import tempfile
|
||||
import torch
|
||||
|
||||
|
||||
def run_test(tmpdir=None):
|
||||
# Skip cleanup if user provides tmpdir
|
||||
cleanup = tmpdir is None
|
||||
# Initialize temporary build directory
|
||||
tmpdir = tmpdir or tempfile.mkdtemp()
|
||||
|
||||
try:
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
subprocess.run(["cmake", "-B", tmpdir, current_dir], check=True)
|
||||
subprocess.run(["cmake", "--build", tmpdir], check=True)
|
||||
|
||||
sys.path.append(tmpdir)
|
||||
|
||||
from tensor import make_tensor, pycapsule_get_pointer
|
||||
|
||||
# Mock test tensor and corresponding C structure for this example
|
||||
# In production, this may come from external library
|
||||
x = torch.arange(2 * 8 * 4).to(torch.float32).reshape(2, 8, 4)
|
||||
c_struct = make_tensor(x.data_ptr(), x.shape, x.stride())
|
||||
c_struct_p = pycapsule_get_pointer(c_struct)
|
||||
|
||||
# Initialize tensor wrapper and compile test function
|
||||
tensor = ExampleTensor(c_struct_p, len(x.shape))
|
||||
compiled_func = cute.compile(foo, tensor)
|
||||
|
||||
# Benchmark pointer access performance
|
||||
from time import time
|
||||
|
||||
start = time()
|
||||
# Measure performance of critical path pointer access
|
||||
# get C pointers is on critical path to call JIT compiled function
|
||||
for _ in range(1000):
|
||||
tensor.__c_pointers__()
|
||||
end = time()
|
||||
print(f"__c_pointers__: {(end - start) * 1000} us")
|
||||
|
||||
# Execute compiled function
|
||||
compiled_func(tensor)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
finally:
|
||||
if cleanup:
|
||||
# Clean up the temporary directory
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Set temporary directory for building C modules"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tmp-dir", type=str, help="Temporary directory path for building C modules"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
run_test(args.tmp_dir)
|
||||
82
examples/python/CuTeDSL/cute/ffi/tensor.cpp
Normal file
82
examples/python/CuTeDSL/cute/ffi/tensor.cpp
Normal file
@ -0,0 +1,82 @@
|
||||
// Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are met:
|
||||
|
||||
// 1. Redistributions of source code must retain the above copyright notice,
|
||||
// this list of conditions and the following disclaimer.
|
||||
|
||||
// 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
// this list of conditions and the following disclaimer in the documentation
|
||||
// and/or other materials provided with the distribution.
|
||||
|
||||
// 3. Neither the name of the copyright holder nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from
|
||||
// this software without specific prior written permission.
|
||||
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
// POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
#include <cstdint>
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
namespace nb = nanobind;
|
||||
|
||||
// Forward declaration of the MockTensor struct for testing only
|
||||
struct MockTensor {
|
||||
void *ptr;
|
||||
struct {
|
||||
int32_t shape[3];
|
||||
} shape;
|
||||
|
||||
struct {
|
||||
int32_t strides[3];
|
||||
} strides;
|
||||
};
|
||||
|
||||
NB_MODULE(tensor, m) {
|
||||
// create a tensor for testing
|
||||
m.def("make_tensor", [](int64_t ptr, std::vector<int32_t> shape,
|
||||
std::vector<int32_t> strides) {
|
||||
auto *tensor = new MockTensor();
|
||||
tensor->ptr = reinterpret_cast<void *>(ptr);
|
||||
|
||||
assert(shape.size() == 3 && "shape must have 3 elements");
|
||||
assert(strides.size() == 3 && "strides must have 3 elements");
|
||||
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
tensor->shape.shape[i] = shape[i];
|
||||
tensor->strides.strides[i] = strides[i];
|
||||
}
|
||||
|
||||
return nb::steal(PyCapsule_New(tensor, "tensor", [](PyObject *capsule) {
|
||||
auto n = PyCapsule_GetName(capsule);
|
||||
if (void *p = PyCapsule_GetPointer(capsule, n)) {
|
||||
delete reinterpret_cast<MockTensor *>(p);
|
||||
}
|
||||
}));
|
||||
});
|
||||
|
||||
m.def(
|
||||
"pycapsule_get_pointer",
|
||||
[](nb::object &capsule) {
|
||||
void *ptr = PyCapsule_GetPointer(capsule.ptr(), "tensor");
|
||||
if (!ptr) {
|
||||
throw std::runtime_error("Invalid tensor capsule");
|
||||
}
|
||||
return reinterpret_cast<uintptr_t>(ptr);
|
||||
},
|
||||
"Get pointer from PyCapsule");
|
||||
}
|
||||
1486
examples/python/CuTeDSL/hopper/dense_gemm.py
Normal file
1486
examples/python/CuTeDSL/hopper/dense_gemm.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -83,11 +83,6 @@
|
||||
"\n",
|
||||
" # Print hello world from host code\n",
|
||||
" cute.printf(\"hello world\")\n",
|
||||
" \n",
|
||||
" # Initialize CUDA context for launching a kernel with error checking\n",
|
||||
" # We make context initialization explicit to allow users to control the context creation \n",
|
||||
" # and avoid potential issues with multiple contexts\n",
|
||||
" cutlass.cuda.initialize_cuda_context()\n",
|
||||
"\n",
|
||||
" # Launch kernel\n",
|
||||
" kernel().launch(\n",
|
||||
@ -129,6 +124,11 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Initialize CUDA context for launching a kernel with error checking\n",
|
||||
"# We make context initialization explicit to allow users to control the context creation \n",
|
||||
"# and avoid potential issues with multiple contexts\n",
|
||||
"cutlass.cuda.initialize_cuda_context()\n",
|
||||
"\n",
|
||||
"# Method 1: Just-In-Time (JIT) compilation - compiles and runs the code immediately\n",
|
||||
"print(\"Running hello_world()...\")\n",
|
||||
"hello_world()\n",
|
||||
@ -136,6 +136,7 @@
|
||||
"# Method 2: Compile first (useful if you want to run the same code multiple times)\n",
|
||||
"print(\"Compiling...\")\n",
|
||||
"hello_world_compiled = cute.compile(hello_world)\n",
|
||||
"\n",
|
||||
"# Run the pre-compiled version\n",
|
||||
"print(\"Running compiled version...\")\n",
|
||||
"hello_world_compiled()"
|
||||
|
||||
Reference in New Issue
Block a user