v4.0 update. (#2371)

This commit is contained in:
Junkai-Wu
2025-06-06 14:39:20 +08:00
committed by GitHub
parent 2e2af190bd
commit 8bdbfca682
254 changed files with 29751 additions and 1980 deletions

View File

@ -42,7 +42,6 @@
#include "cute/algorithm/functional.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cute/algorithm/gemm.hpp"
#include "cute/tensor_predicate.hpp"
#include "cute/numeric/arithmetic_tuple.hpp"
#include "cutlass/arch/grid_dependency_control.h"
@ -288,7 +287,7 @@ struct CollectiveMma<
constexpr int tma_alignment_bits = 128;
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M,N,K,L] = problem_shape_MNKL;
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
bool implementable = cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
@ -445,7 +444,7 @@ struct CollectiveMma<
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b, cute::TMA::CacheHintSm90::EVICT_LAST), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
++k_tile_iter;
if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) {
if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) {
launch_dep_grids = true;
cutlass::arch::launch_dependent_grids();
}
@ -453,7 +452,7 @@ struct CollectiveMma<
// Advance smem_pipe_write
++smem_pipe_write;
}
if (!disable_gdc && !launch_dep_grids) {
if (!disable_gdc && !launch_dep_grids) {
cutlass::arch::launch_dependent_grids();
}
}
@ -533,7 +532,7 @@ struct CollectiveMma<
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
++k_tile_iter;
if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) {
if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) {
launch_dep_grids = true;
cutlass::arch::launch_dependent_grids();
}
@ -541,7 +540,7 @@ struct CollectiveMma<
// Advance smem_pipe_write
++smem_pipe_write;
}
if (!disable_gdc && !launch_dep_grids) {
if (!disable_gdc && !launch_dep_grids) {
cutlass::arch::launch_dependent_grids();
}
}
@ -634,9 +633,9 @@ struct CollectiveMma<
// Issue the epilogue waits
if (lane_predicate) {
/* This helps avoid early exit of blocks in Cluster
* Waits for all stages to either be released (all
* Waits for all stages to either be released (all
* Consumer UNLOCKs), or if the stage was never used
* then would just be acquired since the phase was
* then would just be acquired since the phase was
* still inverted from make_producer_start_state
*/
pipeline.producer_tail(smem_pipe_write);
@ -854,7 +853,7 @@ struct CollectiveMma<
k_tile_count -= prologue_mma_count;
smem_pipe_release.advance(k_tile_count);
// Wait on all GMMAs to complete
warpgroup_wait<0>();

View File

@ -133,7 +133,7 @@ using TP = _8;
static constexpr int TP_ = TP{};
#if defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && \
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4))
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
// Distributed GEMM tiling/sharding schedule
// Choices:
@ -252,7 +252,8 @@ HostTensorB tensor_B_arr[TP_];
HostTensorD tensor_C_arr[TP_];
HostTensorD tensor_D_arr[TP_];
#endif // (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
#endif // (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) &&
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
@ -345,7 +346,7 @@ struct Result {
};
#if defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && \
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4))
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
@ -803,17 +804,18 @@ int run(Options &options) {
return 0;
}
#endif // (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
#endif // (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) &&
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA Toolkit 12.4 or newer to run this example
// CUTLASS must be compiled with CUDA Toolkit 12.6 or newer to run this example
// and must have compute capability at least 90.
// Some necessary cuda graph APIs were only introduced in CUDA 12.4.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4)) {
std::cerr << "This example requires CUDA 12.4 or newer." << std::endl;
// Some necessary cuda graph APIs were only introduced in CUDA 12.6.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 6)) {
std::cerr << "This example requires CUDA 12.6 or newer." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
@ -857,11 +859,11 @@ int main(int argc, char const **args) {
// Evaluate CUTLASS kernels
//
#if (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
#if (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6)))
run(options);
#else
std::cerr
<< "This example must be compiled with `sm90a` and CUDA Toolkit 12.4 or later." << std::endl;
<< "This example must be compiled with `sm90a` and CUDA Toolkit 12.6 or later." << std::endl;
return 0;
#endif

View File

@ -250,8 +250,6 @@ cutlass::DeviceAllocation<ElementAccumulator> block_beta;
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90GroupParams<Shape<int,int,int>>::RasterOrderOptions;
/// Result structure
struct Result
{
@ -518,7 +516,7 @@ GemmArguments args_from_options(const OptionType &options, bool host_problem_sha
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
}
arguments.scheduler.raster_order = options.raster;
arguments.scheduler.raster_order = options.raster_order;
// The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8)
arguments.scheduler.max_swizzle_size = options.swizzle;
@ -690,10 +688,10 @@ int run(OptionType &options, bool host_problem_shapes_available = true)
std::string raster = "Heuristic";
if (options.raster == RasterOrderOptions::AlongN) {
if (options.raster_order == RasterOrderOptions::AlongN) {
raster = "Along N";
}
else if (options.raster == RasterOrderOptions::AlongM) {
else if (options.raster_order == RasterOrderOptions::AlongM) {
raster = "Along M";
}
@ -747,7 +745,7 @@ int main(int argc, char const **args) {
// Parse options
//
Options<RasterOrderOptions, ProblemShape> options;
Options<ProblemShape> options;
options.parse(argc, args);

View File

@ -253,8 +253,6 @@ cutlass::DeviceAllocation<ElementAccumulator> block_beta;
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90GroupParams<Shape<int,int,int>>::RasterOrderOptions;
/// Result structure
struct Result
{
@ -523,7 +521,7 @@ GemmArguments args_from_options(const OptionType &options, bool host_problem_sha
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
}
arguments.scheduler.raster_order = options.raster;
arguments.scheduler.raster_order = options.raster_order;
// The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8)
arguments.scheduler.max_swizzle_size = options.swizzle;
@ -699,10 +697,10 @@ int run(OptionType &options, bool host_problem_shapes_available = true)
std::string raster = "Heuristic";
if (options.raster == RasterOrderOptions::AlongN) {
if (options.raster_order == RasterOrderOptions::AlongN) {
raster = "Along N";
}
else if (options.raster == RasterOrderOptions::AlongM) {
else if (options.raster_order == RasterOrderOptions::AlongM) {
raster = "Along M";
}
@ -755,7 +753,7 @@ int main(int argc, char const **args) {
// Parse options
//
Options<RasterOrderOptions, ProblemShape> options;
Options<ProblemShape> options;
options.parse(argc, args);

View File

@ -30,10 +30,9 @@
**************************************************************************************************/
// Command line options parsing
template<typename _RasterOrderOptions, typename _ProblemShape>
using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
template<typename _ProblemShape>
struct Options {
using RasterOrderOptions = _RasterOrderOptions;
using ProblemShape = _ProblemShape;
bool help = false;
@ -50,7 +49,7 @@ struct Options {
int const m_alignment = 128;
int const n_alignment = 128;
RasterOrderOptions raster;
RasterOrderOptions raster_order;
int swizzle;
// Parses the command line
@ -74,13 +73,13 @@ struct Options {
cmd.get_cmd_line_argument("raster", raster_char);
if (raster_char == 'N' || raster_char == 'n') {
raster = RasterOrderOptions::AlongN;
raster_order = RasterOrderOptions::AlongN;
}
else if (raster_char == 'M' || raster_char == 'm') {
raster = RasterOrderOptions::AlongM;
raster_order = RasterOrderOptions::AlongM;
}
else if (raster_char == 'H' || raster_char == 'h') {
raster = RasterOrderOptions::Heuristic;
raster_order = RasterOrderOptions::Heuristic;
}
cmd.get_cmd_line_argument("swizzle", swizzle, 1);

View File

@ -543,7 +543,7 @@ int run(Options &options) {
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example
// CUTLASS must be compiled with CUDA 12.8 Toolkit or newer to run this example
// and must have compute capability at least 100.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
@ -560,7 +560,6 @@ int main(int argc, char const **args) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
return 0;
}
//
// Parse options
//

View File

@ -237,7 +237,7 @@ cutlass::DeviceAllocation<ElementAccumulator> block_beta;
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100GroupParams<typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
// Command line options parsing
struct Options {

View File

@ -300,7 +300,7 @@ auto make_iterator(T* ptr) {
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100GroupParams<typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
// Command line options parsing
struct Options {

View File

@ -490,7 +490,7 @@ int run(Options &options)
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
// CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example
// and must have compute capability at least 90.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
@ -503,11 +503,11 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
return 0;
}
}
//
// Parse options
//

View File

@ -490,7 +490,7 @@ int run(Options &options)
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
// CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example
// and must have compute capability at least 90.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
@ -503,11 +503,11 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
return 0;
}
}
//
// Parse options
//

View File

@ -499,11 +499,11 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
return 0;
}
//
// Parse options
//

View File

@ -117,15 +117,17 @@ struct Options {
int q = 256;
int k = 256;
int d = 128;
int warmup_iterations = 1;
int iterations = 3;
int tensor_ring_buffers = 1;
bool verify = false;
bool verbose = false;
bool causal = false;
bool residual = false;
bool varlen = false;
bool persistent = false;
int sm_count = 0;
std::string kernel_filter;
InitStyle init_style_q = InitStyle::kRandom;
@ -189,10 +191,15 @@ struct Options {
if (b == -1) b = 16384 / k;
if (b == 0) b = 1;
cmd.get_cmd_line_argument("warmup_iterations", warmup_iterations, defaults.warmup_iterations);
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
cmd.get_cmd_line_argument("tensor_ring_buffers", tensor_ring_buffers, defaults.tensor_ring_buffers);
verify = cmd.check_cmd_line_flag("verify");
verbose = cmd.check_cmd_line_flag("verbose");
varlen = cmd.check_cmd_line_flag("varlen");
persistent = cmd.check_cmd_line_flag("persistent");
std::string mask;
cmd.get_cmd_line_argument<std::string>("mask", mask, "");
if (mask == "no" || mask == "") {
@ -210,7 +217,6 @@ struct Options {
causal = false;
}
cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count);
get_init_style_argument(cmd, "init-style", init_style_q, defaults.init_style_q);
get_init_style_argument(cmd, "init-style", init_style_k, defaults.init_style_q);
get_init_style_argument(cmd, "init-style", init_style_v, defaults.init_style_q);
@ -235,10 +241,13 @@ struct Options {
<< " --q=<int> Sets the Q extent\n"
<< " --k=<int> Sets the K extent\n"
<< " --d=<int> Sets the D extentn"
<< " --tensor_ring_buffers=<int> Sets the number of tensor ring buffers\n"
<< " --warmup_iterations=<int> Sets the warmup iterations\n"
<< " --iterations=<int> Benchmarking iterations\n"
<< " --verify Verify results\n"
<< " --verbose Print smem and execution time per kernel\n"
<< " --mask=<no|residual|causal> Enables masking\n"
<< " --persistent Enables persistent scheduler\n"
<< " --varlen Enables variable sequence length\n"
<< " B*Q and B*K become the total sequence length\n"
<< " and are split B-ways, alternatingly +10% and -10%\n"
@ -379,40 +388,55 @@ struct FwdRunner {
StrideLSE stride_LSE;
uint64_t seed = 0;
DeviceAllocation<Element> block_Q;
DeviceAllocation<Element> block_K;
DeviceAllocation<Element> block_V;
DeviceAllocation<ElementOut> block_O;
DeviceAllocation<ElementAccumulatorPV> block_LSE;
DeviceAllocation<ElementOut> block_ref_O;
DeviceAllocation<ElementAccumulatorPV> block_ref_LSE;
struct DeviceBuffer {
DeviceAllocation<Element> block_Q;
DeviceAllocation<Element> block_K;
DeviceAllocation<Element> block_V;
DeviceAllocation<ElementOut> block_O;
DeviceAllocation<ElementAccumulatorPV> block_LSE;
DeviceAllocation<ElementOut> block_ref_O;
DeviceAllocation<ElementAccumulatorPV> block_ref_LSE;
DeviceAllocation<int> device_cumulative_seqlen_q;
DeviceAllocation<int> device_cumulative_seqlen_kv;
DeviceBuffer() = default;
DeviceBuffer(const DeviceBuffer&) = delete;
DeviceBuffer& operator=(const DeviceBuffer&) = delete;
size_t get_storage_size() const {
return block_Q.get_storage_size() + block_K.get_storage_size() + block_V.get_storage_size()
+ block_O.get_storage_size() + block_LSE.get_storage_size() + block_ref_O.get_storage_size()
+ block_ref_LSE.get_storage_size() + device_cumulative_seqlen_q.get_storage_size()
+ device_cumulative_seqlen_kv.get_storage_size();
}
};
std::vector<std::unique_ptr<DeviceBuffer>> buffers;
std::vector<int> cumulative_seqlen_q;
std::vector<int> cumulative_seqlen_kv;
DeviceAllocation<int> device_cumulative_seqlen_q;
DeviceAllocation<int> device_cumulative_seqlen_kv;
//
// Methods
//
bool verify(const ProblemShapeType& problem_shape) {
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()),
bool verify(const ProblemShapeType& problem_shape, DeviceBuffer& buffer) {
Tensor mQ = make_tensor(make_gmem_ptr(buffer.block_Q.get()),
select<0,2,3>(problem_shape),
stride_Q);
Tensor mK = make_tensor(make_gmem_ptr(block_K.get()),
Tensor mK = make_tensor(make_gmem_ptr(buffer.block_K.get()),
select<1,2,3>(problem_shape),
stride_K);
Tensor mV = make_tensor(make_gmem_ptr(block_V.get()),
Tensor mV = make_tensor(make_gmem_ptr(buffer.block_V.get()),
select<1,2,3>(problem_shape),
stride_V);
Tensor mO = make_tensor(make_gmem_ptr(block_ref_O.get()),
Tensor mO = make_tensor(make_gmem_ptr(buffer.block_ref_O.get()),
select<0,2,3>(problem_shape),
stride_O);
Tensor mLSE = make_tensor(make_gmem_ptr(block_ref_LSE.get()),
Tensor mLSE = make_tensor(make_gmem_ptr(buffer.block_ref_LSE.get()),
select<0,3>(problem_shape),
stride_LSE);
@ -431,7 +455,7 @@ struct FwdRunner {
// Check if output from CUTLASS kernel and reference kernel are equal or not
double max_diff = 0;
double mean_diff = 0;
reference_abs_diff(block_O, block_ref_O, max_diff, mean_diff);
reference_abs_diff(buffer.block_O, buffer.block_ref_O, max_diff, mean_diff);
bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
if (! passed_O) {
@ -439,14 +463,13 @@ struct FwdRunner {
<< " mean " << mean_diff << std::endl;
}
// reference_abs_diff(block_LSE, block_ref_LSE, max_diff, mean_diff);
reference_abs_diff(buffer.block_LSE, buffer.block_ref_LSE, max_diff, mean_diff);
bool passed_LSE = true; // future work
// bool passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
// if ( ! passed_LSE) {
// std::cerr << "failed LSE: max diff " << max_diff
// << " mean " << mean_diff << std::endl;
// }
bool passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
if ( ! passed_LSE) {
std::cerr << "failed LSE: max diff " << max_diff
<< " mean " << mean_diff << std::endl;
}
return passed_O && passed_LSE;
}
@ -559,50 +582,70 @@ struct FwdRunner {
get<1,1>(stride_LSE) = 0;
}
block_Q.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
block_K.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
block_V.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
block_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
block_LSE.reset(size(shape_LSE));
block_ref_O.reset(size(shape_QO));
block_ref_LSE.reset(size(shape_LSE));
auto buffer_init_fn = [&](auto& buffer) {
buffer.block_Q.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
buffer.block_K.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
buffer.block_V.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
buffer.block_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
buffer.block_LSE.reset(size(shape_LSE));
initialize_block(block_Q, seed + 2023, options.init_style_q);
initialize_block(block_K, seed + 2022, options.init_style_k);
initialize_block(block_V, seed + 2021, options.init_style_v);
initialize_block(buffer.block_Q, seed + 2023, options.init_style_q);
initialize_block(buffer.block_K, seed + 2022, options.init_style_k);
initialize_block(buffer.block_V, seed + 2021, options.init_style_v);
if ( ! cumulative_seqlen_q.empty()) {
device_cumulative_seqlen_q.reset(cumulative_seqlen_q.size());
device_cumulative_seqlen_q.copy_from_host(
cumulative_seqlen_q.data(), cumulative_seqlen_q.size());
}
if ( ! cumulative_seqlen_kv.empty()) {
device_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size());
device_cumulative_seqlen_kv.copy_from_host(
cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size());
if ( ! cumulative_seqlen_q.empty()) {
buffer.device_cumulative_seqlen_q.reset(cumulative_seqlen_q.size());
buffer.device_cumulative_seqlen_q.copy_from_host(
cumulative_seqlen_q.data(), cumulative_seqlen_q.size());
}
if ( ! cumulative_seqlen_kv.empty()) {
buffer.device_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size());
buffer.device_cumulative_seqlen_kv.copy_from_host(
cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size());
}
};
buffers.push_back(std::make_unique<DeviceBuffer>());
buffer_init_fn(*buffers.back());
int tensor_ring_buffers = options.tensor_ring_buffers;
for (int i = 1; i < tensor_ring_buffers; i++) {
buffers.push_back(std::make_unique<DeviceBuffer>());
buffer_init_fn(*buffers.back());
}
if constexpr (kIsVarlen) {
get<0>(problem_shape).cumulative_length = device_cumulative_seqlen_q.get();
get<1>(problem_shape).cumulative_length = device_cumulative_seqlen_kv.get();
get<0>(problem_shape).cumulative_length = buffers[0]->device_cumulative_seqlen_q.get();
get<1>(problem_shape).cumulative_length = buffers[0]->device_cumulative_seqlen_kv.get();
}
return problem_shape;
}
auto get_arguments(const ProblemShapeType& problem_shape, const cutlass::KernelHardwareInfo& hw_info, int buffer_index) {
auto problem_shape_ = problem_shape;
if constexpr (kIsVarlen) {
get<0>(problem_shape_).cumulative_length = buffers[buffer_index]->device_cumulative_seqlen_q.get();
get<1>(problem_shape_).cumulative_length = buffers[buffer_index]->device_cumulative_seqlen_kv.get();
}
typename Operation::Arguments arguments{
problem_shape_,
{ buffers[buffer_index]->block_Q.get(), stride_Q,
buffers[buffer_index]->block_K.get(), stride_K,
buffers[buffer_index]->block_V.get(), stride_V },
{ buffers[buffer_index]->block_O.get(), stride_O,
buffers[buffer_index]->block_LSE.get(), stride_LSE },
hw_info
};
return arguments;
}
ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
ProblemShapeType problem_shape = initialize(options);
typename Operation::Arguments arguments{
problem_shape,
{ block_Q.get(), stride_Q,
block_K.get(), stride_K,
block_V.get(), stride_V },
{ block_O.get(), stride_O,
block_LSE.get(), stride_LSE },
hw_info
};
int buffer_index = 0;
typename Operation::Arguments arguments = get_arguments(problem_shape, hw_info, buffer_index);
Operation op;
@ -630,11 +673,21 @@ struct FwdRunner {
}
// Run
status = op.run();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return example_result;
for (int i = 0; i < options.warmup_iterations; i++) {
status = op.run();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return example_result;
}
buffer_index = (buffer_index + 1) % buffers.size();
arguments = get_arguments(problem_shape, hw_info, buffer_index);
status = op.update(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to update the CUTLASS kernel's parameters. Last CUDA error is: "
<< std::endl;
return example_result;
}
}
cudaError_t result = cudaDeviceSynchronize();
@ -672,6 +725,14 @@ struct FwdRunner {
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return example_result;
}
buffer_index = (buffer_index + 1) % buffers.size();
arguments = get_arguments(problem_shape, hw_info, buffer_index);
status = op.update(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to update the CUTLASS kernel's parameters. Last CUDA error is: "
<< std::endl;
return example_result;
}
}
//
@ -734,10 +795,10 @@ struct FwdRunner {
// Verify that the result is correct
bool passed = true;
if (options.verify) {
passed = verify(problem_shape);
passed = verify(problem_shape, *buffers[0]);
if (passed) example_result.verified = true;
}
if (!passed) {
std::cerr << "Reference check failed" << std::endl;
return example_result;
@ -789,10 +850,14 @@ void run_fwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareIn
using HeadDim = _128;
// Persistent Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
// Individual Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
if (options.persistent) {
// Persistent Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
}
else {
// Individual Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
}
}
///////////////////////////////////////////////////////////////////////////////////////////////////
@ -818,10 +883,14 @@ void run_fwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInf
using HeadDim = _64;
// Persistent Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
// Individual Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
if (options.persistent) {
// Persistent Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
}
else {
// Individual Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
}
}
@ -845,10 +914,14 @@ void run_fwd_32(Mask fusion, Options const & options, cutlass::KernelHardwareInf
using HeadDim = _32;
#ifdef FP8
// Persistent Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
// Individual Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
if (options.persistent) {
// Persistent Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
}
else {
// Individual Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
}
#endif
}

View File

@ -59,7 +59,7 @@ using namespace cutlass::fmha::kernel;
///////////////////////////////////////////////////////////////////////////////////////////////////
enum class InitStyle {
kOne, kLinearStride128, kLinearStride1, kRandom, kNone
kOne, kLinearStride128, kLinearStride1, kRandom, kRandomLarge, kNone
};
///////////////////////////////////////////////////////////////////////////////////////////////////
@ -98,6 +98,9 @@ struct Options {
if (s == "r") {
dst = InitStyle::kRandom;
}
else if (s == "l") {
dst = InitStyle::kRandomLarge;
}
else if (s == "1") {
dst = InitStyle::kOne;
}
@ -203,6 +206,11 @@ void initialize_block(
block.get(), block.size(), seed, (Element) -1, (Element) 1);
break;
}
case InitStyle::kRandomLarge: {
cutlass::reference::device::BlockFillRandomGaussian(
block.get(), block.size(), seed, (Element) -1, (Element) 100);
break;
}
case InitStyle::kLinearStride1: {
std::vector<Element> data(block.size());
for (size_t i = 0; i < block.size() / 128; i ++) {

View File

@ -144,4 +144,23 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
target_compile_definitions(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE ${PREC_MACRO} SKIP_ATOMIC)
target_compile_options(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE -Xptxas -v)
endforeach()
# Add a target that builds all examples
add_custom_target(77_blackwell_fmha_all
DEPENDS
77_blackwell_fmha_fp8
77_blackwell_fmha_fp16
77_blackwell_fmha_gen_fp8
77_blackwell_fmha_gen_fp16
77_blackwell_mla_2sm_fp8
77_blackwell_mla_2sm_fp16
77_blackwell_mla_2sm_cpasync_fp8
77_blackwell_mla_2sm_cpasync_fp16
77_blackwell_mla_b2b_2sm_fp8
77_blackwell_mla_b2b_2sm_fp16
77_blackwell_fmha_bwd_fp8
77_blackwell_fmha_bwd_fp16
77_blackwell_fmha_bwd_sat_fp8
77_blackwell_fmha_bwd_sat_fp16
)
endif()

View File

@ -157,8 +157,8 @@ struct CausalMask : NoMask {
TileShape const& tile_shape,
ProblemSize const& problem_size) {
int trip_count = get_trip_count(blk_coord, tile_shape, problem_size);
return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));
int trip_count = get_trip_count(blk_coord, tile_shape, problem_size);
return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));
}
template<class BlkCoord, class TileShape, class ProblemSize>

View File

@ -42,7 +42,7 @@ template<
class ElementAcc,
class TileShape, // Q, D, _
class StrideO, // Q, D, B
class StrideLSE // Q, B
class StrideLSE_ // Q, B
>
struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
@ -54,6 +54,7 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
// using SmemLayoutAtomO = decltype(make_ordered_layout(select<0,1>(TileShape{}), Step<_1, _0>{}));
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, replace<2>(TileShape{}, _2{}), Step<_2, _1, _3>{}));
using SmemLayoutO_ = SmemLayoutO;
using StrideLSE = StrideLSE_;
struct TensorStorage {
@ -79,6 +80,9 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
struct Params {
TMA_O tma_store_o;
ElementAcc* ptr_LSE;
StrideLSE dLSE;
};
template<class ProblemShape>
@ -110,7 +114,9 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
);
return {
tma_store_o
tma_store_o,
args.ptr_LSE,
args.dLSE
};
}
@ -119,6 +125,10 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
cute::prefetch_tma_descriptor(params.tma_store_o.get_tma_descriptor());
}
const Params& params;
CUTLASS_DEVICE Sm100FmhaFwdEpilogueTmaWarpspecialized(const Params& params) : params(params) {}
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
CUTLASS_DEVICE auto
store(

View File

@ -531,7 +531,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
// Each thread owns a single row
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem
using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem
using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem
@ -613,7 +613,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
NumericArrayConverter<Element, ElementQK, kConversionsPerStep> convert;
const int kReleasePipeCount = 10; // must be multiple of 2
order_s.wait();
CUTLASS_PRAGMA_UNROLL
@ -637,7 +637,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
}
tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv);
if (i == size(tTMEM_LOADrS) - kReleasePipeCount) {
order_s.arrive();
}
@ -691,7 +691,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3);
cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2);
float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y;
row_sum = local_row_sum;
if (final_call) {
@ -787,14 +787,14 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
// good values would be either 32 or 64
const int kCorrectionTileSize = 32 / sizeof(ElementOut);
using TMEM_LOAD = std::conditional_t<kCorrectionTileSize == 32, SM100_TMEM_LOAD_32dp32b32x, SM100_TMEM_LOAD_32dp32b16x>; // 4x32 threads with 64 cols of 32b elem
using TMEM_LOAD = std::conditional_t<kCorrectionTileSize == 32, SM100_TMEM_LOAD_32dp32b32x, SM100_TMEM_LOAD_32dp32b16x>; // 4x32 threads with 64 cols of 32b elem
typename CollectiveMmaPV::TiledMma mma;
Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));
Tensor tOcO = mma.get_slice(0).partition_C(cO);
Tensor tOsO = mma.get_slice(0).partition_C(sO);
Tensor tOtO_i = logical_divide(tOtO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
Tensor tOcO_i = logical_divide(tOcO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
Tensor tOsO_i = logical_divide(tOsO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
@ -809,7 +809,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i(make_coord(_, _), _0{}));
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i(make_coord(_, _), _));
Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i(make_coord(_, _), _));
Tensor tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i(make_coord(_, _), _));
@ -824,9 +824,9 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i);
Tensor tTMrO = make_tensor<ElementPV>(shape(tTMEM_LOADcO(_, _0{}, _0{}, i)));
copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO);
#ifndef ONLY_SOFTMAX
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size(tTMrO); j += 2) {
@ -872,24 +872,24 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
// good values would be either 32 or 64
const int kCorrectionTileSize = 16;
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem
using TMEM_STORE = SM100_TMEM_STORE_32dp32b16x; // 4x32 threads with 64 cols of 32b elem
typename CollectiveMmaPV::TiledMma mma;
Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));
Tensor tOcO = mma.get_slice(0).partition_C(cO);
Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
tOtO_i.data() = tOtO_i.data().get() + tmem_O;
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i);
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i);
auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx);
Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i);
Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i);
Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i);
@ -899,7 +899,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
float2 scale_f32x2 = make_float2(scale, scale);
Tensor tTMrO = make_tensor<ElementPV>(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{}));
auto copy_in = [&](int i) {
Tensor tTMEM_LOADtO_i = tTMEM_LOADtO;
tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize);
@ -942,7 +942,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
}
}
template<class BlkCoord, class ProblemShape, class TensorStorageEpi>
template<class BlkCoord, class ProblemShape, class TensorStorageEpi, class CollectiveEpilogue>
CUTLASS_DEVICE auto
correction(
BlkCoord const& blk_coord,
@ -951,7 +951,8 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state,
PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state,
PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state,
PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state) {
PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state,
CollectiveEpilogue& epilogue) {
int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape);
@ -961,7 +962,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{}));
Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS);
Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{})));
Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));
@ -1060,13 +1061,25 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
// F2FP
// store to smem
Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{});
Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), repeat_like(typename CollectiveEpilogue::StrideLSE{}, _1{}), epilogue.params.dLSE);
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO);
if (epilogue.params.ptr_LSE != nullptr) {
int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord);
ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);
if (row_idx < get<0>(problem_shape)) {
gLSE(row_idx, get<2>(blk_coord)) = lse;
}
}
cutlass::arch::fence_view_async_tmem_load();
pipeline_o.consumer_release(pipeline_o_consumer_state);
++pipeline_o_consumer_state;
pipeline_epi.producer_commit(pipeline_epi_producer_state);
++pipeline_epi_producer_state;
@ -1083,6 +1096,16 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _1{}, sO);
if (epilogue.params.ptr_LSE != nullptr) {
int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{});
ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);
if (row_idx < get<0>(problem_shape)) {
gLSE(row_idx, get<2>(blk_coord)) = lse;
}
}
cutlass::arch::fence_view_async_tmem_load();
pipeline_o.consumer_release(pipeline_o_consumer_state);

View File

@ -118,7 +118,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
using TensorStrideContiguousK = Stride<int, _1, Stride<int, int>>;
using TensorStrideContiguousMN = Stride<_1, int, Stride<int, int>>;
// compute S
using CollectiveMmaKQ = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
@ -381,7 +381,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q, D, HB), args.mainloop.stride_dq_acc),
SmemLayoutDQ{}(_, _, _0{})
);
return Params{
args.problem_shape,
args.mainloop,
@ -452,7 +452,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
ThrMMA cta_mma_kq = TiledMmaKQ{}.get_slice(_0{});
ThrMMA cta_mma_vdo = TiledMmaVDO{}.get_slice(_0{});
auto tSTgK = cta_mma_kq.partition_A(gK);
auto tSTgQ = cta_mma_kq.partition_B(gQ);
auto tDPTgV = cta_mma_vdo.partition_A(gV);
@ -477,7 +477,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
group_modes<0,3>(sDO), group_modes<0,3>(tDPTgDO));
// set up lse and sum_odo
auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord;
pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state);
@ -495,7 +495,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
}
// load Q
if (cute::elect_one_sync()) {
if (cute::elect_one_sync()) {
cute::copy(
mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask),
tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch),
@ -520,7 +520,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
&mLSE(gmem_idx, blk_coord_batch),
gmem_idx < Q
);
pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive);
++pipeline_load_compute_lse_producer_state;
@ -529,7 +529,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state);
pipeline_load_mma_do.producer_expect_transaction(pipeline_load_mma_do_producer_state, kTransactionsBytesLoadV);
// load V
if (cute::elect_one_sync()) {
cute::copy(
@ -540,7 +540,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
}
// load dO
if (cute::elect_one_sync()) {
if (cute::elect_one_sync()) {
cute::copy(
mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask),
tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch),
@ -573,7 +573,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state);
// load Q
if (cute::elect_one_sync()) {
if (cute::elect_one_sync()) {
cute::copy(
mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask),
tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch),
@ -584,7 +584,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
++pipeline_load_mma_q_producer_state;
pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state);
// load LSE
smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4;
gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
@ -593,15 +593,15 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
&mLSE(gmem_idx, blk_coord_batch),
gmem_idx < Q
);
pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive);
++pipeline_load_compute_lse_producer_state;
pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state);
tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state);
// load dO
if (cute::elect_one_sync()) {
// load dO
if (cute::elect_one_sync()) {
cute::copy(
mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask),
tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch),
@ -612,7 +612,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
++pipeline_load_mma_do_producer_state;
pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state);
// load sum_OdO
smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4;
gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
@ -621,7 +621,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
&mSumOdO(gmem_idx, blk_coord_batch),
gmem_idx < Q
);
pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive);
++pipeline_load_compute_sum_odo_producer_state;
@ -639,23 +639,23 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
int iter_count,
MainloopArguments const& mainloop_args,
TensorStorage& shared_tensors,
PipelineLoadMmaQ& pipeline_load_mma_q,
typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state,
PipelineLoadMmaDO& pipeline_load_mma_do,
PipelineLoadMmaQ& pipeline_load_mma_q,
typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state,
PipelineLoadMmaDO& pipeline_load_mma_do,
typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_consumer_state,
PipelineMmaComputeS& pipeline_mma_compute_s,
PipelineMmaComputeS& pipeline_mma_compute_s,
typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_producer_state,
PipelineMmaComputeDP& pipeline_mma_compute_dp,
PipelineMmaComputeDP& pipeline_mma_compute_dp,
typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_producer_state,
PipelineMmaReduceDQ& pipeline_mma_reduce_dq,
PipelineMmaReduceDQ& pipeline_mma_reduce_dq,
typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_producer_state,
PipelineComputeMmaP& pipeline_compute_mma_p,
PipelineComputeMmaP& pipeline_compute_mma_p,
typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_consumer_state,
PipelineComputeMmaDS& pipeline_compute_mma_ds,
PipelineComputeMmaDS& pipeline_compute_mma_ds,
typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_consumer_state,
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) {
auto [Q, K, D, HB] = problem_shape;
auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{});
@ -685,7 +685,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
Tensor tDVrP = TiledMmaPDO::make_fragment_A(sP)(_, _, _, _0{});
tDVrP.data() = TmemAllocation::kP;
Tensor tDVrDOT = TiledMmaPDO::make_fragment_B(sDOT);
TiledMmaKQ tiled_mma_kq;
TiledMmaVDO tiled_mma_vdo;
TiledMmaDSK tiled_mma_dsk;
@ -923,6 +923,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
TensorC const& coord,
TensorShape const& tensor_shape) {
Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); });
auto copy_op = make_cotiled_copy(
Copy_Atom<UniversalCopy<uint128_t>, Element>{},
make_layout(make_shape(_1{}, Int<sizeof(uint128_t) / sizeof(Element)>{})),
@ -930,21 +932,11 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
);
auto thr_copy = copy_op.get_slice(_0{});
auto tCg = thr_copy.partition_D(gmem);
auto tCr = thr_copy.partition_S(quantize(regs));
auto tCc = thr_copy.partition_D(coord);
Tensor tCg = thr_copy.partition_D(gmem);
Tensor tCr = thr_copy.partition_S(quantize(regs));
Tensor tPc = thr_copy.partition_D(preds);
constexpr int R = decltype(tCr.layout())::rank;
auto tCg_v = group_modes<1, R>(tCg);
auto tCr_v = group_modes<1, R>(tCr);
auto tCc_v = group_modes<1, R>(tCc);
auto tCp_v = make_tensor<bool>(shape<1>(tCc_v));
for (int i = 0; i < size(tCp_v); ++i) {
tCp_v(i) = elem_less(tCc_v(_0{},i), tensor_shape);
}
copy_if(copy_op, tCp_v, tCr_v, tCg_v);
copy_if(copy_op, tPc, tCr, tCg);
}
@ -1073,7 +1065,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) {
auto [Q, K, D, HB] = problem_shape;
// in tmem, S & P overlap
@ -1114,7 +1106,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
Tensor tTR_cST = split_wg(thread_t2r.partition_D(cST));
Tensor tTR_rST = make_tensor<ElementAcc>(shape(tTR_cST));
Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST));
Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT);
Tensor tTR_cDPT = split_wg(tTR_cDPT_p);
Tensor tTR_rDPT = make_tensor<ElementAcc>(shape(tTR_cDPT));
@ -1152,20 +1144,20 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
fn(cute::false_type{});
}
};
dispatch_bool(std::is_base_of_v<cutlass::fmha::collective::CausalMask, Mask> &&
warp_uniform(iter_index == get<1>(blk_coord)), [&](auto is_causal_masked_tile) {
// compute P = softmax(S, LSE)
cute::copy(tiled_t2r, tTR_tST, tTR_rST);
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask, Mask> && decltype(is_causal_masked_tile)::value) {
Mask{}.apply_mask(tTR_rST, [&](int i) {
auto c_transpose = tTR_cST(i);
return make_coord(get<1>(c_transpose) + iter_index * TileShapeQ{}, get<0>(c_transpose) + get<1>(blk_coord) * TileShapeK{});
}, problem_shape);
}
ElementAcc log2_e = static_cast<ElementAcc>(M_LOG2E);
float2 softmax_scale_log2_e;
softmax_scale_log2_e.x = mainloop_args.softmax_scale * log2_e;
@ -1184,16 +1176,16 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
tTR_rST(i) = ::exp2f(out.x);
tTR_rST(i+1) = ::exp2f(out.y);
}
auto tRT_rST = quantize(tTR_rST);
auto tRT_rST_reshaped = make_tensor(tRT_rST.data(), shape(tRT_cST));
cutlass::arch::fence_view_async_tmem_load();
cutlass::arch::NamedBarrier(
kNumComputeWarps * NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::TransformBarrier
).arrive_and_wait();
cute::copy(tiled_r2t, tRT_rST_reshaped, tRT_tP);
});
@ -1293,9 +1285,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state,
PipelineReduceTmaStore& pipeline_reduce_tma_store,
typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state) {
using X = Underscore;
auto [Q, K, D, HB] = problem_shape;
auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord;
@ -1307,7 +1299,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
tDQtDQ.data() = TmemAllocation::kDQ;
Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB));
auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step<_1, _1, X>{})
auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step<X, _1, _1>{})
(_, _, _, _0{}, blk_coord_batch);
Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{}));
@ -1376,7 +1368,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
iter_index += 1;
}
}
CUTLASS_DEVICE void operator()(Params const& params, char* smem) {
int warp_idx = cutlass::canonical_warp_idx_sync();
@ -1561,7 +1553,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
typename decltype(pipeline_compute_mma_p)::PipelineState pipeline_compute_mma_p_consumer_state;
typename decltype(pipeline_compute_mma_ds)::PipelineState pipeline_compute_mma_ds_consumer_state;
typename decltype(pipeline_mma_compute_dkdv)::PipelineState pipeline_mma_compute_dkdv_consumer_state;
auto pipeline_load_mma_q_producer_state = make_producer_start_state<decltype(pipeline_load_mma_q)>();
auto pipeline_load_mma_do_producer_state = make_producer_start_state<decltype(pipeline_load_mma_do)>();
auto pipeline_load_compute_lse_producer_state = make_producer_start_state<decltype(pipeline_load_compute_lse)>();
@ -1587,7 +1579,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
if (role == WarpRole::Load) {
warpgroup_reg_set<RegisterAllocation::kLoad>();
load(
blk_coord,
problem_shape,
@ -1596,7 +1588,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
params.mainloop,
params.mainloop_params,
shared_storage.tensors,
pipeline_load_mma_q, pipeline_load_mma_q_producer_state,
pipeline_load_mma_q, pipeline_load_mma_q_producer_state,
pipeline_load_mma_do, pipeline_load_mma_do_producer_state,
pipeline_load_compute_lse, pipeline_load_compute_lse_producer_state,
pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_producer_state
@ -1608,7 +1600,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr);
__syncwarp();
mma(
blk_coord,
problem_shape,
@ -1616,7 +1608,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
iter_count,
params.mainloop,
shared_storage.tensors,
pipeline_load_mma_q, pipeline_load_mma_q_consumer_state,
pipeline_load_mma_q, pipeline_load_mma_q_consumer_state,
pipeline_load_mma_do, pipeline_load_mma_do_consumer_state,
pipeline_mma_compute_s, pipeline_mma_compute_s_producer_state,
pipeline_mma_compute_dp, pipeline_mma_compute_dp_producer_state,
@ -1629,7 +1621,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
}
else if (role == WarpRole::Compute) {
warpgroup_reg_set<RegisterAllocation::kCompute>();
compute(
blk_coord,
problem_shape,
@ -1660,7 +1652,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
}
else if (role == WarpRole::Reduce) {
warpgroup_reg_set<RegisterAllocation::kReduce>();
reduce(
blk_coord,
problem_shape,
@ -1677,9 +1669,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
}
else {
warpgroup_reg_set<RegisterAllocation::kEmpty>();
/* no-op */
}
}

View File

@ -356,7 +356,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineO>();
CollectiveMainloop mainloop;
CollectiveEpilogue epilogue;
CollectiveEpilogue epilogue{params.epilogue};
if (role == WarpRole::Softmax0 || role == WarpRole::Softmax1) {
warpgroup_reg_set<NumRegsSoftmax>();
@ -407,7 +407,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
pipeline_s0_corr, pipeline_s0_corr_consumer_state,
pipeline_s1_corr, pipeline_s1_corr_consumer_state,
pipeline_mma_corr, pipeline_mma_corr_consumer_state,
pipeline_corr_epi, pipeline_corr_epi_producer_state
pipeline_corr_epi, pipeline_corr_epi_producer_state,
epilogue
);

View File

@ -146,7 +146,7 @@ struct Sm100FmhaMlaReductionKernel {
ElementAcc sum_lse = 0;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNLsePerThread; ++i) {
sum_lse = sum_lse + expf(local_lse[i] - params.scale * lse_max);
sum_lse = sum_lse + expf(local_lse[i] - lse_max);
}
CUTLASS_PRAGMA_UNROLL
@ -156,7 +156,7 @@ struct Sm100FmhaMlaReductionKernel {
sum_lse = __shfl_sync(0xffffffff, sum_lse, 0);
ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementAcc>::infinity() : logf(sum_lse) + params.scale * lse_max;
ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementAcc>::infinity() : logf(sum_lse) + lse_max;
if (threadIdx.x == 0 and params.ptr_lse != nullptr) {
gLSE(0) = global_lse;
}

View File

@ -127,7 +127,7 @@ void __global__ fmha_reference_kernel(
mO(idx_Q + offset_Q, idx_D, idx_L) = static_cast<typename TensorO::value_type>(acc * scale);
}
if (threadIdx.x == 0) {
if (threadIdx.x == 0 && mLSE.data() != nullptr) {
mLSE(idx_Q + offset_Q, idx_L) = log(sum) + softmax_scale * maxS;
}

View File

@ -75,6 +75,8 @@ struct DeviceAllocation {
size_t size() const { return size_; }
size_t get_storage_size() const { return (size_ + offset_) * sizeof(T); }
void copy_from_host(const T* ptr, size_t sz) {
auto ret = cudaMemcpy(ptr_, ptr, sz * sizeof(T), cudaMemcpyDefault);
assert(ret == cudaSuccess);

View File

@ -280,7 +280,7 @@ auto make_iterator(T* ptr) {
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100GroupParams<typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
// Command line options parsing
struct Options {

View File

@ -133,7 +133,7 @@ using TP = _8;
static constexpr int TP_ = TP{};
#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && \
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4))
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
// Distributed GEMM tiling/sharding schedule
// Choices:
@ -254,7 +254,8 @@ HostTensorB tensor_B_arr[TP_];
HostTensorD tensor_C_arr[TP_];
HostTensorD tensor_D_arr[TP_];
#endif // (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
#endif // (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) &&
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
@ -347,7 +348,7 @@ struct Result {
};
#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && \
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4))
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
@ -805,17 +806,16 @@ int run(Options &options) {
return 0;
}
#endif // (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
#endif // (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) &&
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA Toolkit 12.4 or newer to run this example
// and must have compute capability at least 90.
// Some necessary cuda graph APIs were only introduced in CUDA 12.4.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4)) {
std::cerr << "This example requires CUDA 12.4 or newer." << std::endl;
// CUTLASS must be compiled with CUDA Toolkit 12.8 or newer to run Blackwell kernels.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
@ -861,11 +861,11 @@ int main(int argc, char const **args) {
// Evaluate CUTLASS kernels
//
#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)))
run(options);
#else
std::cerr
<< "This example must be compiled with `sm100a` and CUDA Toolkit 12.4 or later." << std::endl;
<< "This example must be compiled with `sm100a` and CUDA Toolkit 12.8 or later." << std::endl;
return 0;
#endif

View File

@ -14,8 +14,8 @@ cmake $PATH -DCUTLASS_NVCC_ARCHS="100a" -DCUTLASS_ENABLE_GDC_FOR_SM100=1
### Minimum software
Like all other CUTLASS examples, the NVIDIA driver, runtime, and CUDA Toolkit are required.
This example specifically requires CUDA Toolkit 12.6 or newer, due to some of the necessary
CUDA graph APIs.
This example specifically requires CUDA Toolkit 12.8 or newer, since that is the first version
supporting the Blackwell architecture.
### Hardware / driver settings

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,50 @@
# Copyright (c) 2014 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cutlass_example_add_executable(
88_hopper_fmha
88_hopper_fmha.cu
)
if(NOT WIN32 AND NOT CUTLASS_CLANG_HOST_COMPILE)
set_property(
SOURCE 88_hopper_fmha.cu
PROPERTY COMPILE_FLAGS "--use_fast_math"
)
cutlass_example_add_executable(
88_hopper_fmha_fp8
88_hopper_fmha.cu
)
target_compile_definitions(
88_hopper_fmha_fp8
PRIVATE FP8)
endif()

View File

@ -0,0 +1,77 @@
# CUTLASS Hopper FMHA Example
This sample showcases how to implement fused multi-head attention (FMHA) using
CUTLASS for the NVIDIA Hopper architecture. At its heart, the forward pass of
FMHA is a GEMM-online softmax-GEMM fusion, whereas the backward pass is a slightly
more complex structure (basically, a GEMM-softmax-2xGEMM-2xGEMM fusion).
For more information please refer to the [Flash Attention 3 paper](https://arxiv.org/abs/2407.08608).
The forward pass kernel supports head dims 32, 64, 128, and 256 for fp16 and bf16 input data types,
and head dims 128, and 256 for fp8.
All kernels use the Tensor Memory Accelerator for loads.
Kernels with head dims 128 and 256 have warp-specialized cooperative schedules.
Backward pass kernels (fp16 only) support head dims 32, 64, and 128, and all support
warp-specialized cooperative schedules.
## Customization
### Mask Fusion
Similar to the [Blackwell FMHA example](../77_blackwell_fmha/README.md), attention masks such as
causal masking can be fused into the kernel. To modify the code for such fusions,
`collective/fmha_fusion.hpp` provides the easiest customization point.
The `before_softmax` function is called with the accumulator of the first GEMM and the logical
positions of those elements. It is well-suited for applying masks or activations.
### MHA Variants
Using CuTe, it is easy to represent the various attention variants.
Where regular multi-head attention's layout for the head dimension is (numHeads:headStride),
for single-head attention it is simply (1:0) everywhere,
for GQA it is normal in Q and (numHeads/numGroups,numGroups:headStride,0) in KV,
and for MQA it is normal for Q and (numHeads:0) in KV.
As such, beyond general stride handling, no additional work is needed to support these,
and the example will just demonstrate regular multi-head attention.
### FP8
The warp-specialized forward kernel supports FP8 computation with both FP32 and FP16
accumulation for the Q*K product. They can be enabled in the runner by defining FP8.
## Performance
Forward pass kernels can generally come close to that of FA3, but backward pass
kernels are more limited in performance and are not expected to reach the same level of performance
as FA3.
# Copyright
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
```
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```

View File

@ -0,0 +1,863 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "../collective/fmha_common.hpp"
#include "../collective/fmha_collective_load.hpp"
#include "../collective/fmha_collective_softmax.hpp"
#include "../kernel/fmha_options.hpp"
namespace cutlass::fmha::collective {
template<
typename Element_,
typename ElementAccumulator_,
typename TileShape_, // BlockQO, BlockKV, BlockHead
class Fusion,
class... Options
>
struct FmhaBwdMainloopTmaWarpSpecialized {
using Element = Element_;
using ElementAccumulator = ElementAccumulator_;
using TileShape = TileShape_;
static constexpr bool kIsPersistent = false;
static const int NumLoadWarpGroups = 1;
static constexpr int NumMmaWarpGroups = 2;
static constexpr int StageCountQ = 2 /*K, V*/ * NumMmaWarpGroups;
static constexpr int StageCount = 2 /*Q, dO*/ * 2 /* actual stages */;
static const int kOuterLoads = 2;
using StagesQ = cutlass::gemm::collective::StageCount<StageCountQ>;
using Stages = cutlass::gemm::collective::StageCount<StageCount>;
using ClusterShape = Shape<_1, _1, _1>;
static_assert(StagesQ::value >= 2);
static_assert(Stages::value >= 2 * NumMmaWarpGroups);
// 16B alignment lets us use TMA
static constexpr int Alignment = 16 / sizeof(Element);
using TileShapeNM = Shape< // (N,M,D)
decltype(tuple_element_t<1, TileShape>{} / Int<NumMmaWarpGroups>{}),
tuple_element_t<0, TileShape>,
tuple_element_t<2, TileShape>>;
using TileShapeND = decltype(select<0,2,1>(TileShapeNM{})); // (N,D,M)
using TileShapeMD = decltype(select<2,1,0>(TileShapeND{})); // (M,D,N)
using CollectiveMmaNM = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Element, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment,
Element, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment,
ElementAccumulator,
TileShapeNM, ClusterShape, Stages,
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
using CollectiveMmaND = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Element, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment, // from register, doesn't matter
Element, cute::tuple<_1, int, cute::tuple<int, int>>, Alignment,
ElementAccumulator,
TileShapeND, ClusterShape, Stages,
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
using CollectiveMmaND_SS = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Element, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment, // from register, doesn't matter
Element, cute::tuple<_1, int, cute::tuple<int, int>>, Alignment,
ElementAccumulator,
TileShapeND, ClusterShape, Stages,
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
using CollectiveMmaMD = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Element, cute::tuple<_1, int, cute::tuple<int, int>>, Alignment, // from smem, might matter (?)
Element, cute::tuple<_1, int, cute::tuple<int, int>>, Alignment,
ElementAccumulator,
TileShapeMD, ClusterShape, Stages,
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
using TiledMmaNM = typename CollectiveMmaNM::TiledMma;
using TiledMmaND_SS = typename CollectiveMmaND_SS::TiledMma;
using TiledMmaND_RS = decltype(convert_to_gmma_rs(typename CollectiveMmaND::TiledMma{}));
using TiledMmaND = TiledMmaND_RS;
using TiledMmaMD = typename CollectiveMmaMD::TiledMma;
using SmemLayoutQ = typename CollectiveMmaNM::SmemLayoutB;
using SmemLayoutK = typename CollectiveMmaNM::SmemLayoutA;
using SmemLayoutV = typename CollectiveMmaNM::SmemLayoutA;
using SmemLayoutDO = typename CollectiveMmaNM::SmemLayoutB;
//using SmemLayoutDQ = Layout<
// Shape<
// tuple_element_t<0, TileShapeMD>,
// Shape<_2, _4, decltype(tuple_element_t<1, TileShapeMD>{} / _8{})>,
// _2
// >,
// Stride<
// _4,
// Stride<decltype(tuple_element_t<0, TileShapeMD>{} * _4{}), _1, decltype(tuple_element_t<0, TileShapeMD>{} * _8{})>,
// decltype(tuple_element_t<0, TileShapeMD>{} * tuple_element_t<1, TileShapeMD>{})
// >>;
using SmemLayoutDQ_0 = Layout<
Shape<
tuple_element_t<0, TileShapeMD>,
tuple_element_t<1, TileShapeMD>,
_2
>,
Stride<
tuple_element_t<1, TileShapeMD>,
_1,
decltype(tuple_element_t<0, TileShapeMD>{} * tuple_element_t<1, TileShapeMD>{})
>>;
using SmemAtomDQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
cute::GMMA::Major::K, ElementAccumulator, tuple_element_t<0, TileShapeMD>, tuple_element_t<1, TileShapeMD>>());
using SmemLayoutDQ_1 = decltype(tile_to_shape(SmemAtomDQ{}, make_shape(get<0>(TileShapeMD{}), get<1>(TileShapeMD{}), _2{}), Step<_2, _1, _3>{}));
using SmemLayoutDQ = SmemLayoutDQ_1;
using PipelineDQ = cutlass::PipelineAsync<2>;
using SmemLayoutDS_0 = decltype(unstageSmemLayout(typename CollectiveMmaMD::SmemLayoutA{}, Int<NumMmaWarpGroups>{}));
using SmemLayoutDS = decltype(tile_to_shape(GMMA::Layout_MN_INTER_Atom<Element>{}, make_shape(size<0>(SmemLayoutDS_0{}), size<1>(SmemLayoutDS_0{}), size<2>(SmemLayoutDS_0{})), Step<_1, _2, _3>{}));
using SmemLayoutKp = typename CollectiveMmaMD::SmemLayoutB;
using SmemLayoutQp = typename CollectiveMmaND::SmemLayoutB;
using SmemLayoutDOp = typename CollectiveMmaND::SmemLayoutB;
using SmemLayoutLSE = Layout<Shape<tuple_element_t<1, TileShapeNM>, Int<StageCount>>>;
using MainloopPipeline = cutlass::PipelineTmaAsync<Stages::value>;
using MainloopPipelineQ = cutlass::PipelineTmaAsync<StagesQ::value>;
using PipelineState = typename cutlass::PipelineState<MainloopPipeline::Stages>;
using PipelineStateQ = typename cutlass::PipelineState<MainloopPipelineQ::Stages>;
using TileShapePV = TileShapeND; // To work with the kernel level
using TiledMmaPV = TiledMmaND;
static constexpr int kInnerLoadBytes = size(SmemLayoutQ{}(_,_,_0{})) * sizeof(Element) + size(SmemLayoutLSE{}(_,_0{})) * sizeof(ElementAccumulator);
static constexpr int kOuterLoadBytes = size(SmemLayoutK{}(_,_,_0{})) * sizeof(Element);
struct SharedStorage {
// One for each consumer WG
union {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutKp>> smem_kp;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
};
cute::array_aligned<Element, cute::cosize_v<SmemLayoutDS>> smem_ds;
// Loaded by producer, consumed by both WGs
union {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutDO>> smem_do;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQp>> smem_qp;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutDOp>> smem_dop;
};
// Accumulated into by both consumers, potentially loaded, potentially written
cute::array_aligned<ElementAccumulator, cute::cosize_v<SmemLayoutDQ>> smem_dq;
union {
cute::array_aligned<ElementAccumulator, cute::cosize_v<SmemLayoutLSE>> smem_lse;
cute::array_aligned<ElementAccumulator, cute::cosize_v<SmemLayoutLSE>> smem_sumOdO;
};
};
struct Arguments {
const Element* ptr_Q;
cute::tuple<int, int, int, _1> dQ;
const Element* ptr_K;
cute::tuple<int, int, int, _1> dK;
const Element* ptr_V;
cute::tuple<int, int, int, _1> dV;
const Element* ptr_dO;
cute::tuple<int, int, int, _1> dDO;
const ElementAccumulator* ptr_LSE;
cute::tuple<int, int, _1> dLSE;
const ElementAccumulator* ptr_sum_OdO;
cute::tuple<int, int, _1> dSumOdO;
ElementAccumulator* ptr_dQ;
cute::tuple<int, int, int, _1> dDQ;
};
using TMA_Q = typename CollectiveMmaNM::Params::TMA_B;
using TMA_K = typename CollectiveMmaNM::Params::TMA_A;
using TMA_V = typename CollectiveMmaNM::Params::TMA_A;
using TMA_DO = typename CollectiveMmaNM::Params::TMA_B;
using TMA_LSE = decltype(make_tma_copy(SM90_TMA_LOAD{}, make_tensor((const ElementAccumulator*)nullptr, make_shape(1, 1, 1), make_stride(_1{}, 0, 0)), SmemLayoutLSE{}(_,_0{})));
using TMA_ODO = TMA_LSE;
using TMA_DQ = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{}, make_tensor((const ElementAccumulator*)nullptr, make_shape(1, 1, 1, 1), make_stride(0, _1{}, 0, 0)), SmemLayoutDQ{}(_,_,_0{})));
using LoadQ = CollectiveLoadTma<
LoadKind::kBwdM,
MainloopPipeline,
Element,
SmemLayoutQ,
TMA_Q
>;
using LoadK = CollectiveLoadTma<
LoadKind::kBwdN,
MainloopPipelineQ,
Element,
SmemLayoutK,
TMA_K
>;
using LoadV = CollectiveLoadTma<
LoadKind::kBwdN,
MainloopPipelineQ,
Element,
SmemLayoutV,
TMA_V
>;
using LoadDO = CollectiveLoadTma<
LoadKind::kBwdM,
MainloopPipeline,
Element,
SmemLayoutDO,
TMA_DO
>;
using LoadLSE = CollectiveLoadTma<
LoadKind::kBwdScalar,
MainloopPipeline,
ElementAccumulator,
SmemLayoutLSE,
TMA_LSE
>;
using LoadODO = CollectiveLoadTma<
LoadKind::kBwdScalar,
MainloopPipeline,
ElementAccumulator,
SmemLayoutLSE,
TMA_ODO
>;
struct Params {
TMA_Q tma_load_q;
TMA_K tma_load_k;
TMA_V tma_load_v;
TMA_DO tma_load_do;
TMA_LSE tma_load_lse;
TMA_ODO tma_load_odo;
TMA_DQ tma_red_dq;
float scale_softmax;
float scale_softmax_log2;
};
static_assert(size(TiledMmaNM{}) == size(TiledMmaND{}));
static_assert(size(TiledMmaNM{}) == size(TiledMmaMD{}));
template<class ProblemShape>
static bool can_implement(ProblemShape const& problem_size, Arguments const& args) {
return true
&& (get<4>(problem_size) <= get<2>(TileShape{}))
&& ((get<4>(problem_size) % Alignment) == 0)
&& ((get<2>(problem_size) % Alignment) == 0)
;
}
template<class ProblemShape>
static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace) {
auto problem_shape_nm = make_shape(get<3>(problem_size), get<2>(problem_size), get<4>(problem_size), make_shape(get<0>(problem_size), get<1>(problem_size)));
auto dK = make_stride(get<2>(args.dK), get<3>(args.dK), make_stride(get<0>(args.dK), get<1>(args.dK)));
auto dQ = make_stride(get<2>(args.dQ), get<3>(args.dQ), make_stride(get<0>(args.dQ), get<1>(args.dQ)));
auto params_nm_kq = CollectiveMmaNM::to_underlying_arguments(problem_shape_nm,
typename CollectiveMmaNM::Arguments {
args.ptr_K, dK,
args.ptr_Q, dQ,
}, /*workspace=*/ nullptr);
auto dV = make_stride(get<2>(args.dV), get<3>(args.dV), make_stride(get<0>(args.dV), get<1>(args.dV)));
auto dDO = make_stride(get<2>(args.dDO), get<3>(args.dDO), make_stride(get<0>(args.dDO), get<1>(args.dDO)));
auto params_nm_vdo = CollectiveMmaNM::to_underlying_arguments(problem_shape_nm,
typename CollectiveMmaNM::Arguments {
args.ptr_V, dV,
args.ptr_dO, dDO,
}, /*workspace=*/ nullptr);
TMA_LSE tma_load_lse = make_tma_copy(SM90_TMA_LOAD{}, make_tensor(args.ptr_LSE, select<2,0,1>(problem_size), select<2,0,1>(args.dLSE)), SmemLayoutLSE{}(_,_0{}));
TMA_ODO tma_load_odo = make_tma_copy(SM90_TMA_LOAD{}, make_tensor(args.ptr_sum_OdO, select<2,0,1>(problem_size), select<2,0,1>(args.dSumOdO)), SmemLayoutLSE{}(_,_0{}));
TMA_DQ tma_red_dq = make_tma_copy(SM90_TMA_REDUCE_ADD{}, make_tensor(args.ptr_dQ, select<2,4,0,1>(problem_size), select<2,3,0,1>(args.dDQ)), SmemLayoutDQ{}(_,_,_0{}));
return Params{
params_nm_kq.tma_load_b,
params_nm_kq.tma_load_a,
params_nm_vdo.tma_load_a,
params_nm_vdo.tma_load_b,
tma_load_lse, tma_load_odo,
tma_red_dq,
1.0f / (float) std::sqrt(get<4>(problem_size)),
(float) (std::log2(std::exp(1.0)) / std::sqrt(get<4>(problem_size)))
};
}
template<class BlkCoord, class ProblemSize>
CUTLASS_DEVICE
auto
get_inner_tile_count(BlkCoord const& blk_coord, ProblemSize const& problem_size) {
return Fusion{}.get_trip_count(blk_coord, TileShape{}, problem_size);
}
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& params) {
cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor());
cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor());
cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor());
cute::prefetch_tma_descriptor(params.tma_load_do.get_tma_descriptor());
cute::prefetch_tma_descriptor(params.tma_load_odo.get_tma_descriptor());
cute::prefetch_tma_descriptor(params.tma_load_lse.get_tma_descriptor());
}
template<bool kLoadOuter, class BlkCoord, class ProblemShape, class LoadWarpBarrier>
CUTLASS_DEVICE void
load_kv_maybe_q(
int block_rank_in_cluster,
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size,
MainloopPipeline& pipeline_inner, PipelineState& smem_pipe_write_inner,
MainloopPipelineQ& pipeline_outer, PipelineStateQ& smem_pipe_write_outer,
SharedStorage& storage,
LoadWarpBarrier& load_warp_barrier, bool do_barrier)
{
// Load pattern:
// K0 V0 K1 V1
// Q0 DO0 Q1 DO1 Q2 DO2 ...
// K0 Q0 V0 K1 DO0 V1 ...
int lane_predicate = cute::elect_one_sync();
int outer_tile_count = NumMmaWarpGroups;
int inner_tile_count = get_inner_tile_count(blk_coord, problem_size);
auto outer_tile_iter = cute::make_coord_iterator(outer_tile_count);
auto inner_tile_iter = cute::make_coord_iterator(inner_tile_count);
uint16_t mcast_mask_b = 0;
LoadQ load_q{params.tma_load_q, pipeline_inner, storage.smem_q};
auto load_state_q = load_q.init_state(block_rank_in_cluster, problem_size, TileShapeNM{}, blk_coord, inner_tile_count);
LoadDO load_do{params.tma_load_do, pipeline_inner, storage.smem_do};
auto load_state_do = load_do.init_state(block_rank_in_cluster, problem_size, TileShapeNM{}, blk_coord, inner_tile_count);
LoadK load_k{params.tma_load_k, pipeline_outer, storage.smem_k};
auto load_state_k = load_k.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count);
LoadV load_v{params.tma_load_v, pipeline_outer, storage.smem_v};
auto load_state_v = load_v.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count);
LoadLSE load_lse{params.tma_load_lse, pipeline_inner, storage.smem_lse};
auto load_state_lse = load_lse.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count);
LoadODO load_odo{params.tma_load_odo, pipeline_inner, storage.smem_sumOdO};
auto load_state_odo = load_odo.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count);
outer_tile_count *= 2; // K & V
inner_tile_count *= 4; // Q & dO & LSE & sumOdO
while (inner_tile_count > 0) {
if (Fusion{}.is_contributing(make_coord(*inner_tile_iter, get<1>(blk_coord)), TileShape{}, problem_size)) {
break;
}
inner_tile_count -= 4;
++inner_tile_iter;
}
if constexpr (kLoadOuter) {
load_k.template step<false>(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count);
}
load_q.template step<false,false,true>(inner_tile_iter, load_state_q, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
load_lse.template step<false,true,false>(inner_tile_iter, load_state_lse, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
if constexpr (! kLoadOuter) {
if (do_barrier) {
load_warp_barrier.arrive();
load_warp_barrier.wait(/*phase=*/ 0);
do_barrier = false;
}
}
if constexpr (kLoadOuter) {
load_v.template step<true>(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count);
load_k.template step<false>(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count);
}
load_do.template step<false,false,true>(inner_tile_iter, load_state_do, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
load_odo.template step<true,true,false>(inner_tile_iter, load_state_odo, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
if constexpr (kLoadOuter) {
load_v.template step<true>(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count);
}
if constexpr (kLoadOuter) {
while (outer_tile_count > 0) {
load_k.template step<false>(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count);
load_v.template step<true>(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count);
}
}
CUTLASS_PRAGMA_NO_UNROLL
while (inner_tile_count > 0) {
while (inner_tile_count > 0) {
if (Fusion{}.is_contributing(make_coord(*inner_tile_iter, get<1>(blk_coord)), TileShape{}, problem_size)) {
break;
}
inner_tile_count -= 4;
++inner_tile_iter;
}
load_q.template step<false,false,true>(inner_tile_iter, load_state_q, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
load_lse.template step<false,true,false>(inner_tile_iter, load_state_lse, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
load_do.template step<false,false,true>(inner_tile_iter, load_state_do, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
load_odo.template step<true,true,false>(inner_tile_iter, load_state_odo, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
}
}
template<class BlkCoord, class ProblemShape, class LoadWarpBarrier>
CUTLASS_DEVICE void
load_maybe_q(
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size,
MainloopPipelineQ& pipeline_outer, PipelineStateQ& smem_pipe_write_outer,
SharedStorage& storage,
LoadWarpBarrier& load_warp_barrier, bool do_barrier)
{
// Load pattern:
// K0 V0 K1 V1
// Q0 DO0 Q1 DO1 Q2 DO2 ...
// K0 Q0 V0 K1 DO0 V1 ...
int lane_predicate = cute::elect_one_sync();
int outer_tile_count = NumMmaWarpGroups;
auto outer_tile_iter = cute::make_coord_iterator(outer_tile_count);
LoadK load_k{params.tma_load_k, pipeline_outer, storage.smem_k};
auto load_state_k = load_k.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count);
LoadV load_v{params.tma_load_v, pipeline_outer, storage.smem_v};
auto load_state_v = load_v.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count);
outer_tile_count *= 2; // K & V
load_k.template step<false>(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count);
if (do_barrier) {
load_warp_barrier.arrive();
load_warp_barrier.wait(/*phase=*/ 0);
do_barrier = false;
}
load_v.template step<true>(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count);
while (outer_tile_count > 0) {
load_k.template step<false>(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count);
load_v.template step<true>(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count);
}
}
template<class BlkCoord, class ProblemShape, class MainloopPipelineReducer, class PipelineStateReducer>
CUTLASS_DEVICE void
reduce(
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size,
MainloopPipelineReducer& pipeline_reducer, PipelineStateReducer& smem_pipe_read_reducer,
SharedStorage& storage)
{
int lane_predicate = cute::elect_one_sync();
Tensor mDQ_full = params.tma_red_dq.get_tma_tensor(select<2,4,0,1>(problem_size));
Tensor gDQ_full = local_tile(mDQ_full, TileShapeMD{}, make_coord(_, _, _), Step<_1, _1, Underscore>{});
Tensor gDQ = gDQ_full(_, _, _, _0{}, get<2,0>(blk_coord), get<2,1>(blk_coord));
Tensor sDQ = make_tensor(make_smem_ptr(storage.smem_dq.data()), SmemLayoutDQ{});
auto block_tma = params.tma_red_dq.get_slice(_0{});
Tensor tDQsDQ = block_tma.partition_S(sDQ);
Tensor tDQgDQ = block_tma.partition_D(gDQ);
int inner_tile_count = get_inner_tile_count(blk_coord, problem_size);
int g_index = 0;
auto smem_pipe_release_reducer = smem_pipe_read_reducer;
bool first = true;
while (inner_tile_count > 0) {
while (inner_tile_count > 0) {
if (Fusion{}.is_contributing(make_coord(g_index, get<1>(blk_coord)), TileShape{}, problem_size)) {
break;
}
inner_tile_count -= 1;
++g_index;
}
if (inner_tile_count == 0) break;
pipeline_reducer.consumer_wait(smem_pipe_read_reducer);
if (lane_predicate == 1) {
tma_store_wait<1>();
}
if (! first) {
pipeline_reducer.consumer_release(smem_pipe_release_reducer);
++smem_pipe_release_reducer;
} else {
first = false;
}
if (lane_predicate == 1) {
copy(params.tma_red_dq, tDQsDQ(_,_,_,smem_pipe_read_reducer.index()), tDQgDQ(_,_,_,g_index));
tma_store_arrive();
}
++smem_pipe_read_reducer;
--inner_tile_count;
++g_index;
}
if (lane_predicate) {
tma_store_wait<0>();
}
pipeline_reducer.consumer_release(smem_pipe_release_reducer);
++smem_pipe_release_reducer;
}
template<class BlkCoord, class ProblemShape, class MainloopPipelineReducer, class PipelineStateReducer, class MathWgOrderBarrier>
CUTLASS_DEVICE auto
compute(
BlkCoord const& blk_coord, BlkCoord const& wg_coord,
Params const& params, ProblemShape const& problem_size,
MainloopPipeline& pipeline_inner, PipelineState& smem_pipe_read_inner,
MainloopPipelineQ& pipeline_outer, PipelineStateQ& smem_pipe_read_outer,
MainloopPipelineReducer& pipeline_reducer, PipelineStateReducer& smem_pipe_write_reducer,
SharedStorage& storage,
MathWgOrderBarrier& math_wg_order_barrier)
{
TiledMmaND tiled_mma_nd;
Tensor acc_DV = partition_fragment_C(tiled_mma_nd, take<0,2>(TileShapeND{}));
clear(acc_DV);
Tensor acc_DK = partition_fragment_C(tiled_mma_nd, take<0,2>(TileShapeND{}));
clear(acc_DK);
int thread_idx = int(threadIdx.x) % cutlass::NumThreadsPerWarpGroup;
PipelineState smem_pipe_release_inner = smem_pipe_read_inner;
pipeline_outer.consumer_wait(smem_pipe_read_outer);
PipelineStateQ smem_pipe_read_k = smem_pipe_read_outer;
++smem_pipe_read_outer;
pipeline_outer.consumer_wait(smem_pipe_read_outer);
PipelineStateQ smem_pipe_read_v = smem_pipe_read_outer;
int inner_tile_count = get_inner_tile_count(wg_coord, problem_size);
TiledMmaNM tiled_mma_nm;
Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{});
Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{});
auto thr_mma_nm = tiled_mma_nm.get_thread_slice(thread_idx);
Tensor tSsK = thr_mma_nm.partition_A(sK);
Tensor tSsQ = thr_mma_nm.partition_B(sQ);
Tensor tSrK = thr_mma_nm.make_fragment_A(tSsK);
Tensor tSrQ = thr_mma_nm.make_fragment_B(tSsQ);
Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{});
Tensor sDO = make_tensor(make_smem_ptr(storage.smem_do.data()), SmemLayoutDO{});
Tensor tDPsV = thr_mma_nm.partition_A(sV);
Tensor tDPsDO = thr_mma_nm.partition_B(sDO);
Tensor tDPrV = thr_mma_nm.make_fragment_A(tDPsV);
Tensor tDPrDO = thr_mma_nm.make_fragment_B(tDPsDO);
auto thr_mma_nd = tiled_mma_nd.get_thread_slice(thread_idx);
Tensor sDOp = make_tensor(make_smem_ptr(storage.smem_dop.data()), SmemLayoutDOp{});
Tensor tDV_sDO = thr_mma_nd.partition_B(sDOp);
Tensor tDVrDO = thr_mma_nd.make_fragment_B(tDV_sDO);
Tensor sQp = make_tensor(make_smem_ptr(storage.smem_qp.data()), SmemLayoutQp{});
Tensor tDK_sQ = thr_mma_nd.partition_B(sQp);
Tensor tDKrQ = thr_mma_nd.make_fragment_B(tDK_sQ);
int wg_idx = __shfl_sync(0xffffffff, get<1>(wg_coord) % NumMmaWarpGroups, 0);
TiledMmaMD tiled_mma_md;
auto thr_mma_md = tiled_mma_md.get_thread_slice(thread_idx);
Tensor sDS = make_tensor(make_smem_ptr(storage.smem_ds.data()), SmemLayoutDS{});
Tensor tDQsDS = thr_mma_md.partition_A(sDS);
Tensor tDQrDS_full = thr_mma_md.make_fragment_A(tDQsDS);
Tensor tDQrDS = tDQrDS_full(_,_,_,_);
Tensor sKp = make_tensor(make_smem_ptr(storage.smem_kp.data()), SmemLayoutKp{});
Tensor tDQsK = thr_mma_md.partition_B(sKp);
Tensor tDQrK = thr_mma_md.make_fragment_B(tDQsK);
Tensor sLSE = make_tensor(make_smem_ptr(storage.smem_lse.data()), make_shape(get<0>(TileShapeNM{}), get<1>(TileShapeNM{}), Int<StageCount>{}), make_stride(_0{}, _1{}, get<1>(TileShapeNM{})));
Tensor tSsLSE = thr_mma_nm.partition_C(sLSE);
Tensor sODO = make_tensor(make_smem_ptr(storage.smem_sumOdO.data()), make_shape(get<0>(TileShapeNM{}), get<1>(TileShapeNM{}), Int<StageCount>{}), make_stride(_0{}, _1{}, get<1>(TileShapeNM{})));
Tensor tDPsODO = thr_mma_nm.partition_C(sODO);
Tensor cS = make_identity_tensor(take<0,2>(TileShapeNM{}));
Tensor tScS = thr_mma_nm.partition_C(cS);
int n_block = get<1>(wg_coord);
tScS.data() = tScS.data() + E<0>{} * n_block * get<0>(TileShapeNM{});
// Transpose
Tensor sDSp_full = sDS.compose(make_layout(make_shape(size<1>(sDS), size<0>(sDS), size<2>(sDS)), make_stride(size<0>(sDS), _1{}, size<1>(sDS) * size<0>(sDS))));
Tensor sDSp = sDSp_full(_,_,_);
Tensor tDPsDS = thr_mma_nm.partition_C(sDSp);
auto thr_mma_nd_ss = TiledMmaND_SS{}.get_thread_slice(thread_idx);
Tensor tDKsDSp = thr_mma_nd_ss.partition_A(sDSp);
Tensor tDKrDSp = thr_mma_nd_ss.make_fragment_A(tDKsDSp);
Tensor sDQ = make_tensor(make_smem_ptr(storage.smem_dq.data()), SmemLayoutDQ{});
auto tDQsDQ_full = thr_mma_md.partition_C(sDQ);
auto smem_pipe_read_k_other = smem_pipe_read_k;
smem_pipe_read_k_other.advance(2);
int k_index = 0;
while (inner_tile_count > 0) {
while (inner_tile_count > 0) {
if (Fusion{}.is_contributing(make_coord(k_index, get<1>(blk_coord)), TileShape{}, problem_size)) {
break;
}
inner_tile_count -= 1;
tScS.data() = tScS.data() + E<1>{} * get<1>(TileShapeNM{});
k_index += 1;
}
if (inner_tile_count == 0) break;
pipeline_inner.consumer_wait(smem_pipe_read_inner);
PipelineState smem_pipe_read_q = smem_pipe_read_inner;
++smem_pipe_read_inner;
PipelineState smem_pipe_read_do = smem_pipe_read_inner;
++smem_pipe_read_inner;
// GEMM KQ -> S
Tensor acc_S = partition_fragment_C(tiled_mma_nm, take<0,2>(TileShapeNM{}));
warpgroup_fence_operand(acc_S);
warpgroup_arrive();
gemm_zero_acc(tiled_mma_nm, tSrK(_,_,_,smem_pipe_read_k.index()), tSrQ(_,_,_,smem_pipe_read_q.index()), acc_S);
warpgroup_commit_batch();
pipeline_inner.consumer_wait(smem_pipe_read_do);
// GEMM VdO -> dP
Tensor acc_DP = partition_fragment_C(tiled_mma_nm, take<0,2>(TileShapeNM{}));
warpgroup_fence_operand(acc_DP);
warpgroup_arrive();
gemm_zero_acc(tiled_mma_nm, tDPrV(_,_,_,smem_pipe_read_v.index()), tDPrDO(_,_,_,smem_pipe_read_do.index()), acc_DP);
warpgroup_commit_batch();
Tensor reg_LSE = make_fragment_like<ElementAccumulator>(acc_S);
for (int i = 0; i < size(reg_LSE); i++) {
reg_LSE(i) = ((ElementAccumulator)std::log2(std::exp(1.0))) * tSsLSE(_,_,_,smem_pipe_read_q.index())(i);
}
Tensor reg_ODO = make_fragment_like<ElementAccumulator>(acc_S);
if constexpr (decltype(get<0>(TileShape{}) != _128{})::value) {
for (int i = 0; i < size(reg_ODO); i++) {
reg_ODO(i) = tDPsODO(_,_,_,smem_pipe_read_do.index())(i);
}
}
warpgroup_wait<1>();
warpgroup_fence_operand(acc_S);
math_wg_order_barrier.wait();
// Compute S -> P
Fusion{}.before_softmax(acc_S, tScS, problem_size);
auto acc_P = make_fragment_like<ElementAccumulator>(acc_S);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_P); i++) {
acc_P(i) = ::exp2f(params.scale_softmax_log2 * acc_S(i) - reg_LSE(i));
}
math_wg_order_barrier.arrive();
if constexpr (decltype(get<0>(TileShape{}) == _128{})::value) {
for (int i = 0; i < size(reg_ODO); i++) {
reg_ODO(i) = tDPsODO(_,_,_,smem_pipe_read_do.index())(i);
}
}
warpgroup_wait<0>();
warpgroup_fence_operand(acc_DP);
// Compute dP P -> dS
auto acc_DS = make_fragment_like<Element>(acc_DP);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_DS); i++) {
// We could move the scale out and into the respective epilogues (or a final scaling step)
acc_DS(i) = acc_P(i) * params.scale_softmax * (acc_DP(i) - reg_ODO(i));
}
// GEMM PdO -> dV
auto op_P = make_acc_into_op<Element>(acc_P, typename TiledMmaND::LayoutA_TV{});
warpgroup_fence_operand(acc_DV);
warpgroup_fence_operand(op_P);
warpgroup_arrive();
cute::gemm(tiled_mma_nd, op_P, tDVrDO(_,_,_,smem_pipe_read_do.index()), acc_DV);
warpgroup_commit_batch();
// Store dS to smem dS'
if (wg_idx == 0) math_wg_order_barrier.wait();
auto recast_bits = [](auto sz, auto t) {
return recast<uint_bit_t<decltype(sz)::value>>(t);
};
auto tDPsDS_v = recast_bits(Int<sizeof_bits_v<Element> * 2>{}, tDPsDS);
auto acc_DS_v = recast_bits(Int<sizeof_bits_v<Element> * 2>{}, acc_DS);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_DS_v); i++) {
tDPsDS_v(_,_,_,wg_idx)(i) = acc_DS_v(i);
}
cutlass::arch::fence_view_async_shared();
if (wg_idx == 0) math_wg_order_barrier.arrive();
// GEMM dS Q -> dK
if (wg_idx == 1) {
math_wg_order_barrier.wait();
// GEMM dS' K -> dQ
Tensor acc_DQ = partition_fragment_C(tiled_mma_md, take<0,2>(TileShapeMD{}));
warpgroup_fence_operand(acc_DQ);
warpgroup_arrive();
gemm_zero_acc(tiled_mma_md, tDQrDS(_,_,_,0), tDQrK(_,_,_,smem_pipe_read_k_other.index()), acc_DQ);
cute::gemm(tiled_mma_md, tDQrDS(_,_,_,1), tDQrK(_,_,_,smem_pipe_read_k.index()), acc_DQ);
warpgroup_commit_batch();
warpgroup_fence_operand(acc_DK);
warpgroup_arrive();
cute::gemm(TiledMmaND_SS{}, tDKrDSp(_,_,_,wg_idx), tDKrQ(_,_,_,smem_pipe_read_q.index()), acc_DK);
warpgroup_commit_batch();
warpgroup_wait<1>();
warpgroup_fence_operand(acc_DK);
warpgroup_wait<1>();
warpgroup_fence_operand(acc_DQ);
math_wg_order_barrier.arrive();
pipeline_reducer.producer_acquire(smem_pipe_write_reducer);
auto tDQsDQ = tDQsDQ_full(_,_,_,smem_pipe_write_reducer.index());
// Store dQ to smem dQ'
// Invoke TMA reduce on dQ'
using Vec = uint_bit_t<sizeof_bits_v<ElementAccumulator> * 2>;
auto tDQsDQ_v = recast<Vec>(tDQsDQ);
auto acc_DQ_v = recast<Vec>(acc_DQ);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_DQ_v); i++) {
tDQsDQ_v(i) = acc_DQ_v(i);
}
cutlass::arch::fence_view_async_shared();
pipeline_reducer.producer_commit(smem_pipe_write_reducer);
++smem_pipe_write_reducer;
} else {
warpgroup_fence_operand(acc_DK);
warpgroup_arrive();
cute::gemm(TiledMmaND_SS{}, tDKrDSp(_,_,_,wg_idx), tDKrQ(_,_,_,smem_pipe_read_q.index()), acc_DK);
warpgroup_commit_batch();
warpgroup_wait<1>();
warpgroup_fence_operand(acc_DK);
pipeline_reducer.producer_acquire(smem_pipe_write_reducer);
pipeline_reducer.producer_commit(smem_pipe_write_reducer);
++smem_pipe_write_reducer;
}
--inner_tile_count;
pipeline_inner.consumer_release(smem_pipe_release_inner);
++smem_pipe_release_inner;
pipeline_inner.consumer_release(smem_pipe_release_inner);
++smem_pipe_release_inner;
tScS.data() = tScS.data() + E<1>{} * get<1>(TileShapeNM{});
k_index += 1;
}
pipeline_outer.consumer_release(smem_pipe_read_k);
pipeline_outer.consumer_release(smem_pipe_read_outer);
pipeline_reducer.producer_tail(smem_pipe_write_reducer);
++smem_pipe_read_outer;
warpgroup_wait<0>();
warpgroup_fence_operand(acc_DK);
warpgroup_fence_operand(acc_DV);
return make_tuple(acc_DK, acc_DV);
}
};
} // namespace cutlass::fmha::collective

View File

@ -0,0 +1,140 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
namespace cutlass::fmha::collective {
enum class LoadKind {
kQ, kK, kV,
kBwdN, kBwdM, kBwdScalar
};
template<
LoadKind kKind,
class Pipeline,
class Element,
class SmemLayout,
class TMA
>
struct CollectiveLoadTma {
using Params = TMA;
using SharedStorage = cute::array_aligned<Element, cute::cosize_v<SmemLayout>>;
using PipelineState = typename cutlass::PipelineState<Pipeline::Stages>;
Params const& params;
Pipeline& pipeline;
SharedStorage& storage;
CUTLASS_DEVICE
CollectiveLoadTma(Params const& params, Pipeline& pipeline, SharedStorage& storage)
: params(params), pipeline(pipeline), storage(storage) {}
template<class ProblemSize, class TileShape, class BlockCoord>
CUTLASS_DEVICE auto init_g(ProblemSize const& problem_size, TileShape const& tile_shape,
BlockCoord const& blk_coord, int loop_count
) {
using X = Underscore;
if constexpr (kKind == LoadKind::kK) {
Tensor mK_full = params.get_tma_tensor(make_shape(get<3>(problem_size), get<4>(problem_size), select<0,1>(problem_size)));
Tensor gK_full = local_tile(mK_full, tile_shape, make_coord(_, _, _), Step<X, _1, _1>{});
Tensor gK = gK_full(_, _, _, _0{}, get<2>(blk_coord));
return gK;
} else if constexpr (kKind == LoadKind::kQ) {
Tensor mQ_full = params.get_tma_tensor(make_shape(get<2>(problem_size), get<4>(problem_size), select<0,1>(problem_size)));
Tensor gQ_full = local_tile(mQ_full, tile_shape, make_coord(_, _, _), Step<_1, X, _1>{});
Tensor gQ = gQ_full(_, _, _, _0{}, get<2>(blk_coord));
return make_tensor(gQ.data() + loop_count * get<0>(blk_coord) * stride<2>(gQ), gQ.layout());
} else if constexpr (kKind == LoadKind::kV) {
Tensor mV_full = params.get_tma_tensor(make_shape(get<4>(problem_size), get<3>(problem_size), select<0,1>(problem_size)));
Tensor gV_full = local_tile(mV_full, tile_shape, make_coord(_, _, _), Step<X, _1, _1>{});
Tensor gV = gV_full(_, _, _0{}, _, get<2>(blk_coord));
return gV;
} else if constexpr (kKind == LoadKind::kBwdN) {
Tensor m_full = params.get_tma_tensor(make_shape(get<3>(problem_size), get<4>(problem_size), select<0,1>(problem_size)));
Tensor g_full = local_tile(m_full, tile_shape, make_coord(_, _, _), Step<_1, X, _1>{});
Tensor g = g_full(_, _, _, _0{}, get<2>(blk_coord));
return make_tensor(g.data() + loop_count * get<1>(blk_coord) * stride<2>(g), g.layout());
} else if constexpr (kKind == LoadKind::kBwdM) {
Tensor m_full = params.get_tma_tensor(make_shape(get<2>(problem_size), get<4>(problem_size), select<0,1>(problem_size)));
Tensor g_full = local_tile(m_full, tile_shape, make_coord(_, _, _), Step<X, _1, _1>{});
Tensor g = g_full(_, _, _, _0{}, get<2>(blk_coord));
return g;
} else if constexpr (kKind == LoadKind::kBwdScalar) {
Tensor m_full = params.get_tma_tensor(select<2,0,1>(problem_size));
Tensor g_full = local_tile(m_full, tile_shape, make_coord(_, _, _), Step<X, _1, X>{});
Tensor g = g_full(_, _, get<2,0>(blk_coord), get<2,1>(blk_coord));
return g;
}
}
template<class ClusterRank, class ProblemSize, class TileShape, class BlockCoord>
CUTLASS_DEVICE auto init_state(ClusterRank const& block_rank_in_cluster,
ProblemSize const& problem_size, TileShape const& tile_shape,
BlockCoord const& block_coord, int loop_count
) {
Tensor g = init_g(problem_size, tile_shape, block_coord, loop_count);
Tensor s = make_tensor(make_smem_ptr(storage.data()), SmemLayout{});
auto block_tma = params.get_slice(block_rank_in_cluster);
Tensor ts = block_tma.partition_D(s);
Tensor tg = block_tma.partition_S(g);
return make_tuple(tg, ts);
}
template<bool kAdvanceIterator=true, bool kAdvancePipe=true, bool kAcquireBarrier=true, class TileIterator, class State>
CUTLASS_DEVICE void step(TileIterator& tile_iter, State const& state,
PipelineState& smem_pipe_write,
int lane_predicate, int& tile_count, uint16_t mcast_mask = 0
) {
if ((lane_predicate == 1) && (tile_count > 0)) {
if constexpr (kAcquireBarrier) pipeline.producer_acquire(smem_pipe_write);
using BarrierType = typename Pipeline::ProducerBarrierType;
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
if constexpr (kKind == LoadKind::kBwdScalar) {
copy(params.with(*tma_barrier, mcast_mask), get<0>(state)(_,_,*tile_iter), get<1>(state)(_,_,smem_pipe_write.index()));
} else {
copy(params.with(*tma_barrier, mcast_mask), get<0>(state)(_,_,_,*tile_iter), get<1>(state)(_,_,_,smem_pipe_write.index()));
}
if constexpr (kAdvancePipe) ++smem_pipe_write;
if constexpr (kAdvanceIterator) ++tile_iter;
}
--tile_count;
}
};
} // namespace cutlass::fmha::collective

View File

@ -0,0 +1,305 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "../collective/fmha_common.hpp"
namespace cutlass::fmha::collective {
template<
class ElementAccumulator,
class Fusion,
class Params
>
struct CollectiveSoftmax {
Params const& params;
CUTLASS_DEVICE CollectiveSoftmax(Params const& params) : params(params) {}
using SumType = float;
using MaxType = ElementAccumulator;
template<class AccPV, class TiledMmaPV>
CUTLASS_DEVICE auto init(AccPV const& acc_pv, TiledMmaPV const& tiled_mma_pv) {
Tensor s_max = make_fragment_like<MaxType>(size<0>(layout_acc_mn(tiled_mma_pv, acc_pv.layout())));
Tensor a_sum = make_fragment_like<SumType>(s_max);
return make_tuple(s_max, a_sum);
}
CUTLASS_DEVICE float overload_exp2(float f) {
return ::exp2f(f);
}
CUTLASS_DEVICE cutlass::half_t overload_exp2(cutlass::half_t f) {
auto a = f.raw();
decltype(a) d;
asm("ex2.approx.f16 %0, %1;" : "=h"(d) : "h"(a));
return cutlass::half_t::bitcast(d);
}
CUTLASS_DEVICE float overload_max(float a, float b) {
return ::max(a, b);
}
CUTLASS_DEVICE cutlass::half_t overload_max(cutlass::half_t a, cutlass::half_t b) {
return cutlass::half_t{__hmax_nan(a.to_half(), b.to_half())};
}
CUTLASS_DEVICE half overload_to_native(cutlass::half_t f) {
return f.to_half();
}
CUTLASS_DEVICE float overload_to_native(float f) {
return f;
}
template<class AccQK, class TiledMmaQK, class CountQK, class State, class ProblemShape>
CUTLASS_DEVICE auto step(AccQK& acc_qk, TiledMmaQK const& tiled_mma_qk, CountQK const& count_qk, State& state, ProblemShape const& problem_shape) {
Fusion{}.before_softmax(acc_qk, count_qk, problem_shape);
Tensor acc_qk_mn = make_tensor(acc_qk.data(), layout_acc_mn(tiled_mma_qk, acc_qk.layout()));
auto reduction_target_qk = reduction_target_n(tiled_mma_qk);
constexpr int red_rank = decltype(rank(reduction_target_qk))::value;
auto& s_max = get<0>(state);
auto& a_sum = get<1>(state);
// Linear reduction is faster for the first iteration
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
s_max(i) = acc_qk_mn(i, 0);
}
CUTLASS_PRAGMA_UNROLL
for (int j = 1; j < size<1>(acc_qk_mn); j++) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
s_max(i) = overload_max(s_max(i), acc_qk_mn(i, j));
}
}
for_each(make_seq<red_rank>{}, [&](auto r) {
CUTLASS_PRAGMA_UNROLL
for (int j = 1; j < shape<r>(reduction_target_qk); j *= 2) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
s_max(i) = overload_max(s_max(i), MaxType{__shfl_xor_sync(uint32_t(-1), overload_to_native(s_max(i)), stride<r>(reduction_target_qk) * j)});
}
}
});
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
MaxType local_max = s_max(i) == static_cast<MaxType>(-INFINITY) ? static_cast<MaxType>(0) : s_max(i);
MaxType scale = static_cast<MaxType>(params.scale_softmax_log2);
MaxType scale_max = scale * local_max;
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size<1>(acc_qk_mn); j++) {
acc_qk_mn(i, j) = overload_exp2(scale * acc_qk_mn(i, j) - scale_max);
}
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
a_sum(i) = SumType{reduce(acc_qk_mn(i, _), cute::plus{})};
}
}
template<bool kUseFusion=true, class AccQK, class TiledMmaQK, class CountQK, class State, class AccPV, class TiledMmaPV, class ProblemShape>
CUTLASS_DEVICE auto step_interleave_begin(AccQK& acc_qk, TiledMmaQK const& tiled_mma_qk, CountQK const& count_qk, State& state, AccPV& acc_pv, TiledMmaPV const& tiled_mma_pv, ProblemShape const& problem_shape) {
if constexpr (kUseFusion) {
Fusion{}.before_softmax(acc_qk, count_qk, problem_shape);
}
Tensor acc_qk_mn = make_tensor(acc_qk.data(), layout_acc_mn(tiled_mma_qk, acc_qk.layout()));
Tensor acc_pv_mn = make_tensor(acc_pv.data(), layout_acc_mn(tiled_mma_pv, acc_pv.layout()));
static_assert(size<0>(acc_qk_mn) == size<0>(acc_pv_mn));
auto reduction_target_qk = reduction_target_n(tiled_mma_qk);
constexpr int red_rank = decltype(rank(reduction_target_qk))::value;
auto& s_max = get<0>(state);
auto& a_sum = get<1>(state);
Tensor s_max_prev = make_fragment_like(s_max);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
s_max_prev(i) = s_max(i);
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
// Linear reduction is faster here, as well
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size<1>(acc_qk_mn); j++) {
s_max(i) = overload_max(s_max(i), acc_qk_mn(i, j));
}
}
// reduce max
for_each(make_seq<red_rank>{}, [&](auto r) {
CUTLASS_PRAGMA_UNROLL
for (int j = 1; j < shape<r>(reduction_target_qk); j *= 2) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
s_max(i) = overload_max(s_max(i), __shfl_xor_sync(uint32_t(-1), s_max(i), stride<r>(reduction_target_qk) * j));
}
}
});
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<0>(acc_pv_mn); i++) {
float s_max_cur = s_max(i) == -INFINITY ? 0.0f : s_max(i);
float scale = ::exp2f((s_max_prev(i) - s_max_cur) * params.scale_softmax_log2);
a_sum(i) *= scale;
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size<1>(acc_pv_mn); j++) {
acc_pv_mn(i, j) *= scale;
}
}
}
template<class AccQK_MN, class State>
CUTLASS_DEVICE auto step_interleave_step(AccQK_MN& acc_qk_mn, State& state) {
auto& s_max = get<0>(state);
auto& a_sum = get<1>(state);
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size<0>(acc_qk_mn); j++) {
float local_max = s_max(j) == -INFINITY ? 0.f : s_max(j);
float scale_max = params.scale_softmax_log2 * local_max;
CUTLASS_PRAGMA_UNROLL
for (int k = 0; k < size<1>(acc_qk_mn); k++) {
acc_qk_mn(j, k) = ::exp2f(params.scale_softmax_log2 * acc_qk_mn(j, k) - scale_max);
a_sum(j) += acc_qk_mn(j, k);
}
}
}
template<bool kUseFusion=true, class AccQK, class TiledMmaQK, class CountQK, class State, class AccPV, class TiledMmaPV, class ProblemShape>
CUTLASS_DEVICE auto step(AccQK& acc_qk, TiledMmaQK const& tiled_mma_qk, CountQK const& count_qk, State& state, AccPV& acc_pv, TiledMmaPV const& tiled_mma_pv, ProblemShape const& problem_shape) {
if constexpr (kUseFusion) {
Fusion{}.before_softmax(acc_qk, count_qk, problem_shape);
}
Tensor acc_qk_mn = make_tensor(acc_qk.data(), layout_acc_mn(tiled_mma_qk, acc_qk.layout()));
Tensor acc_pv_mn = make_tensor(acc_pv.data(), layout_acc_mn(tiled_mma_pv, acc_pv.layout()));
static_assert(size<0>(acc_qk_mn) == size<0>(acc_pv_mn));
auto reduction_target_qk = reduction_target_n(tiled_mma_qk);
constexpr int red_rank = decltype(rank(reduction_target_qk))::value;
auto& s_max = get<0>(state);
auto& a_sum = get<1>(state);
Tensor s_max_prev = make_fragment_like(s_max);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
s_max_prev(i) = s_max(i);
// Linear reduction is faster here, as well
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size<1>(acc_qk_mn); j++) {
s_max(i) = overload_max(s_max(i), acc_qk_mn(i, j));
}
// reduce max
for_each(make_seq<red_rank>{}, [&](auto r) {
CUTLASS_PRAGMA_UNROLL
for (int j = 1; j < shape<r>(reduction_target_qk); j *= 2) {
s_max(i) = overload_max(s_max(i), MaxType{__shfl_xor_sync(uint32_t(-1), overload_to_native(s_max(i)), stride<r>(reduction_target_qk) * j)});
}
});
MaxType local_max = s_max(i) == static_cast<MaxType>(-INFINITY) ? static_cast<MaxType>(0) : s_max(i);
MaxType scale = static_cast<MaxType>(params.scale_softmax_log2);
MaxType scale_max = scale * local_max;
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size<1>(acc_qk_mn); j++) {
acc_qk_mn(i, j) = overload_exp2(scale * acc_qk_mn(i, j) - scale_max);
}
MaxType s_max_cur = s_max(i) == static_cast<MaxType>(-INFINITY) ? static_cast<MaxType>(0) : s_max(i);
SumType scale_pv = overload_exp2((s_max_prev(i) - s_max_cur) * scale);
a_sum(i) *= scale_pv;
using ElementPV = typename AccPV::value_type;
ElementPV scale_pv_ele = static_cast<ElementPV>(scale_pv);
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size<1>(acc_pv_mn); j++) {
acc_pv_mn(i, j) *= scale_pv_ele;
}
a_sum(i) += SumType{reduce(acc_qk_mn(i, _), cute::plus{})};
}
}
template<class State, class AccPV, class TiledMmaPV>
CUTLASS_DEVICE auto tail(State& state, AccPV& acc_pv, TiledMmaPV const& tiled_mma_pv) {
auto& s_max = get<0>(state);
auto& a_sum = get<1>(state);
Tensor acc_pv_mn = make_tensor(acc_pv.data(), layout_acc_mn(tiled_mma_pv, acc_pv.layout()));
auto reduction_target = reduction_target_n(tiled_mma_pv);
constexpr int red_rank = decltype(rank(reduction_target))::value;
for_each(make_seq<red_rank>{}, [&](auto r) {
CUTLASS_PRAGMA_UNROLL
for (int j = 1; j < shape<r>(reduction_target); j *= 2) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<0>(acc_pv_mn); i++) {
a_sum(i) = a_sum(i) + __shfl_xor_sync(uint32_t(-1), a_sum(i), stride<r>(reduction_target) * j);
}
}
});
Tensor acc_mn = make_tensor(acc_pv.data(), layout_acc_mn(tiled_mma_pv, acc_pv.layout()));
Tensor lse = make_fragment_like(a_sum);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<0>(acc_mn); i++) {
float sum = a_sum(i);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : __frcp_rn(sum);
lse(i) = (sum == 0.f || sum != sum) ? INFINITY : s_max(i) * params.scale_softmax + __logf(sum);
float scale = params.rp_dropout * inv_sum;
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size<1>(acc_mn); j++) {
acc_mn(i, j) *= scale;
}
}
return lse;
}
};
} // namespace cutlass::fmha::collective

View File

@ -0,0 +1,526 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "../collective/fmha_common.hpp"
#include "../collective/fmha_collective_load.hpp"
#include "../collective/fmha_collective_softmax.hpp"
#include "../kernel/fmha_options.hpp"
namespace cutlass::fmha::collective {
using namespace cute;
using cutlass::fmha::kernel::Tag;
using cutlass::fmha::kernel::find_option_t;
template<
typename Element_,
typename ElementAccumulator_,
typename TileShape_, // BlockQO, BlockKV, BlockHead
class Fusion,
class... Options
>
struct FmhaMainloopTma {
using Element = Element_;
using ElementAccumulator = ElementAccumulator_;
using TileShape = TileShape_;
// Options
using kClusterM = find_option_t<Tag::kClusterM, Int<1>, Options...>;
static constexpr int StageCount = find_option_t<Tag::kStagesKV, Int<4>, Options...>::value;
static constexpr int StageCountQ = find_option_t<Tag::kStagesQ, Int<1>, Options...>::value;
using StagesQ = cutlass::gemm::collective::StageCount<StageCountQ>;
using Stages = cutlass::gemm::collective::StageCount<StageCount>;
using ClusterShape = Shape<kClusterM, _1, _1>;
// 16B alignment lets us use TMA
static constexpr int Alignment = 16 / sizeof(Element);
using TileShapeQK = TileShape;
using TileShapePV = decltype(select<0,2,1>(TileShapeQK{}));
using LayoutQKV = cute::tuple<int, _1, cute::tuple<int, int>>;
using LayoutQ = LayoutQKV;
using LayoutK = LayoutQKV;
using LayoutV = LayoutQKV;
using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Element, LayoutQ, Alignment,
Element, LayoutK, Alignment,
ElementAccumulator,
TileShapeQK, ClusterShape, Stages,
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
// the stride for A does not matter since we do not load from smem at all
Element, LayoutK, Alignment,
Element, decltype(select<1,0,2>(LayoutV{})), Alignment,
ElementAccumulator,
TileShapePV, ClusterShape, Stages,
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
using TiledMmaQK = typename CollectiveMmaQK::TiledMma;
using TiledMmaPV = decltype(convert_to_gmma_rs(typename CollectiveMmaPV::TiledMma{}));
using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int<StagesQ::value>{}));
using SmemLayoutK = typename CollectiveMmaQK::SmemLayoutB;
using SmemLayoutV = typename CollectiveMmaPV::SmemLayoutB;
using MainloopPipeline = cutlass::PipelineTmaAsync<Stages::value>;
using MainloopPipelineQ = cutlass::PipelineTmaAsync<StagesQ::value>;
using PipelineState = typename cutlass::PipelineState<MainloopPipeline::Stages>;
using PipelineStateQ = typename cutlass::PipelineState<MainloopPipelineQ::Stages>;
using TileShapeOut = TileShapePV;
using TiledMmaOut = TiledMmaPV;
using ElementOut = ElementAccumulator;
struct SharedStorage {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
union {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
};
};
struct Arguments {
const Element* ptr_Q;
LayoutQ dQ;
const Element* ptr_K;
LayoutK dK;
const Element* ptr_V;
LayoutV dV;
};
using TMA_Q = typename CollectiveMmaQK::Params::TMA_A;
using TMA_K = typename CollectiveMmaQK::Params::TMA_B;
using TMA_V = typename CollectiveMmaPV::Params::TMA_B;
struct Params {
TMA_Q tma_load_q;
TMA_K tma_load_k;
TMA_V tma_load_v;
float scale_softmax;
float scale_softmax_log2;
float rp_dropout;
};
using LoadQ = cutlass::fmha::collective::CollectiveLoadTma<
cutlass::fmha::collective::LoadKind::kQ,
MainloopPipelineQ,
Element,
SmemLayoutQ,
TMA_Q
>;
using LoadK = cutlass::fmha::collective::CollectiveLoadTma<
cutlass::fmha::collective::LoadKind::kK,
MainloopPipeline,
Element,
SmemLayoutK,
TMA_K
>;
using LoadV = cutlass::fmha::collective::CollectiveLoadTma<
cutlass::fmha::collective::LoadKind::kV,
MainloopPipeline,
Element,
SmemLayoutV,
TMA_V
>;
static_assert(size(typename CollectiveMmaQK::TiledMma{}) == size(typename CollectiveMmaPV::TiledMma{}));
static const int MaxThreadsPerBlock = size(typename CollectiveMmaQK::TiledMma{});
template<class ProblemShape>
static bool can_implement(ProblemShape const& problem_size, Arguments const& args) {
return true
&& (get<4>(problem_size) <= get<2>(TileShape{}))
&& ((get<4>(problem_size) % Alignment) == 0)
&& ((get<2>(problem_size) % Alignment) == 0)
;
}
template<class ProblemShape>
static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace) {
auto problem_shape_qk = make_shape(get<2>(problem_size), get<3>(problem_size), get<4>(problem_size), make_shape(get<0>(problem_size), get<1>(problem_size)));
auto params_qk = CollectiveMmaQK::to_underlying_arguments(problem_shape_qk,
typename CollectiveMmaQK::Arguments {
args.ptr_Q, args.dQ,
args.ptr_K, args.dK,
}, /*workspace=*/ nullptr);
auto problem_shape_pv = select<0,2,1,3>(problem_shape_qk);
auto params_pv = CollectiveMmaPV::to_underlying_arguments(problem_shape_pv,
typename CollectiveMmaPV::Arguments {
args.ptr_K, args.dK, // never used, dummy
args.ptr_V, select<1,0,2>(args.dV),
}, /*workspace=*/ nullptr);
return Params{
params_qk.tma_load_a,
params_qk.tma_load_b,
params_pv.tma_load_b,
1.0f / (float) std::sqrt(get<4>(problem_size)),
(float) (std::log2(std::exp(1.0)) / std::sqrt(get<4>(problem_size))),
1.0f
};
}
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& params) {
cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor());
cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor());
cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor());
}
template<class BlkCoord, class ProblemShape>
CUTLASS_DEVICE auto
compute(
int block_rank_in_cluster,
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size,
MainloopPipeline& pipeline, PipelineState& smem_pipe_read, PipelineState& smem_pipe_write,
MainloopPipelineQ& pipeline_q, PipelineStateQ& smem_pipe_read_q, PipelineStateQ& smem_pipe_write_q,
SharedStorage& storage)
{
int warp_idx = cutlass::canonical_warp_idx_sync();
int thread_idx = threadIdx.x;
PipelineState smem_pipe_release = smem_pipe_read;
[[maybe_unused]] PipelineStateQ smem_pipe_release_q = smem_pipe_read_q;
int fusion_tile_count = Fusion{}.get_trip_count(blk_coord, TileShape{}, problem_size);
LoadQ load_q{params.tma_load_q, pipeline_q, storage.smem_q};
auto load_state_q = load_q.init_state(_0{}, problem_size, TileShapeQK{}, blk_coord, 1);
LoadK load_k{params.tma_load_k, pipeline, storage.smem_k};
auto load_state_k = load_k.init_state(block_rank_in_cluster, problem_size, TileShapeQK{}, blk_coord, fusion_tile_count);
LoadV load_v{params.tma_load_v, pipeline, storage.smem_v};
auto load_state_v = load_v.init_state(block_rank_in_cluster, problem_size, TileShapePV{}, blk_coord, fusion_tile_count);
// Set predicate for the lowest lane_id in the warp
int lane_predicate = cute::elect_one_sync();
// Issue TmaLoads (Prologue fetches)
if (warp_idx == 0) {
auto q_tile_iter = cute::make_coord_iterator(1);
int q_tile_count = 1;
load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, q_tile_count);
}
// Loop over K elems
auto k_tile_iter = cute::make_coord_iterator(fusion_tile_count);
int k_tile_count_tma = 2 * fusion_tile_count;
uint16_t mcast_mask_b = 0;
if (warp_idx == 0 && lane_predicate == 1) {
if constexpr (cute::is_same_v<typename CollectiveMmaQK::GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
for (int m = 0; m < size<0>(block_layout); ++m) {
mcast_mask_b |= (uint16_t(1) << block_layout(m,_0{},Int<0>{}));
}
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < StageCount; i++) {
if (i % 2 == 0) {
load_k.template step<false>(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b);
} else {
load_v.template step<true>(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b);
}
}
}
TiledMmaQK tiled_mma_qk;
auto thr_mma_qk = tiled_mma_qk.get_thread_slice(thread_idx);
// Mainloop setup QK
Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{});
Tensor tSsQ = thr_mma_qk.partition_A(sQ); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tSsK = thr_mma_qk.partition_B(sK); // (MMA,MMA_N,MMA_K,PIPE)
Tensor tSrQ = thr_mma_qk.make_fragment_A(tSsQ); // (MMA,MMA_N,MMA_K,PIPE)
Tensor tSrK = thr_mma_qk.make_fragment_B(tSsK); // (MMA,MMA_M,MMA_N,PIPE)
// Prepare: MMA PV
TiledMmaPV tiled_mma_pv;
auto thr_mma_pv = tiled_mma_pv.get_thread_slice(thread_idx);
// Mainloop setup PV
Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{});
Tensor tOsV = thr_mma_pv.partition_B(sV); // (MMA,MMA_N,MMA_K,PIPE)
Tensor tOrV = thr_mma_pv.make_fragment_B(tOsV); // (MMA,MMA_M,MMA_N,PIPE)
int k_tile_count = Fusion{}.get_unmasked_trip_count(blk_coord, TileShape{}, problem_size);
pipeline_q.consumer_wait(smem_pipe_read_q);
// mapping into QK accumulator
Tensor cP = make_identity_tensor(take<0,2>(TileShapeQK{}));
Tensor tPcP = thr_mma_qk.partition_C(cP);
int m_block = get<0>(blk_coord);
tPcP.data() = tPcP.data() + E<0>{} * m_block * get<0>(TileShapeQK{});
// Allocate PV acc
Tensor acc_pv = partition_fragment_C(tiled_mma_pv, take<0, 2>(TileShapePV{}));
cutlass::fmha::collective::CollectiveSoftmax<ElementAccumulator, Fusion, decltype(params)> softmax{params};
auto softmax_state = softmax.init(acc_pv, tiled_mma_pv);
if (true)
{
--k_tile_count;
// Allocate QK acc
Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{}));
pipeline.consumer_wait(smem_pipe_read);
// MMA QK
warpgroup_fence_operand(acc_qk);
warpgroup_arrive();
gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,_0{}), tSrK(_,_,_,smem_pipe_read.index()), acc_qk);
warpgroup_commit_batch();
++smem_pipe_read;
// Wait for the pipeline MMAs to drain
warpgroup_wait<0>();
warpgroup_fence_operand(acc_qk);
softmax.step(acc_qk, tiled_mma_qk, tPcP, softmax_state, problem_size);
Tensor acc_qk_fixed = make_fragment_like<Element>(convert_c_layout_to_a_layout(acc_qk.layout(), shape<1>(typename decltype(tiled_mma_pv)::LayoutA_TV{})));
Tensor acc_qk_input = make_tensor(acc_qk_fixed.data(), acc_qk.layout());
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_qk); i++) {
acc_qk_input(i) = static_cast<Element>(acc_qk(i));
}
pipeline.consumer_wait(smem_pipe_read);
// MMA PV
warpgroup_fence_operand(acc_pv);
warpgroup_fence_operand(acc_qk_fixed);
warpgroup_arrive();
gemm_zero_acc(tiled_mma_pv, acc_qk_fixed, tOrV(_,_,_,smem_pipe_read.index()), acc_pv);
warpgroup_commit_batch();
//
// Advance the pipe
//
// Advance consumer pipeline
++smem_pipe_read;
pipeline.consumer_release(smem_pipe_release);
++smem_pipe_release;
tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{});
}
CUTLASS_PRAGMA_NO_UNROLL
for ( ; k_tile_count > 0; --k_tile_count)
{
// Allocate QK acc
Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{}));
pipeline.consumer_wait(smem_pipe_read);
// MMA QK
warpgroup_fence_operand(acc_qk);
warpgroup_arrive();
gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,_0{}), tSrK(_,_,_,smem_pipe_read.index()), acc_qk);
warpgroup_commit_batch();
++smem_pipe_read;
if (warp_idx == 0) {
load_k.template step<false>(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b);
}
// Wait for the pipeline MMAs to drain
warpgroup_wait<0>();
warpgroup_fence_operand(acc_qk);
warpgroup_fence_operand(acc_pv);
softmax.template step_interleave_begin<false>(acc_qk, tiled_mma_qk, tPcP, softmax_state, acc_pv, tiled_mma_pv, problem_size);
pipeline.consumer_release(smem_pipe_release);
++smem_pipe_release;
pipeline.consumer_wait(smem_pipe_read);
// MMA PV
auto layout_qk_input = convert_c_layout_to_a_layout(acc_qk.layout(), shape<1>(typename decltype(tiled_mma_pv)::LayoutA_TV{}));
Tensor acc_qk_input = make_tensor(acc_qk.data(), layout_qk_input);
static_assert(decltype(size<1>(layout_qk_input) == _1{})::value);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<2>(tOrV); i++) {
Tensor acc_qk_element = make_fragment_like<Element>(layout_qk_input(_, _0{}, _0{}));
Tensor acc_qk_element_mk = tensor_op_mk_v(tiled_mma_pv, acc_qk_element);
Tensor acc_qk_input_mk = tensor_op_mk_v(tiled_mma_pv, acc_qk_input(_, _0{}, i));
softmax.step_interleave_step(acc_qk_input_mk, softmax_state);
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size(acc_qk_element_mk); j++) {
acc_qk_element_mk(j) = static_cast<Element>(acc_qk_input_mk(j));
}
warpgroup_arrive();
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size<1>(tOrV); j++) {
cute::gemm(tiled_mma_pv, acc_qk_element, tOrV(_,j,i,smem_pipe_read.index()), acc_pv(_,_0{},j));
}
}
warpgroup_commit_batch();
// Wait for the pipeline MMAs to drain
pipeline.consumer_release(smem_pipe_release);
++smem_pipe_release;
++smem_pipe_read;
if (warp_idx == 0) {
load_v.template step<true>(k_tile_iter, load_state_v, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b);
}
tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{});
}
k_tile_count += Fusion{}.get_masked_trip_count(blk_coord, TileShape{}, problem_size);
CUTLASS_PRAGMA_NO_UNROLL
for ( ; k_tile_count > 0; --k_tile_count)
{
// Allocate QK acc
Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{}));
pipeline.consumer_wait(smem_pipe_read);
// MMA QK
warpgroup_fence_operand(acc_qk);
warpgroup_arrive();
gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,_0{}), tSrK(_,_,_,smem_pipe_read.index()), acc_qk);
warpgroup_commit_batch();
++smem_pipe_read;
if (warp_idx == 0) {
load_k.template step<false>(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b);
}
// Wait for the pipeline MMAs to drain
warpgroup_wait<0>();
warpgroup_fence_operand(acc_qk);
warpgroup_fence_operand(acc_pv);
softmax.step_interleave_begin(acc_qk, tiled_mma_qk, tPcP, softmax_state, acc_pv, tiled_mma_pv, problem_size);
pipeline.consumer_release(smem_pipe_release);
++smem_pipe_release;
pipeline.consumer_wait(smem_pipe_read);
// MMA PV
auto layout_qk_input = convert_c_layout_to_a_layout(acc_qk.layout(), shape<1>(typename decltype(tiled_mma_pv)::LayoutA_TV{}));
Tensor acc_qk_input = make_tensor(acc_qk.data(), layout_qk_input);
static_assert(decltype(size<1>(layout_qk_input) == _1{})::value);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<2>(tOrV); i++) {
Tensor acc_qk_element = make_fragment_like<Element>(layout_qk_input(_, _0{}, _0{}));
Tensor acc_qk_element_mk = tensor_op_mk_v(tiled_mma_pv, acc_qk_element);
Tensor acc_qk_input_mk = tensor_op_mk_v(tiled_mma_pv, acc_qk_input(_, _0{}, i));
softmax.step_interleave_step(acc_qk_input_mk, softmax_state);
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size(acc_qk_element_mk); j++) {
acc_qk_element_mk(j) = static_cast<Element>(acc_qk_input_mk(j));
}
warpgroup_arrive();
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size<1>(tOrV); j++) {
cute::gemm(tiled_mma_pv, acc_qk_element, tOrV(_,j,i,smem_pipe_read.index()), acc_pv(_,_0{},j));
}
}
warpgroup_commit_batch();
// Wait for the pipeline MMAs to drain
pipeline.consumer_release(smem_pipe_release);
++smem_pipe_release;
++smem_pipe_read;
if (warp_idx == 0) {
load_v.template step<true>(k_tile_iter, load_state_v, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b);
}
tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{});
}
// Wait for the pipeline MMAs to drain
warpgroup_wait<0>();
warpgroup_fence_operand(acc_pv);
Tensor lse = softmax.tail(softmax_state, acc_pv, tiled_mma_pv);
return make_tuple(acc_pv, lse);
}
};
} // namespace cutlass::fmha::collective

View File

@ -0,0 +1,560 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "../collective/fmha_common.hpp"
#include "../collective/fmha_collective_load.hpp"
#include "../collective/fmha_collective_softmax.hpp"
#include "../kernel/fmha_options.hpp"
namespace cutlass::fmha::collective {
using namespace cute;
using cutlass::fmha::kernel::Tag;
using cutlass::fmha::kernel::find_option_t;
template<
class Element_,
class ElementAccumulatorQK_,
class ElementAccumulatorPV_,
class TileShape_, // SeqQ, SeqKV, Head
class LayoutQ_, class LayoutK_, class LayoutV_, // SeqX, Head, (Batches)
class Fusion,
class... Options
>
struct FmhaMainloopTmaWarpSpecialized {
using Element = Element_;
using ElementAccumulatorQK = ElementAccumulatorQK_;
using ElementAccumulatorPV = ElementAccumulatorPV_;
using TileShape = TileShape_;
using LayoutQ = LayoutQ_;
using LayoutK = LayoutK_;
using LayoutV = LayoutV_;
// Options
static constexpr bool kIsPersistent = find_option_t<Tag::kIsPersistent, false_type, Options...>::value;
static constexpr bool kIsMainloopLocked = find_option_t<Tag::kIsMainloopLocked, false_type, Options...>::value;
static constexpr int NumLoadWarpGroups = 1;
static constexpr int NumMmaWarpGroups = find_option_t<Tag::kNumMmaWarpGroups, Int<2>, Options...>::value;
static constexpr int StageCount = find_option_t<Tag::kStagesKV, Int<5>, Options...>::value;
static constexpr int StageCountQ = find_option_t<Tag::kStagesQ, Int<NumMmaWarpGroups>, Options...>::value;
static const int kOuterLoads = 1;
using StagesQ = cutlass::gemm::collective::StageCount<StageCountQ>;
using Stages = cutlass::gemm::collective::StageCount<StageCount>;
using ClusterShape = Shape<_1, _1, _1>;
static_assert(StagesQ::value >= NumMmaWarpGroups);
static_assert(Stages::value >= 2);
// 16B alignment lets us use TMA
static constexpr int Alignment = 16 / sizeof(Element);
using TileShapeQK = Shape<
decltype(tuple_element_t<0, TileShape>{} / Int<NumMmaWarpGroups>{}),
tuple_element_t<1, TileShape>,
tuple_element_t<2, TileShape>>;
using TileShapePV = decltype(select<0,2,1>(TileShapeQK{}));
using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Element, LayoutQ, Alignment,
Element, LayoutK, Alignment,
ElementAccumulatorQK,
TileShapeQK, ClusterShape, Stages,
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
// the stride for A does not matter since we do not load from smem at all
Element, LayoutK, Alignment,
Element, decltype(select<1,0,2>(LayoutV{})), Alignment,
ElementAccumulatorPV,
TileShapePV, ClusterShape, Stages,
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
using TiledMmaQK = typename CollectiveMmaQK::TiledMma;
using TiledMmaPV = decltype(convert_to_gmma_rs(typename CollectiveMmaPV::TiledMma{}));
using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int<StagesQ::value>{}));
using SmemLayoutK = typename CollectiveMmaQK::SmemLayoutB;
using SmemLayoutV = typename CollectiveMmaPV::SmemLayoutB;
using MainloopPipeline = cutlass::PipelineTmaAsync<Stages::value>;
using MainloopPipelineQ = cutlass::PipelineTmaAsync<StagesQ::value>;
using PipelineState = typename cutlass::PipelineState<MainloopPipeline::Stages>;
using PipelineStateQ = typename cutlass::PipelineState<MainloopPipelineQ::Stages>;
static constexpr int kInnerLoadBytes = size(SmemLayoutK{}(_,_,_0{})) * sizeof(Element);
static constexpr int kOuterLoadBytes = size(SmemLayoutQ{}(_,_,_0{})) * sizeof(Element);
using TileShapeOut = TileShapePV;
using TiledMmaOut = TiledMmaPV;
using ElementOut = ElementAccumulatorPV;
struct SharedStorage {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
union {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
};
};
struct Arguments {
const Element* ptr_Q;
LayoutQ dQ;
const Element* ptr_K;
LayoutK dK;
const Element* ptr_V;
LayoutV dV;
};
using TMA_Q = typename CollectiveMmaQK::Params::TMA_A;
using TMA_K = typename CollectiveMmaQK::Params::TMA_B;
using TMA_V = typename CollectiveMmaPV::Params::TMA_B;
struct Params {
TMA_Q tma_load_q;
TMA_K tma_load_k;
TMA_V tma_load_v;
float scale_softmax;
float scale_softmax_log2;
float rp_dropout;
};
using LoadQ = cutlass::fmha::collective::CollectiveLoadTma<
cutlass::fmha::collective::LoadKind::kQ,
MainloopPipelineQ,
Element,
SmemLayoutQ,
TMA_Q
>;
using LoadK = cutlass::fmha::collective::CollectiveLoadTma<
cutlass::fmha::collective::LoadKind::kK,
MainloopPipeline,
Element,
SmemLayoutK,
TMA_K
>;
using LoadV = cutlass::fmha::collective::CollectiveLoadTma<
cutlass::fmha::collective::LoadKind::kV,
MainloopPipeline,
Element,
SmemLayoutV,
TMA_V
>;
static_assert(size(typename CollectiveMmaQK::TiledMma{}) == size(typename CollectiveMmaPV::TiledMma{}));
template<class ProblemShape>
static bool can_implement(ProblemShape const& problem_size, Arguments const& args) {
return true
&& (get<4>(problem_size) <= get<2>(TileShape{}))
&& ((get<4>(problem_size) % Alignment) == 0)
&& ((get<2>(problem_size) % Alignment) == 0)
;
}
template<class ProblemShape>
static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace) {
auto problem_shape_qk = make_shape(get<2>(problem_size), get<3>(problem_size), get<4>(problem_size), make_shape(get<0>(problem_size), get<1>(problem_size)));
auto params_qk = CollectiveMmaQK::to_underlying_arguments(problem_shape_qk,
typename CollectiveMmaQK::Arguments {
args.ptr_Q, args.dQ,
args.ptr_K, args.dK,
}, /*workspace=*/ nullptr);
auto problem_shape_pv = select<0,2,1,3>(problem_shape_qk);
auto params_pv = CollectiveMmaPV::to_underlying_arguments(problem_shape_pv,
typename CollectiveMmaPV::Arguments {
args.ptr_K, args.dK, // never used, dummy
args.ptr_V, select<1,0,2>(args.dV),
}, /*workspace=*/ nullptr);
return Params{
params_qk.tma_load_a,
params_qk.tma_load_b,
params_pv.tma_load_b,
1.0f / (float) std::sqrt(get<4>(problem_size)),
(float) (std::log2(std::exp(1.0)) / std::sqrt(get<4>(problem_size))),
1.0f
};
}
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& params) {
cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor());
cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor());
cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor());
}
template<bool kLoadQ, class BlkCoord, class ProblemShape, class LoadWarpBarrier>
CUTLASS_DEVICE void
load_kv_maybe_q(
int block_rank_in_cluster,
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size,
MainloopPipeline& pipeline, PipelineState& smem_pipe_write,
MainloopPipelineQ& pipeline_q, PipelineStateQ& smem_pipe_write_q,
SharedStorage& storage,
LoadWarpBarrier& load_warp_barrier, bool do_barrier)
{
int fusion_tile_count = Fusion{}.get_trip_count(blk_coord, TileShape{}, problem_size);
int lane_predicate = cute::elect_one_sync();
uint16_t mcast_mask_b = 0;
if (lane_predicate == 1) {
if constexpr (cute::is_same_v<typename CollectiveMmaQK::GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
for (int m = 0; m < size<0>(block_layout); ++m) {
mcast_mask_b |= (uint16_t(1) << block_layout(m,_0{},Int<0>{}));
}
}
}
auto q_tile_iter = cute::make_coord_iterator(Int<NumMmaWarpGroups>{});
[[maybe_unused]] int q_tile_count = NumMmaWarpGroups;
auto k_tile_iter = cute::make_coord_iterator(fusion_tile_count);
int k_tile_count = 2 * fusion_tile_count;
LoadQ load_q{params.tma_load_q, pipeline_q, storage.smem_q};
auto load_state_q = load_q.init_state(_0{}, problem_size, TileShapeQK{}, blk_coord, NumMmaWarpGroups);
LoadK load_k{params.tma_load_k, pipeline, storage.smem_k};
auto load_state_k = load_k.init_state(block_rank_in_cluster, problem_size, TileShapeQK{}, blk_coord, fusion_tile_count);
LoadV load_v{params.tma_load_v, pipeline, storage.smem_v};
auto load_state_v = load_v.init_state(block_rank_in_cluster, problem_size, TileShapePV{}, blk_coord, fusion_tile_count);
if constexpr (kLoadQ) {
load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, q_tile_count);
}
load_k.template step<false>(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count, mcast_mask_b);
if constexpr (kLoadQ) {
load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, q_tile_count);
}
if constexpr (! kLoadQ) {
if (do_barrier) {
load_warp_barrier.arrive();
load_warp_barrier.wait(/*phase=*/ 0);
do_barrier = false;
}
}
load_v.template step<true>(k_tile_iter, load_state_v, smem_pipe_write, lane_predicate, k_tile_count, mcast_mask_b);
if constexpr (kLoadQ) {
while (q_tile_count > 0) {
load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, q_tile_count);
}
}
CUTLASS_PRAGMA_NO_UNROLL
while (k_tile_count > 0) {
load_k.template step<false>(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count, mcast_mask_b);
load_v.template step<true>(k_tile_iter, load_state_v, smem_pipe_write, lane_predicate, k_tile_count, mcast_mask_b);
}
}
template<class BlkCoord, class ProblemShape, class LoadWarpBarrier>
CUTLASS_DEVICE void
load_maybe_q(
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size,
MainloopPipelineQ& pipeline_q, PipelineStateQ& smem_pipe_write_q,
SharedStorage& storage,
LoadWarpBarrier& load_warp_barrier, bool do_barrier)
{
int lane_predicate = cute::elect_one_sync();
LoadQ load_q{params.tma_load_q, pipeline_q, storage.smem_q};
auto load_state_q = load_q.init_state(_0{}, problem_size, TileShapeQK{}, blk_coord, NumMmaWarpGroups);
auto q_tile_iter = cute::make_coord_iterator(Int<NumMmaWarpGroups>{});
CUTLASS_PRAGMA_UNROLL
for (int q_tile_count = 0; q_tile_count < NumMmaWarpGroups; q_tile_count++) {
int count = 1;
load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, count);
if (q_tile_count == 0 && do_barrier) {
load_warp_barrier.arrive();
load_warp_barrier.wait(/*phase=*/ 0);
do_barrier = false;
}
}
}
template<class BlkCoord, class ProblemShape, class MainloopPipelineReducer, class PipelineStateReducer>
CUTLASS_DEVICE void
reduce(
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size,
MainloopPipelineReducer& pipeline_reducer, PipelineStateReducer& smem_pipe_write_reducer,
SharedStorage& storage)
{ /* no-op */ }
template<class BlkCoord, class ProblemShape, class MainloopPipelineReducer, class PipelineStateReducer, class MathWgOrderBarrier>
CUTLASS_DEVICE auto
compute(
BlkCoord const& blk_coord, BlkCoord const& wg_coord,
Params const& params, ProblemShape const& problem_size,
MainloopPipeline& pipeline, PipelineState& smem_pipe_read,
MainloopPipelineQ& pipeline_q, PipelineStateQ& smem_pipe_read_q,
MainloopPipelineReducer&, PipelineStateReducer&,
SharedStorage& storage,
MathWgOrderBarrier& math_wg_order_barrier)
{
int thread_idx = int(threadIdx.x);
PipelineState smem_pipe_release = smem_pipe_read;
PipelineStateQ smem_pipe_release_q = smem_pipe_read_q;
TiledMmaQK tiled_mma_qk;
auto thr_mma_qk = tiled_mma_qk.get_thread_slice(thread_idx);
// Mainloop setup QK
Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{});
Tensor tSsQ = thr_mma_qk.partition_A(sQ); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tSsK = thr_mma_qk.partition_B(sK); // (MMA,MMA_N,MMA_K,PIPE)
Tensor tSrQ = thr_mma_qk.make_fragment_A(tSsQ); // (MMA,MMA_N,MMA_K,PIPE)
Tensor tSrK = thr_mma_qk.make_fragment_B(tSsK); // (MMA,MMA_M,MMA_N,PIPE)
// Prepare: MMA PV
TiledMmaPV tiled_mma_pv;
auto thr_mma_pv = tiled_mma_pv.get_thread_slice(thread_idx);
// Mainloop setup PV
Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{});
Tensor tOsV = thr_mma_pv.partition_B(sV); // (MMA,MMA_N,MMA_K,PIPE)
Tensor tOrV = thr_mma_pv.make_fragment_B(tOsV); // (MMA,MMA_M,MMA_N,PIPE)
int k_tile_count = Fusion{}.get_unmasked_trip_count(blk_coord, TileShape{}, problem_size);
pipeline_q.consumer_wait(smem_pipe_read_q);
// mapping into QK accumulator
Tensor cP = make_identity_tensor(take<0,2>(TileShapeQK{}));
Tensor tPcP = thr_mma_qk.partition_C(cP);
int m_block = get<0>(wg_coord);
tPcP.data() = tPcP.data() + E<0>{} * m_block * get<0>(TileShapeQK{});
// Allocate PV acc
Tensor acc_pv = partition_fragment_C(tiled_mma_pv, take<0, 2>(TileShapePV{}));
cutlass::fmha::collective::CollectiveSoftmax<ElementAccumulatorQK, Fusion, decltype(params)> softmax{params};
auto softmax_state = softmax.init(acc_pv, tiled_mma_pv);
if (true)
{
--k_tile_count;
// Allocate QK acc
Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{}));
pipeline.consumer_wait(smem_pipe_read);
math_wg_order_barrier.wait();
// MMA QK
warpgroup_fence_operand(acc_qk);
warpgroup_arrive();
gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,smem_pipe_read_q.index()), tSrK(_,_,_,smem_pipe_read.index()), acc_qk);
warpgroup_commit_batch();
math_wg_order_barrier.arrive();
++smem_pipe_read;
// Wait for the pipeline MMAs to drain
warpgroup_wait<0>();
warpgroup_fence_operand(acc_qk);
softmax.step(acc_qk, tiled_mma_qk, tPcP, softmax_state, problem_size);
Tensor acc_qk_fixed = make_acc_into_op<Element>(acc_qk, typename TiledMmaPV::LayoutA_TV{});
pipeline.consumer_wait(smem_pipe_read);
// MMA PV
warpgroup_fence_operand(acc_pv);
warpgroup_fence_operand(acc_qk_fixed);
warpgroup_arrive();
gemm_zero_acc(tiled_mma_pv, acc_qk_fixed, tOrV(_,_,_,smem_pipe_read.index()), acc_pv);
warpgroup_commit_batch();
pipeline.consumer_release(smem_pipe_release);
++smem_pipe_release;
// Advance consumer pipeline
++smem_pipe_read;
tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{});
}
CUTLASS_PRAGMA_NO_UNROLL
while (k_tile_count > 0)
{
--k_tile_count;
// Allocate QK acc
Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{}));
pipeline.consumer_wait(smem_pipe_read);
// MMA QK
warpgroup_fence_operand(acc_qk);
warpgroup_arrive();
gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,smem_pipe_read_q.index()), tSrK(_,_,_,smem_pipe_read.index()), acc_qk);
warpgroup_commit_batch();
++smem_pipe_read;
auto tok = pipeline.consumer_try_wait(smem_pipe_read);
// Wait for the pipeline MMAs to drain
warpgroup_wait<0>();
warpgroup_fence_operand(acc_qk);
warpgroup_fence_operand(acc_pv);
if constexpr (kIsMainloopLocked) math_wg_order_barrier.wait();
softmax.template step<false>(acc_qk, tiled_mma_qk, tPcP, softmax_state, acc_pv, tiled_mma_pv, problem_size);
if constexpr (kIsMainloopLocked) math_wg_order_barrier.arrive();
Tensor acc_qk_fixed = make_acc_into_op<Element>(acc_qk, typename TiledMmaPV::LayoutA_TV{});
pipeline.consumer_wait(smem_pipe_read, tok);
// MMA PV
warpgroup_fence_operand(acc_pv);
warpgroup_fence_operand(acc_qk_fixed);
warpgroup_arrive();
cute::gemm(tiled_mma_pv, acc_qk_fixed, tOrV(_,_,_,smem_pipe_read.index()), acc_pv);
warpgroup_commit_batch();
pipeline.consumer_release(smem_pipe_release);
++smem_pipe_release;
pipeline.consumer_release(smem_pipe_release);
++smem_pipe_release;
++smem_pipe_read;
tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{});
}
k_tile_count += Fusion{}.get_masked_trip_count(blk_coord, TileShape{}, problem_size);
CUTLASS_PRAGMA_NO_UNROLL
while (k_tile_count > 0)
{
--k_tile_count;
// Allocate QK acc
Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{}));
pipeline.consumer_wait(smem_pipe_read);
// MMA QK
warpgroup_fence_operand(acc_qk);
warpgroup_arrive();
gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,smem_pipe_read_q.index()), tSrK(_,_,_,smem_pipe_read.index()), acc_qk);
warpgroup_commit_batch();
++smem_pipe_read;
auto tok = pipeline.consumer_try_wait(smem_pipe_read);
// Wait for the pipeline MMAs to drain
warpgroup_wait<0>();
warpgroup_fence_operand(acc_qk);
warpgroup_fence_operand(acc_pv);
//if constexpr (kIsPersistent)
// if (k_tile_count == 0) pipeline_q.consumer_release(smem_pipe_release_q);
if constexpr (kIsMainloopLocked) math_wg_order_barrier.wait();
softmax.step(acc_qk, tiled_mma_qk, tPcP, softmax_state, acc_pv, tiled_mma_pv, problem_size);
if constexpr (kIsMainloopLocked) math_wg_order_barrier.arrive();
Tensor acc_qk_fixed = make_acc_into_op<Element>(acc_qk, typename TiledMmaPV::LayoutA_TV{});
pipeline.consumer_wait(smem_pipe_read, tok);
// MMA PV
warpgroup_fence_operand(acc_pv);
warpgroup_fence_operand(acc_qk_fixed);
warpgroup_arrive();
cute::gemm(tiled_mma_pv, acc_qk_fixed, tOrV(_,_,_,smem_pipe_read.index()), acc_pv);
warpgroup_commit_batch();
pipeline.consumer_release(smem_pipe_release);
++smem_pipe_release;
pipeline.consumer_release(smem_pipe_release);
++smem_pipe_release;
++smem_pipe_read;
tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{});
}
if (kIsPersistent) pipeline_q.consumer_release(smem_pipe_release_q);
// Wait for the pipeline MMAs to drain
warpgroup_wait<0>();
warpgroup_fence_operand(acc_pv);
if (kIsPersistent) pipeline.consumer_release(smem_pipe_release);
++smem_pipe_release;
Tensor lse = softmax.tail(softmax_state, acc_pv, tiled_mma_pv);
return make_tuple(acc_pv, lse);
}
};
} // namespace cutlass::fmha::collective

View File

@ -0,0 +1,245 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/kernel_hardware_info.h"
#include "cute/tensor.hpp"
namespace cutlass::fmha::collective {
using namespace cute;
template<typename Atom, typename TA, typename TB, typename TC>
CUTE_DEVICE void gemm_reset_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) {
constexpr int rA = decltype(rank(tA))::value;
constexpr int rB = decltype(rank(tB))::value;
constexpr int rC = decltype(rank(tC))::value;
if constexpr (rA == 2 && rB == 2 && rC == 1) {
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<1>(tA); k_block++) {
cute::gemm(atom, tA(_,k_block), tB(_,k_block), tC);
atom.accumulate_ = GMMA::ScaleOut::One;
}
} else {
static_assert(rA == 3 && rB == 3 && rC == 3);
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tA); k_block++) {
cute::gemm(atom, tA(_,_,k_block), tB(_,_,k_block), tC);
atom.accumulate_ = GMMA::ScaleOut::One;
}
}
}
template<typename Atom, typename TA, typename TB, typename TC>
CUTE_DEVICE void gemm_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) {
atom.accumulate_ = GMMA::ScaleOut::Zero;
gemm_reset_zero_acc(atom, tA, tB, tC);
}
template<typename T, typename Fn>
CUTE_DEVICE constexpr typename T::value_type reduce(T const& t, Fn fn) {
if constexpr (decltype(size(t) % _2{} == _0{})::value) {
auto partial = make_tensor<typename T::value_type>(size(t) / _2{});
CUTE_UNROLL
for (int i = 0; i < size(partial); i++) {
partial(i) = fn(t(i), t(i + size(partial)));
}
return reduce(partial, fn);
} else {
auto result = t(_0{});
CUTE_UNROLL
for (int i = 1; i < size(t); i++) {
result = fn(result, t(i));
}
return result;
}
}
struct fmha_max {
CUTE_DEVICE float operator()(float a, float b) { return ::max(a, b); }
};
template<typename Threshold, typename Source, typename Reference>
inline auto __device__ constexpr layout_separate(Threshold const& thr,
Source const& src, Reference const& ref) {
auto lt = filter(transform_layout(src, ref, [&](auto const& s, auto const& r) {
if constexpr(decltype(r < thr)::value) {
return s;
} else {
return make_layout(_1{}, _0{});
}
}));
auto ge = filter(transform_layout(src, ref, [&](auto const& s, auto const& r) {
if constexpr(decltype(r >= thr)::value) {
return s;
} else {
return make_layout(_1{}, _0{});
}
}));
return make_layout(lt, ge);
}
template<typename TiledMma, typename Acc>
inline auto __device__ constexpr layout_acc_mn(TiledMma const& tiled_mma, Acc const& acc) {
auto separated = layout_separate(get<0>(typename TiledMma::Shape_MNK{}),
get<0>(acc), stride<1>(typename TiledMma::LayoutC_TV{}));
auto V_M = get<0>(separated);
auto V_N = get<1>(separated);
return make_layout(make_layout(V_M, get<1>(acc)), make_layout(V_N, get<2>(acc)));
}
template<typename TiledMma, typename Acc>
inline auto __device__ constexpr layout_op_mk_v(TiledMma const& tiled_mma, Acc const& acc) {
return layout_separate(get<0>(typename TiledMma::Shape_MNK{}),
get<0>(acc), stride<1>(typename TiledMma::LayoutA_TV{}));
}
template<typename TiledMma, typename Acc>
inline auto __device__ constexpr tensor_op_mk_v(TiledMma const& tiled_mma, Acc&& acc) {
return make_tensor(acc.data(), layout_op_mk_v(tiled_mma, acc.layout()));
}
template<typename TiledMma>
inline auto __device__ constexpr reduction_target_n(TiledMma const& tiled_mma) {
auto separated = layout_separate(get<0>(typename TiledMma::Shape_MNK{}),
make_layout(shape<0>(typename TiledMma::LayoutC_TV{})),
stride<0>(typename TiledMma::LayoutC_TV{}));
return get<1>(separated);
}
template<template<cute::GMMA::Major, cute::GMMA::Major, cute::GMMA::ScaleIn, cute::GMMA::ScaleIn> class Primitive, cute::GMMA::Major tA, cute::GMMA::Major tB, cute::GMMA::ScaleIn sA, cute::GMMA::ScaleIn sB>
inline auto __device__ constexpr convert_to_gmma_rs(cute::MMA_Atom<Primitive<tA, tB, sA, sB>> const& tiled_mma) {
using Atom = cute::MMA_Atom<Primitive<tA, tB, sA, sB>>;
using ElementA = typename Atom::ValTypeA;
using ElementB = typename Atom::ValTypeB;
using ElementC = typename Atom::ValTypeC;
using Shape_MNK = typename Atom::Shape_MNK;
using RS = decltype(cute::GMMA::rs_op_selector<ElementA, ElementB, ElementC, Shape_MNK, tA, tB, sA, sB>());
return cute::MMA_Atom<RS>{};
}
template<template<cute::GMMA::ScaleIn, cute::GMMA::ScaleIn> class Primitive, cute::GMMA::ScaleIn sA, cute::GMMA::ScaleIn sB>
inline auto __device__ constexpr convert_to_gmma_rs(cute::MMA_Atom<Primitive<sA, sB>> const& tiled_mma) {
using Atom = cute::MMA_Atom<Primitive<sA, sB>>;
using ElementA = typename Atom::ValTypeA;
using ElementB = typename Atom::ValTypeB;
using ElementC = typename Atom::ValTypeC;
using Shape_MNK = typename Atom::Shape_MNK;
constexpr auto tA = cute::GMMA::Major::K;
constexpr auto tB = cute::GMMA::Major::K;
using RS = decltype(cute::GMMA::rs_op_selector<ElementA, ElementB, ElementC, Shape_MNK, tA, tB, sA, sB>());
return cute::MMA_Atom<RS>{};
}
template<class Atom, class... Args>
CUTE_DEVICE auto constexpr convert_to_gmma_rs(cute::TiledMMA<Atom, Args...> const& tiled_mma) {
return cute::TiledMMA<decltype(convert_to_gmma_rs(Atom{})), Args...>{};
}
template<typename CLayout, typename AValueShape>
CUTE_DEVICE auto constexpr convert_c_layout_to_a_layout(CLayout const& c, AValueShape const& a) {
return make_layout(
make_shape(a, shape<1>(c), make_shape(shape<2>(c), size<0>(c) / size(a))),
make_stride(stride<0>(c), stride<1>(c), make_stride(stride<2>(c), size<2>(a) * stride<0,2>(c))));
}
template<class Layout, class Stages = _1>
CUTE_DEVICE constexpr auto unstageSmemLayout(Layout const& layout, Stages stages = {}) {
return composition(layout, make_tuple(_, _, make_layout(stages)));
}
template<class Element, class Accumulator, class OperandLayout_TV>
CUTE_DEVICE auto make_acc_into_op(Accumulator const& acc, OperandLayout_TV const& operand_layout_tv) {
Tensor operand = make_fragment_like<Element>(convert_c_layout_to_a_layout(acc.layout(), shape<1>(operand_layout_tv)));
Tensor operand_as_acc = make_tensor(operand.data(), acc.layout());
cute::copy(acc, operand_as_acc);
if constexpr (sizeof(Element) == 1) {
// 00 11 22 33 00 11 22 33 acc layout
// 00 00 11 11 22 22 33 33 operand layout
// BB AA AA BB AA BB BB AA conflict-free exchange pattern
// 16-bit exchange; so process two at a time potentially
int tid = threadIdx.x % 4;
auto values_u32 = recast<uint32_t>(operand);
CUTE_UNROLL
for (int n = 0; n < size<1>(values_u32); n++) {
CUTE_UNROLL
for (int k = 0; k < size<2>(values_u32); k++) {
CUTE_UNROLL
for (int ii = 0; ii < 8; ii += 4) {
uint32_t values_tmp_0 = values_u32(ii / 2 + 0, n, k);
uint32_t values_tmp_1 = values_u32(ii / 2 + 1, n, k);
// step A:
// t 1 v 0 -> t 0 v 1
// t 2 v 0 -> t 1 v 0
// t 0 v 1 -> t 2 v 0
// t 3 v 1 -> t 3 v 1
int v_to_send = tid == 1 || tid == 2 ? 0 : 1;
int v_to_recv = v_to_send;
int t_to_recv_from = (0x3021 >> (tid * 4)) & 0xF;
uint32_t values_tmp_a = v_to_send == 0 ? values_tmp_0 : values_tmp_1;
values_tmp_a = __shfl_sync(0xFFFFFFFF, values_tmp_a, t_to_recv_from, 4);
// step B:
// t 0 v 0 -> t 0 v 0
// t 3 v 0 -> t 1 v 1
// t 1 v 1 -> t 2 v 1
// t 2 v 1 -> t 3 v 0
v_to_send = 1 - v_to_send;
v_to_recv = 1 - v_to_recv;
t_to_recv_from = (0x2130 >> (tid * 4)) & 0xF;
uint32_t values_tmp_b = v_to_send == 0 ? values_tmp_0 : values_tmp_1;
values_tmp_b = __shfl_sync(0xFFFFFFFF, values_tmp_b, t_to_recv_from, 4);
values_u32(ii / 2 + 0, n, k) = __byte_perm(values_tmp_a, values_tmp_b, v_to_send == 0 ? 0x1054 : 0x5410);
values_u32(ii / 2 + 1, n, k) = __byte_perm(values_tmp_a, values_tmp_b, v_to_send == 0 ? 0x3276 : 0x7632);
}
}
}
}
return operand;
}
} // namespace cutlass::fmha::collective

View File

@ -0,0 +1,156 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "../collective/fmha_common.hpp"
namespace cutlass::fmha::collective {
template<class Element, class ElementAccumulator, class TileShape_WG>
struct FmhaFwdEpilogue {
static constexpr int Alignment = 16 / sizeof(Element);
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<Element, ElementAccumulator, void>;
using CollectiveEpilogueTMA = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape_WG, Shape<_1,_1,_1>, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
void, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment,
Element, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment,
cutlass::epilogue::TmaWarpSpecialized,
DefaultOperation
>::CollectiveOp;
struct Arguments {
Element* ptr_O;
cute::tuple<int, cute::_1, cute::tuple<int, int>> dO;
ElementAccumulator* ptr_LSE;
cute::tuple<cute::_1, cute::tuple<int, int>> dLSE;
};
struct Params {
ElementAccumulator* ptr_LSE;
cute::tuple<cute::_1, cute::tuple<int, int>> dLSE;
typename CollectiveEpilogueTMA::Params epilogue_TMA;
};
using TensorStorage = typename CollectiveEpilogueTMA::TensorStorage;
using PipelineStorage = typename CollectiveEpilogueTMA::PipelineStorage;
using LoadPipeline = typename CollectiveEpilogueTMA::LoadPipeline;
static constexpr int TmaTransactionBytes = CollectiveEpilogueTMA::TmaTransactionBytes;
template<class ProblemShape>
static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace = nullptr) {
auto problem_size_o = make_shape(get<2>(problem_size), get<4>(problem_size), 1,
make_shape(get<0>(problem_size), get<1>(problem_size)));
typename CollectiveEpilogueTMA::Arguments args_tma{{}, args.ptr_O, args.dO, args.ptr_O, args.dO};
return Params{
args.ptr_LSE, args.dLSE,
CollectiveEpilogueTMA::to_underlying_arguments(problem_size_o, args_tma, workspace)
};
}
template<class TileShape, class BlkCoord, class ResultTuple, class TiledMma, class ProblemShape>
CUTLASS_DEVICE void operator()(
TileShape const& tile_shape, BlkCoord const& blk_coord,
ResultTuple const& result, TiledMma const& tiled_mma,
ProblemShape const& problem_size, Params const& params,
LoadPipeline epi_load_pipeline,
TensorStorage& epi_tensor_storage)
{
using X = Underscore;
auto acc = get<0>(result);
auto lse = get<1>(result);
auto thr_mma = tiled_mma.get_thread_slice(threadIdx.x);
int seqlen_q = get<2>(problem_size);
int num_batch = get<0>(problem_size);
int num_heads = get<1>(problem_size);
// Epilogue for lse
Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE),
make_shape(seqlen_q, get<1>(tile_shape), make_shape(num_batch, num_heads)),
make_stride(_1{}, _0{}, get<1>(params.dLSE)));
Tensor gLSE_full = local_tile(mLSE, tile_shape, make_coord(_, _, _), Step<_1, _1, X>{});
Tensor gLSE = gLSE_full(_, _, get<0>(blk_coord), get<1>(blk_coord), get<2>(blk_coord));
Tensor tOgLSE = thr_mma.partition_C(gLSE);
Tensor cO = make_identity_tensor(take<0,2>(tile_shape));
Tensor tOcO = thr_mma.partition_C(cO);
if (get<1>(tOcO(_0{})) == 0) {
auto tOgLSE_mn = make_tensor(tOgLSE.data(), layout_acc_mn(tiled_mma, tOgLSE.layout()));
auto tOcO_mn = make_tensor(tOcO.data(), layout_acc_mn(tiled_mma, tOcO.layout()));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<0>(tOgLSE_mn); i++) {
if (get<0>(tOcO_mn(i)) + get<0>(blk_coord) * get<0>(tile_shape) < get<2>(problem_size)) {
tOgLSE_mn(i, _0{}) = lse(i);
}
}
}
auto problem_size_o = make_shape(get<2>(problem_size), get<4>(problem_size), _,
make_shape(get<0>(problem_size), get<1>(problem_size)));
CollectiveEpilogueTMA epilogue_tma(params.epilogue_TMA, epi_tensor_storage);
using EpiStorePipeline = typename CollectiveEpilogueTMA::StorePipeline;
typename EpiStorePipeline::Params epi_store_pipeline_params;
epi_store_pipeline_params.always_wait = true;
EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
typename CollectiveEpilogueTMA::LoadPipelineState epi_load_pipe_consumer_state;
PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>();
auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] =
epilogue_tma.store(
epi_load_pipeline, epi_load_pipe_consumer_state,
epi_store_pipeline, epi_store_pipe_producer_state,
problem_size_o, tile_shape, make_coord(get<0>(blk_coord), _0{}, _, get<2>(blk_coord)),
acc, tiled_mma, threadIdx.x % cutlass::NumThreadsPerWarpGroup,
epi_tensor_storage
);
epilogue_tma.store_tail(
epi_load_pipeline, epi_load_pipe_consumer_state_next,
epi_store_pipeline, epi_store_pipe_producer_state_next
);
}
};
} // namespace cutlass::fmha::collective

View File

@ -0,0 +1,157 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "../collective/fmha_epilogue.hpp"
namespace cutlass::fmha::collective {
template<class Element, class ElementAccumulator, class TileShape_WG>
struct FmhaBwdEpilogueKV {
static constexpr int Alignment = 16 / sizeof(Element);
struct Arguments {
Element* ptr_K;
cute::tuple<int, int, int, cute::_1> dK;
Element* ptr_V;
cute::tuple<int, int, int, _1> dV;
};
//using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<Element, ElementAccumulator, void>;
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
using DefaultOperation = cutlass::epilogue::fusion::Sm90EVT<
cutlass::epilogue::fusion::Sm90Compute<cutlass::first, Element, ElementAccumulator, RoundStyle>,
cutlass::epilogue::fusion::Sm90AccFetch
>;
using CollectiveEpilogueTMA = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape_WG, Shape<_1,_1,_1>, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
void, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment,
Element, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment,
cutlass::epilogue::TmaWarpSpecialized,
DefaultOperation
>::CollectiveOp;
struct Params {
typename CollectiveEpilogueTMA::Params epilogue_K;
typename CollectiveEpilogueTMA::Params epilogue_V;
};
using TensorStorage = typename CollectiveEpilogueTMA::TensorStorage[2];
using PipelineStorage = typename CollectiveEpilogueTMA::PipelineStorage;
using LoadPipeline = typename CollectiveEpilogueTMA::LoadPipeline;
static constexpr int TmaTransactionBytes = CollectiveEpilogueTMA::TmaTransactionBytes;
template<class ProblemShape>
static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace = nullptr) {
auto dK = make_stride(get<2>(args.dK), get<3>(args.dK),
make_stride(get<0>(args.dK), get<1>(args.dK)));
auto dV = make_stride(get<2>(args.dV), get<3>(args.dV),
make_stride(get<0>(args.dV), get<1>(args.dV)));
auto problem_size_kv = make_shape(get<3>(problem_size), get<4>(problem_size), 1,
make_shape(get<0>(problem_size), get<1>(problem_size)));
typename CollectiveEpilogueTMA::Arguments args_k{{}, args.ptr_K, dK, args.ptr_K, dK};
typename CollectiveEpilogueTMA::Arguments args_v{{}, args.ptr_V, dV, args.ptr_V, dV};
return Params{
CollectiveEpilogueTMA::to_underlying_arguments(problem_size_kv, args_k, nullptr),
CollectiveEpilogueTMA::to_underlying_arguments(problem_size_kv, args_v, nullptr)
};
}
template<class TileShape, class BlkCoord, class ResultTuple, class TiledMma, class ProblemShape>
CUTLASS_DEVICE void operator()(
TileShape const& tile_shape, BlkCoord const& blk_coord,
ResultTuple const& result, TiledMma const& tiled_mma,
ProblemShape const& problem_size, Params const& params,
LoadPipeline epi_load_pipeline, TensorStorage& epi_tensor_storage)
{
auto acc_k = get<0>(result);
auto acc_v = get<1>(result);
auto problem_size_kv = make_shape(get<3>(problem_size), get<4>(problem_size), _,
make_shape(get<0>(problem_size), get<1>(problem_size)));
using EpiStorePipeline = typename CollectiveEpilogueTMA::StorePipeline;
typename EpiStorePipeline::Params epi_store_pipeline_params;
epi_store_pipeline_params.always_wait = true;
EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
typename CollectiveEpilogueTMA::LoadPipelineState epi_load_pipe_consumer_state;
PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>();
CollectiveEpilogueTMA epilogue_k{params.epilogue_K, epi_tensor_storage[0]};
CollectiveEpilogueTMA epilogue_v{params.epilogue_V, epi_tensor_storage[1]};
{
auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] =
epilogue_k.store(
epi_load_pipeline, epi_load_pipe_consumer_state,
epi_store_pipeline, epi_store_pipe_producer_state,
problem_size_kv, tile_shape, make_coord(get<1>(blk_coord), _0{}, _, get<2>(blk_coord)),
acc_k, tiled_mma, threadIdx.x % cutlass::NumThreadsPerWarpGroup,
epi_tensor_storage[0]
);
}
{
auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] =
epilogue_v.store(
epi_load_pipeline, epi_load_pipe_consumer_state,
epi_store_pipeline, epi_store_pipe_producer_state,
problem_size_kv, tile_shape, make_coord(get<1>(blk_coord), _0{}, _, get<2>(blk_coord)),
acc_v, tiled_mma, threadIdx.x % cutlass::NumThreadsPerWarpGroup,
epi_tensor_storage[1]
);
epilogue_k.store_tail(
epi_load_pipeline, epi_load_pipe_consumer_state_next,
epi_store_pipeline, epi_store_pipe_producer_state_next
);
epilogue_v.store_tail(
epi_load_pipeline, epi_load_pipe_consumer_state_next,
epi_store_pipeline, epi_store_pipe_producer_state_next
);
}
}
};
} // namespace cutlass::fmha::collective

View File

@ -0,0 +1,283 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
namespace cutlass::fmha::collective {
using namespace cute;
struct DefaultFusion {
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size
) {
return ceil_div(get<3>(problem_size), get<1>(tile_shape));
}
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_masked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size
) {
return get_trip_count(blk_coord, tile_shape, problem_size);
}
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_unmasked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size
) {
return 0;
}
template<class AccQK, class IndexQK, class ProblemSize>
CUTLASS_DEVICE
void before_softmax(
AccQK& acc_qk,
IndexQK const& index_qk,
ProblemSize const& problem_size
) {
return;
}
};
struct ResidualFusion : DefaultFusion {
using Base = DefaultFusion;
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_masked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size
) {
return 1;
}
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_unmasked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size
) {
return get_trip_count(blk_coord, tile_shape, problem_size) - 1;
}
template<class AccQK, class IndexQK, class ProblemSize>
CUTLASS_DEVICE
void before_softmax(
AccQK& acc_qk,
IndexQK const& index_qk,
ProblemSize const& problem_size
) {
// This is useful is seqlen_k % kBlockN != 0 since it masks
// the remaining elements out from softmax.
// d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar
// issues as they are transparently taken care of by TMA and the
// epilogue, if it is instantiated with predication support.
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_qk); i++) {
auto pos = index_qk(i);
if (get<1>(pos) >= get<3>(problem_size)) {
acc_qk(i) = -INFINITY;
}
}
}
};
struct CausalFusion : DefaultFusion {
using Base = DefaultFusion;
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size
) {
// See note below on different ways to think about causal attention
// Again, we'd add the offset_q into the max_blocks_q calculation
int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size);
int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape));
return std::min(max_blocks_k, max_blocks_q);
}
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_masked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size
) {
return ceil_div(get<0>(tile_shape), get<1>(tile_shape));
}
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_unmasked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size
) {
return get_trip_count(blk_coord, tile_shape, problem_size) - get_masked_trip_count(blk_coord, tile_shape, problem_size);
}
template<class AccQK, class IndexQK, class ProblemSize>
CUTLASS_DEVICE
void before_softmax(
AccQK& acc_qk,
IndexQK const& index_qk,
ProblemSize const& problem_size
) {
// There are two ways to do causal if N_Q != N_K
// (1) is to assume that the Q is at the beginning of the matrix
// - this is what we demonstrate here
// (2) is that it is at the end of the matrix
// - this is usually what we want for inference settings
// where we only compute the next row and use cache for the rest
// - if you'd like this, you only need to add an offset like so:
// get<0>(pos) + offset_q < get<1>(pos)
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_qk); i++) {
auto pos = index_qk(i);
if (get<0>(pos) < get<1>(pos)) {
acc_qk(i) = -INFINITY;
}
}
}
};
template<class Base>
struct FusionBwdAdapter {
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size
) {
return Base{}.get_trip_count(select<1,0,2>(blk_coord), select<1,0,2>(tile_shape), select<0,1,3,2,4>(problem_size));
}
template<class AccQK, class IndexQK, class ProblemSize>
CUTLASS_DEVICE
void before_softmax(
AccQK& acc_qk,
IndexQK const& index_qk,
ProblemSize const& problem_size
) {
auto index_base = index_qk(_0{});
auto index_shape = shape(index_qk);
auto index_stride = transform_leaf(stride(index_qk), [](auto elem) {
if constexpr (is_scaled_basis<decltype(elem)>::value) {
if constexpr(decltype(elem.mode() == _0{})::value) {
return ScaledBasis<decltype(elem.value()), 1>(elem.value());
} else {
return ScaledBasis<decltype(elem.value()), 0>(elem.value());
}
} else {
return elem;
}
});
auto index_qk_bwd = make_tensor(make_inttuple_iter(select<1,0>(index_base)), make_layout(index_shape, index_stride));
Base{}.before_softmax(acc_qk, index_qk_bwd, problem_size);
}
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
bool is_contributing(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size
) {
return true;
}
};
template<>
struct FusionBwdAdapter<CausalFusion> {
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size
) {
return get<2>(problem_size) / get<0>(TileShape{});
}
template<class AccQK, class IndexQK, class ProblemSize>
CUTLASS_DEVICE
void before_softmax(
AccQK& acc_qk,
IndexQK const& index_qk,
ProblemSize const& problem_size
) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_qk); i++) {
auto pos = index_qk(i);
if (get<1>(pos) < get<0>(pos)) {
acc_qk(i) = -INFINITY;
}
}
}
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
bool is_contributing(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size
) {
int max_q = get<0>(blk_coord) * get<0>(tile_shape) + get<0>(tile_shape);
int min_k = get<1>(blk_coord) * get<1>(tile_shape);
return min_k <= max_q;
}
};
} // namespace cutlass::fmha::collective

View File

@ -0,0 +1,278 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief An universal device layer for cutlass 3.x-style kernels.
*/
#pragma once
// common
#include "cutlass/cutlass.h"
#include "cutlass/device_kernel.h"
#if !defined(__CUDACC_RTC__)
#include "cutlass/cluster_launch.hpp"
#include "cutlass/trace.h"
#endif // !defined(__CUDACC_RTC__)
////////////////////////////////////////////////////////////////////////////////
namespace cutlass::device {
////////////////////////////////////////////////////////////////////////////////
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
template <class Kernel_>
class Universal {
public:
using Kernel = Kernel_;
static int const kThreadCount = Kernel::MaxThreadsPerBlock;
/// Argument structure: User API
using Arguments = typename Kernel::Arguments;
/// Argument structure: Kernel API
using Params = typename Kernel::Params;
private:
/// Kernel API parameters object
Params params_;
bool is_initialized(bool set = false) {
static bool initialized = false;
if (set) initialized = true;
return initialized;
}
public:
/// Access the Params structure
Params const& params() const {
return params_;
}
/// Determines whether the GEMM can execute the given problem.
static Status
can_implement(Arguments const& args) {
if (Kernel::can_implement(args)) {
return Status::kSuccess;
}
else {
return Status::kInvalid;
}
}
/// Gets the workspace size
static size_t
get_workspace_size(Arguments const& args) {
size_t workspace_bytes = 0;
workspace_bytes += Kernel::get_workspace_size(args);
return workspace_bytes;
}
/// Computes the grid shape
static dim3
get_grid_shape(Params const& params) {
return Kernel::get_grid_shape(params);
}
/// Computes the maximum number of active blocks per multiprocessor
static int maximum_active_blocks(int /* smem_capacity */ = -1) {
CUTLASS_TRACE_HOST("Universal::maximum_active_blocks()");
int max_active_blocks = -1;
int smem_size = Kernel::SharedStorageSize;
// first, account for dynamic smem capacity if needed
cudaError_t result;
if (smem_size >= (48 << 10)) {
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
result = cudaFuncSetAttribute(
device_kernel<Kernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaFuncSetAttribute() returned error: "
<< cudaGetErrorString(result));
return -1;
}
}
// query occupancy after setting smem size
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks,
device_kernel<Kernel>,
Kernel::MaxThreadsPerBlock,
smem_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: "
<< cudaGetErrorString(result));
return -1;
}
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
return max_active_blocks;
}
/// Initializes GEMM state from arguments.
Status
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("Universal::initialize() - workspace "
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
// Initialize the workspace
Status status = Kernel::initialize_workspace(args, workspace, stream);
if (status != Status::kSuccess) {
return status;
}
// Initialize the Params structure
params_ = Kernel::to_underlying_arguments(args, workspace);
if (is_initialized()) return Status::kSuccess;
// account for dynamic smem capacity if needed
int smem_size = Kernel::SharedStorageSize;
if (smem_size >= (48 << 10)) {
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
cudaError_t result = cudaFuncSetAttribute(
device_kernel<Kernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
}
is_initialized(true);
return Status::kSuccess;
}
/// Update API is preserved in 3.0, but does not guarantee a lightweight update of params.
Status
update(Arguments const& args, void* workspace = nullptr) {
CUTLASS_TRACE_HOST("Universal()::update() - workspace: " << workspace);
size_t workspace_bytes = get_workspace_size(args);
if (workspace_bytes > 0 && nullptr == workspace) {
return Status::kErrorWorkspaceNull;
}
params_ = Kernel::to_underlying_arguments(args, workspace);
return Status::kSuccess;
}
/// Primary run() entry point API that is static allowing users to create and manage their own params.
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
static Status
run(Params& params, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("Universal::run()");
dim3 const block = Kernel::get_block_shape();
dim3 const grid = get_grid_shape(params);
// configure smem size and carveout
int smem_size = Kernel::SharedStorageSize;
Status launch_result;
// Use extended launch API only for mainloops that use it
if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) {
dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}),
cute::size<1>(typename Kernel::ClusterShape{}),
cute::size<2>(typename Kernel::ClusterShape{}));
void const* kernel = (void const*) device_kernel<Kernel>;
void* kernel_params[] = {&params};
launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params);
}
else {
launch_result = Status::kSuccess;
cutlass::arch::synclog_setup();
device_kernel<Kernel><<<grid, block, smem_size, stream>>>(params);
}
cudaError_t result = cudaGetLastError();
if (cudaSuccess == result && Status::kSuccess == launch_result) {
return Status::kSuccess;
}
else {
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
return Status::kErrorInternal;
}
}
//
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
//
/// Launches the kernel after first constructing Params internal state from supplied arguments.
Status
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (Status::kSuccess == status) {
status = run(params_, stream);
}
return status;
}
/// Launches the kernel after first constructing Params internal state from supplied arguments.
Status
operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
return run(args, workspace, stream);
}
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
Status
run(cudaStream_t stream = nullptr) {
return run(params_, stream);
}
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
Status
operator()(cudaStream_t stream = nullptr) {
return run(params_, stream);
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::device
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,299 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
/*!
\file
\brief An universal device layer for cutlass 3.x-style kernels.
*/
// common
#include "cutlass/cutlass.h"
#include "../device/device_universal.hpp"
#include "../collective/fmha_collective_bwd_tma_warpspecialized.hpp"
#include "../collective/fmha_fusion.hpp"
#include "../collective/fmha_epilogue_bwd.hpp"
#include "../kernel/fmha_kernel_bwd_sum_OdO.hpp"
#include "../kernel/fmha_kernel_bwd_convert.hpp"
#include "../kernel/fmha_kernel_tma_warpspecialized.hpp"
#include "../kernel/fmha_tile_scheduler.hpp"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass::fmha::device {
////////////////////////////////////////////////////////////////////////////////
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
template<class Element, class ElementAccumulator, class TileShape, class Fusion, class... Options>
class FmhaBwd {
public:
/// Argument structure: User API
struct Arguments {
cute::tuple<int, int, int, int, int> problem_size;
const Element* ptr_Q;
cute::tuple<int, int, int, cute::_1> stride_Q;
const Element* ptr_K;
cute::tuple<int, int, int, cute::_1> stride_K;
const Element* ptr_V;
cute::tuple<int, int, int, cute::_1> stride_V;
const Element* ptr_O;
cute::tuple<int, int, int, cute::_1> stride_O;
const ElementAccumulator* ptr_LSE;
cute::tuple<int, int, _1> stride_LSE;
const Element* ptr_dO;
cute::tuple<int, int, int, cute::_1> stride_dO;
Element* ptr_dQ;
cute::tuple<int, int, int, cute::_1> stride_dQ;
Element* ptr_dK;
cute::tuple<int, int, int, cute::_1> stride_dK;
Element* ptr_dV;
cute::tuple<int, int, int, cute::_1> stride_dV;
cutlass::KernelHardwareInfo hw_info;
};
using OperationSumOdO = cutlass::device::Universal<cutlass::fmha::kernel::FmhaKernelBwdSumOdO<Element, ElementAccumulator>>;
using OperationConvert = cutlass::device::Universal<cutlass::fmha::kernel::FmhaKernelBwdConvert<Element, ElementAccumulator>>;
using Mainloop = cutlass::fmha::collective::FmhaBwdMainloopTmaWarpSpecialized<
Element, ElementAccumulator, TileShape,
cutlass::fmha::collective::FusionBwdAdapter<Fusion>, Options...>;
using Epilogue = cutlass::fmha::collective::FmhaBwdEpilogueKV<Element, ElementAccumulator, typename Mainloop::TileShapePV>;
using Operation = cutlass::device::Universal<
cutlass::fmha::kernel::FmhaKernelTmaWarpSpecialized<
Mainloop,
Epilogue,
cutlass::fmha::kernel::TileSchedulerBwdAdapter<cutlass::fmha::kernel::IndividualTileScheduler>, Options...>>;
struct Params {
OperationSumOdO op_sum_OdO;
Operation op;
OperationConvert op_convert;
ElementAccumulator* dQ_acc;
size_t dQ_acc_size;
};
private:
Params params_;
static typename OperationSumOdO::Arguments to_sum_OdO_arguments(Arguments const& args, ElementAccumulator* dest = nullptr) {
auto [B, H, Q, K, D] = args.problem_size;
D = cutlass::round_up(D, 8); // Alignment
Q = cutlass::round_up(Q, 8); // Alignment
auto stride_sum_OdO = make_stride(H*Q, Q, _1{});
return typename OperationSumOdO::Arguments {
args.problem_size,
args.ptr_O, args.stride_O,
args.ptr_dO, args.stride_dO,
dest, stride_sum_OdO
};
}
static typename OperationConvert::Arguments to_convert_arguments(Arguments const& args, ElementAccumulator* src = nullptr) {
auto [B, H, Q, K, D] = args.problem_size;
D = cutlass::round_up(D, 8); // Alignment
Q = cutlass::round_up(Q, 8); // Alignment
auto stride_src_dQ = make_stride(B == 1 ? 0 : (H*Q*D), Q*D, D, _1{});
return typename OperationConvert::Arguments {
args.problem_size,
src, stride_src_dQ,
nullptr, stride_src_dQ,
nullptr, stride_src_dQ,
args.ptr_dQ, args.stride_dQ,
nullptr, args.stride_dK,
nullptr, args.stride_dV
};
}
static typename Operation::Arguments to_bwd_arguments(
Arguments const& args,
ElementAccumulator* sum_OdO = nullptr, cute::tuple<int, int, _1> const& stride_sum_OdO = {},
ElementAccumulator* dQ_acc = nullptr, cute::tuple<int, int, int, _1> const& stride_dQ = {}
) {
return typename Operation::Arguments{
args.problem_size,
{ args.ptr_Q, args.stride_Q,
args.ptr_K, args.stride_K,
args.ptr_V, args.stride_V,
args.ptr_dO, args.stride_dO,
args.ptr_LSE, args.stride_LSE,
sum_OdO, stride_sum_OdO,
dQ_acc, stride_dQ },
{ args.ptr_dK, args.stride_dK,
args.ptr_dV, args.stride_dV },
args.hw_info
};
}
public:
/// Determines whether the GEMM can execute the given problem.
static Status
can_implement(Arguments const& args) {
Status status = Status::kSuccess;
status = OperationSumOdO::can_implement(to_sum_OdO_arguments(args));
if (status != Status::kSuccess) {
return status;
}
status = OperationConvert::can_implement(to_convert_arguments(args));
if (status != Status::kSuccess) {
return status;
}
status = Operation::can_implement(to_bwd_arguments(args));
if (status != Status::kSuccess) {
return status;
}
return status;
}
/// Gets the workspace size
static size_t
get_workspace_size(Arguments const& args) {
auto [B, H, Q, K, D] = args.problem_size;
D = cutlass::round_up(D, 8); // Alignment
Q = cutlass::round_up(Q, 8); // Alignment
size_t workspace_bytes = 0;
// OdO vector
workspace_bytes += B*H*Q * sizeof(ElementAccumulator);
// FP32 versions of outputs that are churned (start off with Q only)
workspace_bytes += B*H*Q*D * sizeof(ElementAccumulator);
return workspace_bytes;
}
/// Initializes state from arguments.
Status
initialize_split(Arguments const& args, void* workspace_dQ, void* workspace_sum_OdO, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("Universal::initialize_split() - workspace_dQ="
<< workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null"));
auto [B, H, Q, K, D] = args.problem_size;
D = cutlass::round_up(D, 8); // Alignment
Q = cutlass::round_up(Q, 8); // Alignment
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_sum_OdO);
ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_dQ);
params_.dQ_acc = dQ_acc;
params_.dQ_acc_size = B*H*Q*D * sizeof(ElementAccumulator);
auto args_sum_OdO = to_sum_OdO_arguments(args, sum_OdO);
auto args_convert = to_convert_arguments(args, dQ_acc);
params_.op_sum_OdO.initialize(args_sum_OdO, nullptr, stream);
params_.op_convert.initialize(args_convert, nullptr, stream);
auto args_bwd = to_bwd_arguments(args, sum_OdO, args_sum_OdO.stride_sum_OdO, dQ_acc, args_convert.stride_src_dQ);
params_.op.initialize(args_bwd, nullptr, stream);
return Status::kSuccess;
}
/// Initializes state from arguments.
Status
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("Universal::initialize() - workspace "
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
auto [B, H, Q, K, D] = args.problem_size;
D = cutlass::round_up(D, 8); // Alignment
Q = cutlass::round_up(Q, 8); // Alignment
char* workspace_chr = reinterpret_cast<char*>(workspace);
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_chr);
workspace_chr += B*H*Q * sizeof(ElementAccumulator);
ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_chr);
return initialize_split(args, dQ_acc, sum_OdO, stream);
}
/// Primary run() entry point API that is static allowing users to create and manage their own params.
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
static Status
run(Params& params, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("FmhaDeviceBwd::run()");
Status result = Status::kSuccess;
result = params.op_sum_OdO.run(stream);
if (result != Status::kSuccess) {
return result;
}
auto cuda_result = cudaMemsetAsync(params.dQ_acc, 0, params.dQ_acc_size, stream);
if (cuda_result != cudaSuccess) {
return Status::kErrorInternal;
}
result = params.op.run(stream);
if (result != Status::kSuccess) {
return result;
}
result = params.op_convert.run(stream);
if (result != Status::kSuccess) {
return result;
}
return Status::kSuccess;
}
//
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
//
/// Launches the kernel after first constructing Params internal state from supplied arguments.
Status
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (Status::kSuccess == status) {
status = run(params_, stream);
}
return status;
}
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
Status
run(cudaStream_t stream = nullptr) {
return run(params_, stream);
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::fmha::device
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,158 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "../collective/fmha_collective_tma.hpp"
#include "../collective/fmha_collective_tma_warpspecialized.hpp"
#include "../collective/fmha_epilogue.hpp"
#include "../kernel/fmha_kernel_tma.hpp"
#include "../kernel/fmha_kernel_tma_warpspecialized.hpp"
#include "../kernel/fmha_options.hpp"
namespace cutlass::fmha::kernel {
template<
class Element_,
class ElementAccumulatorQK_,
class ElementAccumulatorPV_,
class TileShape_, // BlockQO, BlockKV, BlockHead
class LayoutQ_,
class LayoutK_,
class LayoutV_,
class Fusion,
class DispatchPolicy,
class... Options
>
struct FmhaBuilder;
template<
class Element,
class ElementAccumulator,
class TileShape, // BlockQO, BlockKV, BlockHead
class Fusion,
class... Options
>
struct FmhaBuilder<
Element,
ElementAccumulator,
ElementAccumulator,
TileShape,
cute::tuple<int, _1, cute::tuple<int, int>>,
cute::tuple<int, _1, cute::tuple<int, int>>,
cute::tuple<int, _1, cute::tuple<int, int>>,
Fusion,
cutlass::gemm::KernelTma,
Options...
> {
using CollectiveMainloop = cutlass::fmha::collective::FmhaMainloopTma<Element, ElementAccumulator, TileShape, Fusion, Options...>;
using CollectiveEpilogue = cutlass::fmha::collective::FmhaFwdEpilogue<
Element, ElementAccumulator, typename CollectiveMainloop::TileShapePV>;
using Kernel = cutlass::fmha::kernel::FmhaKernelTma<CollectiveMainloop, CollectiveEpilogue, Options...>;
};
template<
class Element,
class ElementAccumulatorQK,
class ElementAccumulatorPV,
class TileShape, // BlockQO, BlockKV, BlockHead
class LayoutQ,
class LayoutK,
class LayoutV,
class Fusion,
class... Options
>
struct FmhaBuilder<
Element,
ElementAccumulatorQK,
ElementAccumulatorPV,
TileShape,
LayoutQ,
LayoutK,
LayoutV,
Fusion,
cutlass::gemm::KernelTmaWarpSpecializedCooperative,
Options...
> {
using CollectiveMainloop = cutlass::fmha::collective::FmhaMainloopTmaWarpSpecialized<
Element, ElementAccumulatorQK, ElementAccumulatorPV,
TileShape, LayoutQ, LayoutK, LayoutV,
Fusion, Options...>;
using CollectiveEpilogue = cutlass::fmha::collective::FmhaFwdEpilogue<
Element, ElementAccumulatorPV, typename CollectiveMainloop::TileShapePV>;
static constexpr bool kIsPersistent = find_option_t<Tag::kIsPersistent, false_type, Options...>::value;
using TileScheduler = std::conditional_t<kIsPersistent, cutlass::fmha::kernel::PersistentTileScheduler, cutlass::fmha::kernel::IndividualTileScheduler>;
using Kernel = cutlass::fmha::kernel::FmhaKernelTmaWarpSpecialized<CollectiveMainloop, CollectiveEpilogue, TileScheduler, Options...>;
};
template<
class Element,
class ElementAccumulatorQK,
class ElementAccumulatorPV,
class TileShape, // BlockQO, BlockKV, BlockHead
class LayoutQ,
class LayoutK,
class LayoutV,
class Fusion,
class... Options
>
struct FmhaBuilder<
Element,
ElementAccumulatorQK,
ElementAccumulatorPV,
TileShape,
LayoutQ,
LayoutK,
LayoutV,
Fusion,
cutlass::gemm::KernelTmaWarpSpecializedPingpong,
Options...
> {
using Kernel = typename FmhaBuilder<
Element, ElementAccumulatorQK, ElementAccumulatorPV,
TileShape,
LayoutQ, LayoutK, LayoutV,
Fusion,
cutlass::gemm::KernelTmaWarpSpecializedCooperative,
Options...,
Option<Tag::kIsPersistent, true_type>,
Option<Tag::kLoadsQSeparately, true_type>
>::Kernel;
};
} // namespace cutlass::fmha::kernel

View File

@ -0,0 +1,143 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cute/layout.hpp"
namespace cutlass::fmha::kernel {
using namespace cute;
template<class Element, class ElementAccumulator>
struct FmhaKernelBwdConvert {
struct Arguments {
tuple<int, int, int, int, int> problem_size;
const ElementAccumulator* ptr_src_dQ;
tuple<int, int, int, _1> stride_src_dQ;
const ElementAccumulator* ptr_src_dK;
tuple<int, int, int, _1> stride_src_dK;
const ElementAccumulator* ptr_src_dV;
tuple<int, int, int, _1> stride_src_dV;
Element* ptr_dest_dQ;
tuple<int, int, int, _1> stride_dest_dQ;
Element* ptr_dest_dK;
tuple<int, int, int, _1> stride_dest_dK;
Element* ptr_dest_dV;
tuple<int, int, int, _1> stride_dest_dV;
};
using Params = Arguments;
using ClusterShape = Shape<_1, _1, _1>;
static constexpr int SharedStorageSize = 0;
static const int MinBlocksPerMultiprocessor = 1;
static const int MaxThreadsPerBlock = 128;
using ArchTag = cutlass::arch::Sm90;
static const int kBlockSeq = 8;
static size_t get_workspace_size(Arguments const& args) { return 0; }
static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
return cutlass::Status::kSuccess;
}
static const int kNumThreadsD = 16;
static const int kNumThreadsSeq = MaxThreadsPerBlock / kNumThreadsD;
static const int kElementsPerLoad = 4;
static const int kIterationsSeq = kBlockSeq / kNumThreadsSeq;
static bool can_implement(Arguments const& args) {
return get<4>(args.problem_size) % kElementsPerLoad == 0;
}
static dim3 get_grid_shape(Params const& params) {
dim3 grid(size<0>(params.problem_size), size<1>(params.problem_size), ceil_div(std::max(size<2>(params.problem_size), size<3>(params.problem_size)), kBlockSeq));
return grid;
}
static dim3 get_block_shape() {
dim3 block(kNumThreadsD, kNumThreadsSeq, 1);
return block;
}
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
return args;
}
template<class StrideSrc, class StrideDest>
CUTLASS_DEVICE void copy(Params const& params, const ElementAccumulator* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, int count) {
auto ptr_src_bh = ptr_src + get<0>(stride_src) * blockIdx.x + get<1>(stride_src) * blockIdx.y;
auto ptr_dest_bh = ptr_dest + get<0>(stride_dest) * blockIdx.x + get<1>(stride_dest) * blockIdx.y;
for (int idx_s_t = threadIdx.y; idx_s_t < kBlockSeq; idx_s_t += kNumThreadsSeq) {
int idx_s = idx_s_t + kBlockSeq * blockIdx.z;
if (idx_s >= count) continue;
auto ptr_src_bhs = ptr_src_bh + idx_s * get<2>(stride_src);
auto ptr_dest_bhs = ptr_dest_bh + idx_s * get<2>(stride_dest);
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<4>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) {
ElementAccumulator value_src[kElementsPerLoad];
Element value_dest[kElementsPerLoad];
using VecSrc = uint_bit_t<sizeof_bits_v<ElementAccumulator> * kElementsPerLoad>;
using VecDest = uint_bit_t<sizeof_bits_v<Element> * kElementsPerLoad>;
*reinterpret_cast<VecSrc*>(value_src) = *reinterpret_cast<const VecSrc*>(&ptr_src_bhs[idx_d]);
for (int v = 0; v < kElementsPerLoad; v++) {
value_dest[v] = value_src[v];
}
*reinterpret_cast<VecDest*>(&ptr_dest_bhs[idx_d]) = *reinterpret_cast<const VecDest*>(value_dest);
}
}
}
CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
if (params.ptr_src_dQ != nullptr) {
copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<2>(params.problem_size));
}
if (params.ptr_src_dK != nullptr) {
copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<3>(params.problem_size));
}
if (params.ptr_src_dV != nullptr) {
copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<3>(params.problem_size));
}
}
};
} // namespace cutlass::fmha::kernel

View File

@ -0,0 +1,134 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cute/layout.hpp"
namespace cutlass::fmha::kernel {
using namespace cute;
template<class Element, class ElementAccumulator>
struct FmhaKernelBwdSumOdO {
struct Arguments {
cute::tuple<int, int, int, int, int> problem_size;
const Element* ptr_O;
cute::tuple<int, int, int, cute::_1> stride_O;
const Element* ptr_dO;
cute::tuple<int, int, int, cute::_1> stride_dO;
ElementAccumulator* ptr_sum_OdO;
cute::tuple<int, int, _1> stride_sum_OdO;
};
using Params = Arguments;
using ClusterShape = Shape<_1, _1, _1>;
static constexpr int SharedStorageSize = 0;
static const int MinBlocksPerMultiprocessor = 1;
static const int MaxThreadsPerBlock = 128;
using ArchTag = cutlass::arch::Sm90;
static size_t get_workspace_size(Arguments const& args) { return 0; }
static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
return cutlass::Status::kSuccess;
}
static const int kBlockQ = 16;
static const int kNumThreadsD = 8;
static const int kNumThreadsQ = MaxThreadsPerBlock / kNumThreadsD;
static const int kElementsPerLoad = 2;
static const int kIterationsQ = kBlockQ / kNumThreadsQ;
static bool can_implement(Arguments const& args) {
return get<4>(args.problem_size) % kElementsPerLoad == 0;
}
static dim3 get_grid_shape(Params const& params) {
dim3 grid(ceil_div(size<2>(params.problem_size), kBlockQ), size<1>(params.problem_size), size<0>(params.problem_size));
return grid;
}
static dim3 get_block_shape() {
dim3 block(kNumThreadsD, kNumThreadsQ, 1);
return block;
}
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
return args;
}
CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
auto ptr_O_bh = params.ptr_O + blockIdx.y * get<1>(params.stride_O) + blockIdx.z * get<0>(params.stride_O);
auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<1>(params.stride_dO) + blockIdx.z * get<0>(params.stride_dO);
auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1>(params.stride_sum_OdO) + blockIdx.z * get<0>(params.stride_sum_OdO);
CUTLASS_PRAGMA_UNROLL
for (int idx_q_t = threadIdx.y; idx_q_t < kBlockQ; idx_q_t += kNumThreadsQ) {
int idx_q = idx_q_t + kBlockQ * blockIdx.x;
if (idx_q >= get<2>(params.problem_size)) continue;
ElementAccumulator acc = 0;
auto ptr_O_bhq = ptr_O_bh + idx_q * get<2>(params.stride_O);
auto ptr_dO_bhq = ptr_dO_bh + idx_q * get<2>(params.stride_dO);
auto ptr_sum_OdO_bhq = ptr_sum_OdO_bh + idx_q * get<2>(params.stride_sum_OdO);
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<4>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) {
Element value_O[kElementsPerLoad];
Element value_dO[kElementsPerLoad];
using Vec = uint_bit_t<sizeof_bits_v<Element> * kElementsPerLoad>;
*reinterpret_cast<Vec*>(value_O) = *reinterpret_cast<const Vec*>(&ptr_O_bhq[idx_d]);
*reinterpret_cast<Vec*>(value_dO) = *reinterpret_cast<const Vec*>(&ptr_dO_bhq[idx_d]);
for (int v = 0; v < kElementsPerLoad; v++) {
acc += value_O[v] * value_dO[v];
}
}
for (int i = 1; i < kNumThreadsD; i *= 2) {
acc += __shfl_xor_sync((uint32_t)-1, acc, i, kNumThreadsD);
}
if (threadIdx.x == 0) {
*ptr_sum_OdO_bhq = acc;
}
}
}
};
} // namespace cutlass::fmha::kernel

View File

@ -0,0 +1,222 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/arch/arch.h"
#include "../kernel/fmha_tile_scheduler.hpp"
#include "../kernel/fmha_options.hpp"
namespace cutlass::fmha::kernel {
template<
class CollectiveMainloop,
class CollectiveEpilogue,
class... Options
>
struct FmhaKernelTma {
// Options
static constexpr int kBlocksPerSM = find_option_t<Tag::kBlocksPerSM, Int<2>, Options...>::value;
using Element = typename CollectiveMainloop::Element;
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
using TileScheduler = IndividualTileScheduler;
using StagesQ = typename CollectiveMainloop::StagesQ;
using Stages = typename CollectiveMainloop::Stages;
using TileShape = typename CollectiveMainloop::TileShape;
using ClusterShape = typename CollectiveMainloop::ClusterShape;
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
using MainloopPipelineQ = typename CollectiveMainloop::MainloopPipelineQ;
using SmemLayoutQ = typename CollectiveMainloop::SmemLayoutQ;
using SmemLayoutK = typename CollectiveMainloop::SmemLayoutK;
struct SharedStorage {
union {
typename CollectiveMainloop::SharedStorage mainloop;
typename CollectiveEpilogue::TensorStorage epilogue;
};
using PipelineStorage = typename MainloopPipeline::SharedStorage;
using PipelineStorageQ = typename MainloopPipelineQ::SharedStorage;
alignas(16) PipelineStorage pipeline_storage;
alignas(16) PipelineStorageQ pipeline_storage_q;
using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage;
alignas(16) EpiLoadPipelineStorage epi_load;
};
static constexpr int SharedStorageSize = sizeof(SharedStorage);
using ProblemShape = cute::tuple<int, int, int, int, int>;
struct Arguments {
ProblemShape problem_size;
typename CollectiveMainloop::Arguments mainloop;
typename CollectiveEpilogue::Arguments epilogue;
KernelHardwareInfo hw_info;
};
struct Params {
ProblemShape problem_size;
typename CollectiveMainloop::Params mainloop;
typename CollectiveEpilogue::Params epilogue;
typename TileScheduler::Params tile_scheduler;
};
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename cutlass::PipelineState<MainloopPipeline::Stages>;
using PipelineParamsQ = typename MainloopPipelineQ::Params;
using PipelineStateQ = typename cutlass::PipelineState<MainloopPipelineQ::Stages>;
static const int MinBlocksPerMultiprocessor = kBlocksPerSM;
static const int MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock;
using ArchTag = cutlass::arch::Sm90;
static size_t get_workspace_size(Arguments const& args) { return 0; }
static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
return cutlass::Status::kSuccess;
}
static bool can_implement(Arguments const& args) {
return CollectiveMainloop::can_implement(args.problem_size, args.mainloop);
}
static dim3 get_grid_shape(Params const& params) {
return TileScheduler::get_grid_shape(params.tile_scheduler);
}
static dim3 get_block_shape() {
dim3 block(MaxThreadsPerBlock, 1, 1);
return block;
}
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
return Params{
args.problem_size,
CollectiveMainloop::to_underlying_arguments(args.problem_size, args.mainloop, workspace),
CollectiveEpilogue::to_underlying_arguments(args.problem_size, args.epilogue, workspace),
TileScheduler::to_underlying_arguments(args.problem_size, args.hw_info, ClusterShape{}, TileShape{})
};
}
CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
TileScheduler tile_scheduler{params.tile_scheduler};
// Shared memory.
auto& storage = *reinterpret_cast<SharedStorage*>(smem);
int thread_idx = int(threadIdx.x);
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
int warp_idx = cutlass::canonical_warp_idx_sync();
int warp_group_thread_idx = thread_idx % cutlass::NumThreadsPerWarpGroup;
int lane_predicate = cute::elect_one_sync();
// Issue Tma Descriptor Prefetch from a single thread
if ((warp_idx == 0) && lane_predicate) {
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
}
PipelineParamsQ pipeline_params_q;
pipeline_params_q.transaction_bytes = size(SmemLayoutQ{}(_,_,_0{})) * sizeof(Element); // Q
pipeline_params_q.role = MainloopPipelineQ::ThreadCategory::ProducerConsumer;
pipeline_params_q.is_leader = warp_group_thread_idx == 0;
pipeline_params_q.num_consumers = cutlass::NumThreadsPerWarpGroup;
PipelineParams pipeline_params;
pipeline_params.transaction_bytes = size(SmemLayoutK{}(_,_,_0{})) * sizeof(Element); // KV
pipeline_params.role = MainloopPipeline::ThreadCategory::ProducerConsumer;
pipeline_params.is_leader = warp_group_thread_idx == 0;
pipeline_params.num_consumers = cutlass::NumThreadsPerWarpGroup;
MainloopPipelineQ pipeline_q(storage.pipeline_storage_q, pipeline_params_q, Shape<_1, _1, _1>{});
MainloopPipeline pipeline(storage.pipeline_storage, pipeline_params, ClusterShape{});
using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
typename EpiLoadPipeline::Params epi_load_pipeline_params;
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::ProducerConsumer;
epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster();
epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp;
epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup;
epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes;
EpiLoadPipeline epi_load_pipeline(storage.epi_load, epi_load_pipeline_params);
// State variables used for iterating the circular buffer
// smem_pipe_read / release is used by the consumer of SMEM data - i.e MMA
// smem_pipe_write is used by the producer of SMEM data - i.e TMA
PipelineState smem_pipe_read;
PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipeline>();
PipelineStateQ smem_pipe_read_q;
PipelineStateQ smem_pipe_write_q = cutlass::make_producer_start_state<MainloopPipelineQ>();
// We need this to guarantee that the Pipeline init is visible
// To all producers and consumer blocks in the Cluster
// and to finish smem init
if constexpr (size(ClusterShape{}) > 1) {
cute::cluster_arrive_relaxed();
cute::cluster_wait();
}
else {
__syncthreads();
}
auto blk_coord = tile_scheduler.get_block_coord();
CollectiveMainloop collective_mainloop;
auto result = collective_mainloop.compute(
block_rank_in_cluster,
blk_coord, params.mainloop, params.problem_size,
pipeline, smem_pipe_read, smem_pipe_write,
pipeline_q, smem_pipe_read_q, smem_pipe_write_q,
storage.mainloop
);
CollectiveEpilogue epilogue;
epilogue(typename CollectiveMainloop::TileShapePV{}, blk_coord,
result, typename CollectiveMainloop::TiledMmaPV{},
params.problem_size, params.epilogue,
epi_load_pipeline, storage.epilogue);
}
};
} // namespace cutlass::fmha::kernel

View File

@ -0,0 +1,418 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/arch/reg_reconfig.h"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/arch/arch.h"
#include "../kernel/fmha_options.hpp"
namespace cutlass::fmha::kernel {
using namespace cute;
template<
class CollectiveMainloop,
class CollectiveEpilogue,
class TileScheduler,
class... Options
>
struct FmhaKernelTmaWarpSpecialized {
// Options
static constexpr bool kIsEpilogueLocked = find_option_t<Tag::kIsEpilogueLocked, false_type, Options...>::value;
static constexpr bool kLoadsQSeparately = find_option_t<Tag::kLoadsQSeparately, false_type, Options...>::value;
static const int NumLoadWarpGroups = 1;
static constexpr int NumMmaWarpGroups = CollectiveMainloop::NumMmaWarpGroups;
using TileShape = typename CollectiveMainloop::TileShape;
using ClusterShape = typename CollectiveMainloop::ClusterShape;
using MainloopPipelineOuter = typename CollectiveMainloop::MainloopPipelineQ;
using MainloopPipelineInner = typename CollectiveMainloop::MainloopPipeline;
using MainloopPipelineReducer = cutlass::PipelineAsync<2>;
static constexpr uint32_t StagesPerMathWarpGroup = 2;
using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier<
StagesPerMathWarpGroup, NumMmaWarpGroups>;
struct TensorStorageStruct {
typename CollectiveMainloop::SharedStorage mainloop;
typename CollectiveEpilogue::TensorStorage epilogue[NumMmaWarpGroups];
};
union TensorStorageUnion {
typename CollectiveMainloop::SharedStorage mainloop;
typename CollectiveEpilogue::TensorStorage epilogue[NumMmaWarpGroups];
};
using TensorStorage = std::conditional_t<CollectiveMainloop::kIsPersistent, TensorStorageStruct, TensorStorageUnion>;
struct SharedStorage {
TensorStorage tensors;
using PipelineStorageInner = typename MainloopPipelineInner::SharedStorage;
using PipelineStorageOuter = typename MainloopPipelineOuter::SharedStorage;
using PipelineStorageReducer = typename MainloopPipelineReducer::SharedStorage;
alignas(16) PipelineStorageInner pipeline_storage_inner;
alignas(16) PipelineStorageOuter pipeline_storage_outer;
alignas(16) PipelineStorageReducer pipeline_storage_reducer;
using MathWarpGroupOrderBarrierStorage = typename MathWarpGroupOrderBarrier::SharedStorage;
alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order;
alignas(16) cutlass::arch::ClusterBarrier load_warp_barrier;
using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage;
alignas(16) EpiLoadPipelineStorage epi_load;
};
static constexpr int SharedStorageSize = sizeof(SharedStorage);
using ProblemShape = cute::tuple<int, int, int, int, int>;
struct Arguments {
ProblemShape problem_size;
typename CollectiveMainloop::Arguments mainloop;
typename CollectiveEpilogue::Arguments epilogue;
KernelHardwareInfo hw_info;
};
struct Params {
ProblemShape problem_size;
typename CollectiveMainloop::Params mainloop;
typename CollectiveEpilogue::Params epilogue;
typename TileScheduler::Params tile_scheduler;
};
using PipelineParamsInner = typename MainloopPipelineInner::Params;
using PipelineStateInner = typename cutlass::PipelineState<MainloopPipelineInner::Stages>;
using PipelineParamsOuter = typename MainloopPipelineOuter::Params;
using PipelineStateOuter = typename cutlass::PipelineState<MainloopPipelineOuter::Stages>;
using PipelineParamsReducer = typename MainloopPipelineReducer::Params;
using PipelineStateReducer = typename cutlass::PipelineState<MainloopPipelineReducer::Stages>;
static const int MinBlocksPerMultiprocessor = 1;
static const int MaxThreadsPerBlock = (NumMmaWarpGroups + NumLoadWarpGroups) * cutlass::NumThreadsPerWarpGroup;
using ArchTag = cutlass::arch::Sm90;
static constexpr uint32_t LoadRegisterRequirement = 40 - 2 * 8;
static constexpr uint32_t TotalRegisterSupply = (64*1024 / MaxThreadsPerBlock / MinBlocksPerMultiprocessor / 8) * 8 * MaxThreadsPerBlock / cutlass::NumThreadsPerWarpGroup;
static constexpr uint32_t MmaRegisterRequirement = ((TotalRegisterSupply - LoadRegisterRequirement) / NumMmaWarpGroups / 8) * 8;
static size_t get_workspace_size(Arguments const& args) { return 0; }
static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
return cutlass::Status::kSuccess;
}
static bool can_implement(Arguments const& args) {
return CollectiveMainloop::can_implement(args.problem_size, args.mainloop);
}
static dim3 get_grid_shape(Params const& params) {
return TileScheduler::get_grid_shape(params.tile_scheduler);
}
static dim3 get_block_shape() {
dim3 block(MaxThreadsPerBlock, 1, 1);
return block;
}
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
return Params{
args.problem_size,
CollectiveMainloop::to_underlying_arguments(args.problem_size, args.mainloop, workspace),
CollectiveEpilogue::to_underlying_arguments(args.problem_size, args.epilogue, workspace),
TileScheduler::to_underlying_arguments(args.problem_size, args.hw_info, ClusterShape{}, TileShape{})
};
}
CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
enum class WarpGroupRole {
Producer = 0,
Consumer0 = 1,
Consumer1 = 2,
Consumer2 = 3,
Consumer3 = 4,
};
enum class ProducerWarpRole {
LoadKV = 1,
Reducer = 0,
MaybeLoadQ = 2, // is kLoadsQSeparately is true, this warp loads Q (otherwise warp 0 does it)
MainloopEpilogue = 3,
};
static constexpr ProducerWarpRole WarpRoleLoadQ = kLoadsQSeparately ? ProducerWarpRole::MaybeLoadQ : ProducerWarpRole::LoadKV;
TileScheduler tile_scheduler{params.tile_scheduler};
// Shared memory.
auto& storage = *reinterpret_cast<SharedStorage*>(smem);
int lane_idx = cutlass::canonical_lane_idx();
int warp_idx = cutlass::canonical_warp_idx_sync();
int warp_idx_in_warp_group = warp_idx % cutlass::NumWarpsPerWarpGroup;
int warp_group_idx = cutlass::canonical_warp_group_idx();
auto warp_group_role = WarpGroupRole(warp_group_idx);
auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group);
int consumer_warp_group_idx = warp_group_idx - (int) WarpGroupRole::Consumer0;
int lane_predicate = cute::elect_one_sync();
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
// Issue Tma Descriptor Prefetch from a single thread
if ((warp_idx == 0) && lane_predicate) {
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
}
PipelineParamsOuter pipeline_params_outer;
pipeline_params_outer.transaction_bytes = CollectiveMainloop::kOuterLoadBytes;
pipeline_params_outer.is_leader = lane_predicate && (producer_warp_role == WarpRoleLoadQ);
pipeline_params_outer.num_consumers = cutlass::NumThreadsPerWarpGroup;
PipelineParamsInner pipeline_params_inner;
pipeline_params_inner.transaction_bytes = CollectiveMainloop::kInnerLoadBytes;
pipeline_params_inner.is_leader = lane_predicate && (producer_warp_role == ProducerWarpRole::LoadKV);
pipeline_params_inner.num_consumers = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;
PipelineParamsReducer pipeline_params_reducer;
pipeline_params_reducer.producer_arv_count = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;
pipeline_params_reducer.consumer_arv_count = cutlass::NumThreadsPerWarp;
using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
typename EpiLoadPipeline::Params epi_load_pipeline_params;
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::MainloopEpilogue) {
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer;
}
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::LoadKV) {
pipeline_params_inner.role = MainloopPipelineInner::ThreadCategory::Producer;
}
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == WarpRoleLoadQ) {
pipeline_params_outer.role = MainloopPipelineOuter::ThreadCategory::Producer;
}
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Reducer) {
pipeline_params_reducer.role = MainloopPipelineReducer::ThreadCategory::Consumer;
}
if (warp_group_role == WarpGroupRole::Consumer0 ||
warp_group_role == WarpGroupRole::Consumer1 ||
warp_group_role == WarpGroupRole::Consumer2 ||
warp_group_role == WarpGroupRole::Consumer3
) {
pipeline_params_inner.role = MainloopPipelineInner::ThreadCategory::Consumer;
pipeline_params_outer.role = MainloopPipelineOuter::ThreadCategory::Consumer;
pipeline_params_reducer.role = MainloopPipelineReducer::ThreadCategory::Producer;
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer;
}
MainloopPipelineOuter pipeline_outer(storage.pipeline_storage_outer, pipeline_params_outer, Shape<_1, _1, _1>{});
MainloopPipelineInner pipeline_inner(storage.pipeline_storage_inner, pipeline_params_inner, ClusterShape{});
MainloopPipelineReducer pipeline_reducer(storage.pipeline_storage_reducer, pipeline_params_reducer);
// State variables used for iterating the circular buffer
// smem_pipe_read / release is used by the consumer of SMEM data - i.e MMA
// smem_pipe_write is used by the producer of SMEM data - i.e TMA
PipelineStateInner smem_pipe_read_inner;
PipelineStateInner smem_pipe_write_inner = cutlass::make_producer_start_state<MainloopPipelineInner>();
PipelineStateOuter smem_pipe_read_outer;
PipelineStateOuter smem_pipe_write_outer = cutlass::make_producer_start_state<MainloopPipelineOuter>();
PipelineStateReducer smem_pipe_read_reducer;
PipelineStateReducer smem_pipe_write_reducer = cutlass::make_producer_start_state<MainloopPipelineReducer>();
typename MathWarpGroupOrderBarrier::Params params_math_wg_order_barrier;
// DMA Load WG will not participate in these Ordered Barrier syncs
params_math_wg_order_barrier.group_id = consumer_warp_group_idx;
params_math_wg_order_barrier.group_size = cutlass::NumThreadsPerWarpGroup; // Number of threads / participants in a group
MathWarpGroupOrderBarrier math_wg_order_barrier(storage.math_wg_order, params_math_wg_order_barrier);
// Epilogue Load pipeline
epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster();
epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp;
epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup;
epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes;
EpiLoadPipeline epi_load_pipeline(storage.epi_load, epi_load_pipeline_params);
if constexpr (kLoadsQSeparately) {
if ((warp_idx == 0) && lane_predicate) {
storage.load_warp_barrier.init(2 * cutlass::NumThreadsPerWarp);
}
cutlass::arch::fence_barrier_init();
}
// We need this to guarantee that the Pipeline init is visible
// To all producers and consumer blocks in the Cluster
// and to finish smem init
if constexpr (size(ClusterShape{}) > 1) {
cute::cluster_arrive_relaxed();
cute::cluster_wait();
}
else {
__syncthreads();
}
CollectiveMainloop collective_mainloop;
if (warp_group_role == WarpGroupRole::Producer) {
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
if (producer_warp_role == ProducerWarpRole::LoadKV) {
bool do_barrier = kLoadsQSeparately;
CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord();
collective_mainloop.template load_kv_maybe_q<!kLoadsQSeparately>(
block_rank_in_cluster,
blk_coord, params.mainloop, params.problem_size,
pipeline_inner, smem_pipe_write_inner,
pipeline_outer, smem_pipe_write_outer,
storage.tensors.mainloop,
storage.load_warp_barrier, do_barrier
);
do_barrier = false;
}
}
else if (kLoadsQSeparately && (producer_warp_role == ProducerWarpRole::MaybeLoadQ)) {
bool do_barrier = true;
CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord();
collective_mainloop.load_maybe_q(
blk_coord, params.mainloop, params.problem_size,
pipeline_outer, smem_pipe_write_outer,
storage.tensors.mainloop,
storage.load_warp_barrier, do_barrier
);
do_barrier = false;
}
} else if (producer_warp_role == ProducerWarpRole::Reducer) {
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord();
collective_mainloop.reduce(
blk_coord, params.mainloop, params.problem_size,
pipeline_reducer, smem_pipe_read_reducer,
storage.tensors.mainloop
);
}
}
}
else if (
warp_group_role == WarpGroupRole::Consumer0 ||
warp_group_role == WarpGroupRole::Consumer1 ||
warp_group_role == WarpGroupRole::Consumer2 ||
warp_group_role == WarpGroupRole::Consumer3
) {
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord();
auto wg_coord = blk_coord;
constexpr int kOuterLoads = CollectiveMainloop::kOuterLoads;
if (warp_group_role == WarpGroupRole::Consumer0) {
smem_pipe_read_outer.advance(0 * kOuterLoads);
}
else if (warp_group_role == WarpGroupRole::Consumer1) {
smem_pipe_read_outer.advance(1 * kOuterLoads);
}
else if (warp_group_role == WarpGroupRole::Consumer2) {
smem_pipe_read_outer.advance(2 * kOuterLoads);
}
else if (warp_group_role == WarpGroupRole::Consumer3) {
smem_pipe_read_outer.advance(3 * kOuterLoads);
}
constexpr int wg_dim = is_constant<0, decltype(get<1>(wg_coord))>::value ? 0 : 1;
auto& wg_block = get<wg_dim>(wg_coord);
if (warp_group_role == WarpGroupRole::Consumer0) {
wg_block = NumMmaWarpGroups * wg_block + 0;
}
else if (warp_group_role == WarpGroupRole::Consumer1) {
wg_block = NumMmaWarpGroups * wg_block + 1;
}
else if (warp_group_role == WarpGroupRole::Consumer2) {
wg_block = NumMmaWarpGroups * wg_block + 2;
}
else if (warp_group_role == WarpGroupRole::Consumer3) {
wg_block = NumMmaWarpGroups * wg_block + 3;
}
auto result = collective_mainloop.compute(
blk_coord, wg_coord,
params.mainloop, params.problem_size,
pipeline_inner, smem_pipe_read_inner,
pipeline_outer, smem_pipe_read_outer,
pipeline_reducer, smem_pipe_write_reducer,
storage.tensors.mainloop,
math_wg_order_barrier
);
if (warp_group_role == WarpGroupRole::Consumer0) {
smem_pipe_read_outer.advance(kOuterLoads * (NumMmaWarpGroups - 0));
}
if constexpr (NumMmaWarpGroups >= 2) {
if (warp_group_role == WarpGroupRole::Consumer1) {
smem_pipe_read_outer.advance(kOuterLoads * (NumMmaWarpGroups - 1));
}
}
if constexpr (NumMmaWarpGroups >= 3) {
if (warp_group_role == WarpGroupRole::Consumer2) {
smem_pipe_read_outer.advance(kOuterLoads * (NumMmaWarpGroups - 2));
}
}
if constexpr (NumMmaWarpGroups >= 4) {
if (warp_group_role == WarpGroupRole::Consumer3) {
smem_pipe_read_outer.advance(kOuterLoads * (NumMmaWarpGroups - 3));
}
}
if constexpr (kIsEpilogueLocked) ; math_wg_order_barrier.wait();
CollectiveEpilogue epilogue;
epilogue(typename CollectiveMainloop::TileShapePV{}, wg_coord,
result, typename CollectiveMainloop::TiledMmaPV{},
params.problem_size, params.epilogue,
epi_load_pipeline, storage.tensors.epilogue[consumer_warp_group_idx]);
if constexpr (kIsEpilogueLocked) ; math_wg_order_barrier.arrive();
}
}
}
};
} // namespace cutlass::fmha::kernel

View File

@ -0,0 +1,83 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
namespace cutlass::fmha::kernel {
template<auto kTag, typename Default, typename... Options>
struct find_option;
template<auto kTag, typename Default>
struct find_option<kTag, Default> {
using option_value = Default;
};
template<auto kTag, typename Default, typename Option, typename... Options>
struct find_option<kTag, Default, Option, Options...> :
std::conditional_t<
Option::tag == kTag,
Option,
find_option<kTag, Default, Options...>
>
{};
template<auto kTag, typename Default, typename... Options>
using find_option_t = typename find_option<kTag, Default, Options...>::option_value;
enum class Tag {
kIsPersistent,
kNumMmaWarpGroups,
kLoadsQSeparately,
kIsMainloopLocked,
kIsEpilogueLocked,
kStagesQ,
kStagesKV,
kEpilogueKind,
kBlocksPerSM,
kClusterM,
kAccQK
};
template<auto kTag, class Value>
struct Option {
static constexpr auto tag = kTag;
using option_value = Value;
};
} // namespace cutlass::fmha::kernel

View File

@ -0,0 +1,204 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/kernel_hardware_info.h"
namespace cutlass::fmha::kernel {
////////////////////////////////////////////////////////////////////////////////
struct IndividualTileScheduler {
struct Params {
dim3 grid;
};
bool valid_ = true;
CUTLASS_DEVICE
IndividualTileScheduler(Params const&) {}
template<class ProblemSize, class ClusterShape, class TileShape>
static Params to_underlying_arguments(
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
ClusterShape const& cluster_shape, TileShape const& tile_shape)
{
using namespace cute;
dim3 grid(round_up(ceil_div(size<2>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)), size<0>(problem_size), size<1>(problem_size));
return Params{ grid };
}
static dim3 get_grid_shape(Params const& params) {
return params.grid;
}
CUTLASS_DEVICE
bool is_valid() {
return valid_;
}
CUTLASS_DEVICE
auto get_block_coord() {
using namespace cute;
return make_coord(blockIdx.x, _0{}, make_coord(blockIdx.y, blockIdx.z));
}
CUTLASS_DEVICE
IndividualTileScheduler& operator++() {
valid_ = false;
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////
struct PersistentTileScheduler {
struct Params {
int num_blocks;
FastDivmod divmod_m_block;
FastDivmod divmod_b;
FastDivmod divmod_h;
KernelHardwareInfo hw_info;
};
int block_idx = 0;
Params params;
CUTLASS_DEVICE
PersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {}
template<class ProblemSize, class ClusterShape, class TileShape>
static Params to_underlying_arguments(
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
ClusterShape const& cluster_shape, TileShape const& tile_shape)
{
using namespace cute;
// Get SM count if needed, otherwise use user supplied SM count
int sm_count = hw_info.sm_count;
if (sm_count <= 0) {
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
}
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
hw_info.sm_count = sm_count;
int num_m_blocks = cutlass::round_up(ceil_div(size<2>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape));
int num_blocks = num_m_blocks * size<0>(problem_size) * size<1>(problem_size);
return Params {
num_blocks,
{ num_m_blocks}, { size<0>(problem_size) }, { size<1>(problem_size) },
hw_info
};
}
static dim3 get_grid_shape(Params const& params) {
dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1);
return grid;
}
CUTLASS_DEVICE
bool is_valid() {
return block_idx < params.num_blocks;
}
CUTLASS_DEVICE
auto get_block_coord() {
using namespace cute;
int block_decode = block_idx;
int m_block, bidb, bidh;
params.divmod_m_block(block_decode, m_block, block_decode);
params.divmod_b(block_decode, bidb, block_decode);
params.divmod_h(block_decode, bidh, block_decode);
return make_coord(m_block, _0{}, make_coord(bidb, bidh));
}
CUTLASS_DEVICE
PersistentTileScheduler& operator++() {
block_idx += gridDim.x;
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////
template<typename Base>
struct TileSchedulerBwdAdapter {
using Params = typename Base::Params;
Base base_;
CUTLASS_DEVICE
TileSchedulerBwdAdapter(Params const& params) : base_(params) {}
template<class ProblemSize, class ClusterShape, class TileShape>
static Params to_underlying_arguments(
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
ClusterShape const& cluster_shape, TileShape const& tile_shape)
{
using namespace cute;
return Base::to_underlying_arguments(select<0,1,3,2,4>(problem_size), hw_info, select<1,0,2>(cluster_shape), select<1,0,2>(tile_shape));
}
static dim3 get_grid_shape(Params const& params) {
return Base::get_grid_shape(params);
}
CUTLASS_DEVICE
bool is_valid() {
return base_.is_valid();
}
CUTLASS_DEVICE
auto get_block_coord() {
using namespace cute;
return select<1,0,2>(base_.get_block_coord());
}
CUTLASS_DEVICE
TileSchedulerBwdAdapter& operator++() {
++base_;
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::fmha::kernel

View File

@ -0,0 +1,357 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/tensor.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ProblemShape,
class TensorQ, class TensorK, class TensorV,
class TensorO, class TensorLSE, class TensorDO,
class TensorDQ, /* class TensorDK, class TensorDV, */
class Fusion
>
void __global__ fmha_bwd_reference_dQ_kernel(
ProblemShape problem_shape,
TensorQ mQ, TensorK mK, TensorV mV,
TensorO mO, TensorLSE mLSE, TensorDO mDO,
TensorDQ mDQ, /* TensorDK mDK, TensorDV mDV, */
Fusion fusion
) {
using namespace cute;
using Element = typename TensorO::value_type;
using ElementAccumulator = typename TensorLSE::value_type;
extern __shared__ char mS_mem[];
Element* mS = reinterpret_cast<Element*>(mS_mem);
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<1>(mO)));
for (int idx_L = blockIdx.y; idx_L < size<2>(mDQ); idx_L += gridDim.y) {
for (int idx_Q = blockIdx.x; idx_Q < size<0>(mDQ); idx_Q += gridDim.x) {
for (int idx_K = threadIdx.x; idx_K < size<0>(mK); idx_K += blockDim.x) {
ElementAccumulator acc_qk = 0;
ElementAccumulator acc_dov = 0;
ElementAccumulator acc_doo = 0;
for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) {
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L);
}
auto id = make_identity_tensor(make_shape(1, 1));
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
frag(0) = acc_qk;
fusion.before_softmax(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
acc_qk = frag(0);
mS[idx_K] = static_cast<Element>(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo));
}
__syncthreads();
for (int idx_D = threadIdx.x; idx_D < size<1>(mDQ); idx_D += blockDim.x) {
ElementAccumulator acc = 0;
for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) {
acc += mS[idx_K] * mK(idx_K, idx_D, idx_L);
}
mDQ(idx_Q, idx_D, idx_L) = acc;
}
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ProblemShape,
class TensorQ, class TensorK, class TensorV,
class TensorO, class TensorLSE, class TensorDO,
/* class TensorDQ, */ class TensorDK, /* class TensorDV, */
class Fusion
>
void __global__ fmha_bwd_reference_dK_kernel(
ProblemShape problem_shape,
TensorQ mQ, TensorK mK, TensorV mV,
TensorO mO, TensorLSE mLSE, TensorDO mDO,
/* TensorDQ mDQ, */ TensorDK mDK, /* TensorDV mDV, */
Fusion fusion
) {
using namespace cute;
using Element = typename TensorO::value_type;
using ElementAccumulator = typename TensorLSE::value_type;
extern __shared__ char mS_mem[];
Element* mS = reinterpret_cast<Element*>(mS_mem);
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<1>(mO)));
for (int idx_L = blockIdx.y; idx_L < size<2>(mDK); idx_L += gridDim.y) {
for (int idx_K = blockIdx.x; idx_K < size<0>(mDK); idx_K += gridDim.x) {
for (int idx_Q = threadIdx.x; idx_Q < size<0>(mDO); idx_Q += blockDim.x) {
ElementAccumulator acc_qk = 0;
ElementAccumulator acc_dov = 0;
ElementAccumulator acc_doo = 0;
for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) {
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L);
}
auto id = make_identity_tensor(make_shape(1, 1));
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
frag(0) = acc_qk;
fusion.before_softmax(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
acc_qk = frag(0);
mS[idx_Q] = static_cast<Element>(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo));
}
__syncthreads();
for (int idx_D = threadIdx.x; idx_D < size<1>(mDK); idx_D += blockDim.x) {
ElementAccumulator acc = 0;
for (int idx_Q = 0; idx_Q < size<0>(mDO); idx_Q++) {
acc += mS[idx_Q] * mQ(idx_Q, idx_D, idx_L);
}
mDK(idx_K, idx_D, idx_L) = acc;
}
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ProblemShape,
class TensorQ, class TensorK, class TensorV,
class TensorO, class TensorLSE, class TensorDO,
/* class TensorDQ, class TensorDK, */ class TensorDV,
class Fusion
>
void __global__ fmha_bwd_reference_dV_kernel(
ProblemShape problem_shape,
TensorQ mQ, TensorK mK, TensorV mV,
TensorO mO, TensorLSE mLSE, TensorDO mDO,
/* TensorDQ mDQ, TensorDK mDK, */ TensorDV mDV,
Fusion fusion
) {
using namespace cute;
using Element = typename TensorO::value_type;
using ElementAccumulator = typename TensorLSE::value_type;
extern __shared__ char mS_mem[];
Element* mS = reinterpret_cast<Element*>(mS_mem);
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<1>(mO)));
for (int idx_L = blockIdx.y; idx_L < size<2>(mDV); idx_L += gridDim.y) {
for (int idx_K = blockIdx.x; idx_K < size<0>(mDV); idx_K += gridDim.x) {
for (int idx_Q = threadIdx.x; idx_Q < size<0>(mDO); idx_Q += blockDim.x) {
ElementAccumulator acc_qk = 0;
for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) {
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
}
auto id = make_identity_tensor(make_shape(1, 1));
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
frag(0) = acc_qk;
fusion.before_softmax(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
acc_qk = frag(0);
mS[idx_Q] = static_cast<Element>(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)));
}
__syncthreads();
for (int idx_D = threadIdx.x; idx_D < size<1>(mDV); idx_D += blockDim.x) {
ElementAccumulator acc = 0;
for (int idx_Q = 0; idx_Q < size<0>(mDO); idx_Q++) {
acc += mS[idx_Q] * mDO(idx_Q, idx_D, idx_L);
}
mDV(idx_K, idx_D, idx_L) = acc;
}
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ProblemShape,
class TensorQ, class TensorK, class TensorV,
class TensorO, class TensorLSE, class TensorDO,
/**/ class TensorDQ, /** / class TensorDK, / ** / class TensorDV, / **/
class Fusion
>
void fmha_bwd_reference_dQ(
ProblemShape problem_shape,
TensorQ mQ, TensorK mK, TensorV mV,
TensorO mO, TensorLSE mLSE, TensorDO mDO,
/**/ TensorDQ mDQ, /** / TensorDK mDK, / ** / TensorDV mDV, / **/
Fusion fusion
) {
using namespace cute;
dim3 grid(size<0>(mDQ), size<2>(mDQ), 1);
dim3 block(256);
int shared_mem = size<0>(mK) * sizeof(typename TensorO::value_type);
if (shared_mem >= (48 << 10)) {
CUTLASS_TRACE_HOST(" Setting smem size to " << shared_mem);
auto result = cudaFuncSetAttribute(
fmha_bwd_reference_dQ_kernel<ProblemShape, TensorQ, TensorK, TensorV, TensorO, TensorLSE, TensorDO, TensorDQ, Fusion>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
shared_mem);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaFuncSetAttribute() returned error: "
<< cudaGetErrorString(result));
return;
}
}
fmha_bwd_reference_dQ_kernel<<<grid, block, shared_mem>>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, fusion);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ProblemShape,
class TensorQ, class TensorK, class TensorV,
class TensorO, class TensorLSE, class TensorDO,
/** / class TensorDQ, / **/ class TensorDK, /** / class TensorDV, / **/
class Fusion
>
void fmha_bwd_reference_dK(
ProblemShape problem_shape,
TensorQ mQ, TensorK mK, TensorV mV,
TensorO mO, TensorLSE mLSE, TensorDO mDO,
/** / TensorDQ mDQ, / **/ TensorDK mDK, /** / TensorDV mDV, / **/
Fusion fusion
) {
using namespace cute;
dim3 grid(size<0>(mDK), size<2>(mDK), 1);
dim3 block(256);
int shared_mem = size<0>(mDO) * sizeof(typename TensorO::value_type);
if (shared_mem >= (48 << 10)) {
CUTLASS_TRACE_HOST(" Setting smem size to " << shared_mem);
auto result = cudaFuncSetAttribute(
fmha_bwd_reference_dK_kernel<ProblemShape, TensorQ, TensorK, TensorV, TensorO, TensorLSE, TensorDO, TensorDK, Fusion>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
shared_mem);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaFuncSetAttribute() returned error: "
<< cudaGetErrorString(result));
return;
}
}
fmha_bwd_reference_dK_kernel<<<grid, block, shared_mem>>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDK, fusion);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ProblemShape,
class TensorQ, class TensorK, class TensorV,
class TensorO, class TensorLSE, class TensorDO,
/** / class TensorDQ, / ** / class TensorDK, / **/ class TensorDV, /**/
class Fusion
>
void fmha_bwd_reference_dV(
ProblemShape problem_shape,
TensorQ mQ, TensorK mK, TensorV mV,
TensorO mO, TensorLSE mLSE, TensorDO mDO,
/** / TensorDQ mDQ, / ** / TensorDK mDK, / **/ TensorDV mDV, /**/
Fusion fusion
) {
using namespace cute;
dim3 grid(size<0>(mDV), size<2>(mDV), 1);
dim3 block(256);
int shared_mem = size<0>(mDO) * sizeof(typename TensorO::value_type);
if (shared_mem >= (48 << 10)) {
CUTLASS_TRACE_HOST(" Setting smem size to " << shared_mem);
auto result = cudaFuncSetAttribute(
fmha_bwd_reference_dV_kernel<ProblemShape, TensorQ, TensorK, TensorV, TensorO, TensorLSE, TensorDO, TensorDV, Fusion>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
shared_mem);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaFuncSetAttribute() returned error: "
<< cudaGetErrorString(result));
return;
}
}
fmha_bwd_reference_dV_kernel<<<grid, block, shared_mem>>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDV, fusion);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ProblemShape,
class TensorQ, class TensorK, class TensorV,
class TensorO, class TensorLSE, class TensorDO,
class TensorDQ, class TensorDK, class TensorDV,
class Fusion
>
void fmha_bwd_reference(
ProblemShape problem_shape,
TensorQ mQ, TensorK mK, TensorV mV,
TensorO mO, TensorLSE mLSE, TensorDO mDO,
TensorDQ mDQ, TensorDK mDK, TensorDV mDV,
Fusion fusion
) {
fmha_bwd_reference_dQ(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, fusion);
fmha_bwd_reference_dK(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDK, fusion);
fmha_bwd_reference_dV(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDV, fusion);
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,156 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/tensor.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ProblemShape,
class TensorQ,
class TensorK,
class TensorV,
class TensorO,
class TensorLSE,
class Fusion
>
void __global__ fmha_reference_kernel(
ProblemShape problem_shape,
TensorQ mQ, TensorK mK, TensorV mV,
TensorO mO, TensorLSE mLSE,
Fusion fusion
) {
using namespace cute;
using Element = typename TensorO::value_type;
using ElementAccumulator = typename TensorLSE::value_type;
extern __shared__ char mS_mem[];
Element* mS = reinterpret_cast<Element*>(mS_mem);
ElementAccumulator softmax_scale = static_cast<ElementAccumulator>(1.0 / sqrt(1.0 * size<1>(mO)));
auto id = make_identity_tensor(make_shape(1, 1));
for (int idx_L = blockIdx.y; idx_L < size<2>(mO); idx_L += gridDim.y) {
for (int idx_Q = blockIdx.x; idx_Q < size<0>(mO); idx_Q += gridDim.x) {
for (int idx_K = threadIdx.x; idx_K < size<0>(mK); idx_K += blockDim.x) {
ElementAccumulator acc = 0;
for (int idx_D = 0; idx_D < size<1>(mK); idx_D++) {
acc += mQ(idx_Q, idx_D, idx_L) * mK(idx_K, idx_D, idx_L);
}
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
frag(0) = acc;
fusion.before_softmax(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
mS[idx_K] = static_cast<Element>(frag(0) * softmax_scale);
}
__syncthreads();
ElementAccumulator maxS = -std::numeric_limits<ElementAccumulator>::infinity();
for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) {
maxS = std::max<ElementAccumulator>(maxS, mS[idx_K]);
}
if (maxS == -std::numeric_limits<ElementAccumulator>::infinity()) maxS = 0;
__syncthreads();
for (int idx_K = threadIdx.x; idx_K < size<0>(mK); idx_K += blockDim.x) {
mS[idx_K] = static_cast<Element>(exp(mS[idx_K] - maxS));
}
__syncthreads();
ElementAccumulator sum = 0;
for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) {
sum += mS[idx_K];
}
Element scale = static_cast<Element>(1.0 / sum);
for (int idx_D = threadIdx.x; idx_D < size<1>(mO); idx_D += blockDim.x) {
ElementAccumulator acc = 0;
for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) {
acc += mS[idx_K] * mV(idx_K, idx_D, idx_L) * scale;
}
mO(idx_Q, idx_D, idx_L) = static_cast<Element>(acc);
}
if (threadIdx.x == 0) {
mLSE(idx_Q, idx_L) = log(sum) + maxS;
}
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ProblemShape,
class TensorQ,
class TensorK,
class TensorV,
class TensorO,
class TensorLSE,
class Fusion
>
void fmha_reference(
ProblemShape problem_shape,
TensorQ mQ, TensorK mK, TensorV mV,
TensorO mO, TensorLSE mLSE,
Fusion fusion
) {
using namespace cute;
dim3 grid(size<0>(mO), size<2>(mO), 1);
dim3 block(256);
int shared_mem = size<0>(mK) * sizeof(typename TensorO::value_type);
if (shared_mem >= (48 << 10)) {
CUTLASS_TRACE_HOST(" Setting smem size to " << shared_mem);
auto result = cudaFuncSetAttribute(
fmha_reference_kernel<ProblemShape, TensorQ, TensorK, TensorV, TensorO, TensorLSE, Fusion>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
shared_mem);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaFuncSetAttribute() returned error: "
<< cudaGetErrorString(result));
return;
}
}
fmha_reference_kernel<<<grid, block, shared_mem>>>(problem_shape, mQ, mK, mV, mO, mLSE, fusion);
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,129 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cmath>
#include "cutlass/util/device_memory.h"
template<typename Element>
__global__ void reference_abs_diff_kernel(
Element* data, Element* data_ref, size_t count,
double* max_diff, double* sum_diff,
bool print_diff
) {
double thread_max_diff = 0;
double thread_sum_diff = 0;
__shared__ double block_max_diff;
__shared__ double block_sum_diff;
for (size_t i = threadIdx.x + blockIdx.x * blockDim.x; i < count; i += blockDim.x * gridDim.x) {
double diff = fabs(data[i] - data_ref[i]);
if (print_diff) if (diff != diff || diff > 0.01f) printf("difference at %lld: %f ... %f vs %f\n", static_cast<long long int>(i), diff, (double)data[i], (double)data_ref[i]);
thread_max_diff = fmax(diff, thread_max_diff);
thread_sum_diff += diff;
}
for (int i = 0; i < blockDim.x; i++) {
if (i == threadIdx.x) {
if (i == 0) {
block_max_diff = thread_max_diff;
block_sum_diff = thread_sum_diff;
} else {
block_max_diff = fmax(block_max_diff, thread_max_diff);
block_sum_diff += thread_sum_diff;
}
}
__syncthreads();
}
if (threadIdx.x == 0) {
atomicAdd(sum_diff, block_sum_diff);
for (;;) {
unsigned long long prev = *reinterpret_cast<unsigned long long*>(max_diff);
double prev_diff = reinterpret_cast<double const&>(prev);
double new_max_diff = fmax(block_max_diff, prev_diff);
unsigned long long found = atomicCAS(reinterpret_cast<unsigned long long*>(max_diff), prev, reinterpret_cast<unsigned long long const&>(new_max_diff));
if (found == prev) break;
}
}
}
template<typename Element>
void reference_abs_diff(
cutlass::DeviceAllocation<Element> const& data,
cutlass::DeviceAllocation<Element> const& data_ref,
double& max_diff, double& mean_diff
) {
static bool kPrintDiff = getenv("REF_PRINT_DIFF") && atoi(getenv("REF_PRINT_DIFF")) == 1;
cutlass::DeviceAllocation<double> result;
result.reset(2);
assert(data.size() == data_ref.size());
cudaError_t err = cudaMemset(result.get(), 0, result.size() * sizeof(double));
if (err != cudaSuccess) {
std::cerr << "Memset failed. Last CUDA error: "
<< cudaGetErrorString(err) << std::endl;
max_diff = mean_diff = 1e20;
return;
}
dim3 block(256, 1, 1);
dim3 grid(1024, 1, 1);
reference_abs_diff_kernel<<<block, grid>>>(
data.get(), data_ref.get(), data.size(),
result.get(), result.get() + 1, kPrintDiff);
err = cudaDeviceSynchronize();
if (err != cudaSuccess) {
std::cerr << "Difference kernel failed. Last CUDA error: "
<< cudaGetErrorString(err) << std::endl;
max_diff = mean_diff = 1e20;
return;
}
double result_host[2];
err = cudaMemcpy(result_host, result.get(), result.size() * sizeof(double), cudaMemcpyDefault);
if (err != cudaSuccess) {
std::cerr << "Copy failed. Last CUDA error: "
<< cudaGetErrorString(err) << std::endl;
max_diff = mean_diff = 1e20;
return;
}
max_diff = result_host[0];
mean_diff = result_host[1] / static_cast<double>(data.size());
}

View File

@ -163,6 +163,7 @@ foreach(EXAMPLE
82_blackwell_distributed_gemm
83_blackwell_sparse_gemm
84_blackwell_narrow_precision_sparse_gemm
88_hopper_fmha
)
add_subdirectory(${EXAMPLE})

View File

@ -55,3 +55,7 @@ cutlass_example_add_executable(
tiled_copy.cu
)
cutlass_example_add_executable(
cute_tutorial_tiled_copy_if
tiled_copy_if.cu
)

View File

@ -0,0 +1,297 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <cute/tensor.hpp>
#include "cutlass/util/print_error.hpp"
#include "cutlass/util/GPU_Clock.hpp"
#include "cutlass/util/helper_cuda.hpp"
// This example extends `tiled_copy` using predicate tensors to guard memory accesses performed
// by `cute::copy_if()`. This enables tensors to have shapes that are not integer multiples of
// block sizes.
//
// This is accomplished by instantiating a tensor of coordinates which correspond to tensor elements
// to be accessed and then computing a predicate tensor which masks accesses. The example demonstrates
// how constructing of an identity tensor containing coordinates and a predicate tensor containing
// mask bits can be implemented using the same CuTe operations used to tile the tensors in
// Global Memory.
//
// This example implements two variants:
// - copy_if_kernel() uses `cute::local_partition()` to construct each thread's slice
// - copy_if_kernel_vectorized() uses `make_tiled_copy() to implement vectorized memory accesses.
//
// The tensor shapes and strides must be divisible by the shape of the vector access.
//
/// Simple copy kernel.
//
// Uses local_partition() to partition a tile among threads arranged as (THR_M, THR_N).
template <class TensorS, class TensorD, class BlockShape, class ThreadLayout>
__global__ void copy_if_kernel(TensorS S, TensorD D, BlockShape block_shape, ThreadLayout)
{
using namespace cute;
// Construct a coordinate tensor whose elements are the coordinates used to access tensors S and D.
auto shape_S = shape(S);
Tensor C = make_identity_tensor(shape_S);
// Construct a predicate tensor which compares the coordinates with the original shape
Tensor P = cute::lazy::transform(C, [&](auto c) { return elem_less(c, shape_S); });
// Tile the input tensor into blocks
auto block_coord = make_coord(blockIdx.x, blockIdx.y);
Tensor tile_S = local_tile(S, block_shape, block_coord); // (BlockShape_M, BlockShape_N)
Tensor tile_D = local_tile(D, block_shape, block_coord); // (BlockShape_M, BlockShape_N)
Tensor tile_P = local_tile(P, block_shape, block_coord); // (BlockShape_M, BlockShape_N)
// Construct a partitioning of the tile among threads with the given thread arrangement.
// Concept: Tensor ThrLayout ThrIndex
Tensor thr_tile_S = local_partition(tile_S, ThreadLayout{}, threadIdx.x);
Tensor thr_tile_D = local_partition(tile_D, ThreadLayout{}, threadIdx.x);
Tensor thr_tile_P = local_partition(tile_P, ThreadLayout{}, threadIdx.x);
// Copy from GMEM to GMEM using `thr_tile_P` to guard accesses.
copy_if(thr_tile_P, thr_tile_S, thr_tile_D);
}
/// Vectorized copy kernel.
///
/// Uses `make_tiled_copy()` to perform a copy using vector instructions. This operation
/// has the precondition that pointers are aligned to the vector size.
///
template <class TensorS, class TensorD, class BlockShape, class Tiled_Copy>
__global__ void copy_if_kernel_vectorized(TensorS S, TensorD D, BlockShape block_shape, Tiled_Copy tiled_copy)
{
using namespace cute;
// Construct a coordinate tensor whose elements are the coordinates used to access tensors S and D.
auto shape_S = shape(S);
Tensor C = make_identity_tensor(shape_S);
// Construct a predicate tensor which compares the coordinates with the original shape
Tensor P = cute::lazy::transform(C, [&](auto c) { return elem_less(c, shape_S); });
// Tile the input tensor into blocks
auto block_coord = make_coord(blockIdx.x, blockIdx.y);
Tensor tile_S = local_tile(S, block_shape, block_coord); // (BlockShape_M, BlockShape_N)
Tensor tile_D = local_tile(D, block_shape, block_coord); // (BlockShape_M, BlockShape_N)
Tensor tile_P = local_tile(P, block_shape, block_coord); // (BlockShape_M, BlockShape_N)
//
// Construct a Tensor corresponding to each thread's slice.
//
ThrCopy thr_copy = tiled_copy.get_thread_slice(threadIdx.x);
Tensor thr_tile_S = thr_copy.partition_S(tile_S); // (CPY, CPY_M, CPY_N)
Tensor thr_tile_D = thr_copy.partition_D(tile_D); // (CPY, CPY_M, CPY_N)
Tensor thr_tile_P = thr_copy.partition_S(tile_P); // (CPY, CPY_M, CPY_N)
#if 0
// Copy from GMEM to GMEM
copy_if(tiled_copy, thr_tile_P, thr_tile_S, thr_tile_D);
#else
// make_fragment_like() constructs a tensor in RMEM with the same shape as thr_tile_S.
Tensor frag = make_fragment_like(thr_tile_S);
// Copy from GMEM to RMEM and from RMEM to GMEM
copy_if(tiled_copy, thr_tile_P, thr_tile_S, frag);
copy_if(tiled_copy, thr_tile_P, frag, thr_tile_D);
#endif
}
/// Main function
int main(int argc, char** argv)
{
//
// Given a 2D shape, perform an efficient copy
//
using namespace cute;
using Element = float;
// Define a tensor shape with dynamic extents (m, n)
auto tensor_shape = make_shape(528, 300);
thrust::host_vector<Element> h_S(size(tensor_shape));
thrust::host_vector<Element> h_D(size(tensor_shape));
//
// Initialize
//
for (size_t i = 0; i < h_S.size(); ++i) {
h_S[i] = static_cast<Element>(i);
h_D[i] = Element{};
}
thrust::device_vector<Element> d_S = h_S;
thrust::device_vector<Element> d_D = h_D;
thrust::device_vector<Element> d_Zero = h_D;
//
// Make tensors
//
Tensor tensor_S = make_tensor(make_gmem_ptr(d_S.data().get()), make_layout(tensor_shape));
Tensor tensor_D = make_tensor(make_gmem_ptr(d_D.data().get()), make_layout(tensor_shape));
//
// Partition
//
// Define a statically sized block (M, N).
//
// Note, by convention, capital letters are used to represent static modes.
auto block_shape = make_shape(Int<128>{}, Int<64>{});
// Tile the tensor (m, n) ==> ((M, N), m', n') where (M, N) is the static tile
// shape, and modes (m', n') correspond to the number of tiles.
//
// These will be used to determine the CUDA kernel grid dimensinos.
Tensor tiled_tensor_D = tiled_divide(tensor_D, block_shape); // ((M, N), m', n')
// Describes the layout of threads which is then replicated to tile 'block_shape.'
Layout thr_layout = make_layout(make_shape(Int<32>{}, Int< 8>{})); // (ThrM, ThrN)
//
// Determine grid and block dimensions
//
dim3 gridDim (size<1>(tiled_tensor_D), size<2>(tiled_tensor_D)); // Grid shape corresponds to modes m' and n'
dim3 blockDim(size(thr_layout));
//
// Launch the kernel
//
// copy_if()
copy_if_kernel<<< gridDim, blockDim >>>(
tensor_S,
tensor_D,
block_shape,
thr_layout);
cudaError result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "CUDA Runtime error: " << cudaGetErrorString(result) << std::endl;
return -1;
}
h_D = d_D;
//
// Verification
//
auto verify = [](thrust::host_vector<Element> const &S, thrust::host_vector<Element> const &D){
int32_t errors = 0;
int32_t const kErrorLimit = 10;
if (S.size() != D.size()) {
return 1;
}
for (size_t i = 0; i < D.size(); ++i) {
if (S[i] != D[i]) {
std::cerr << "Error. S[" << i << "]: " << S[i] << ", D[" << i << "]: " << D[i] << std::endl;
if (++errors >= kErrorLimit) {
std::cerr << "Aborting on " << kErrorLimit << "nth error." << std::endl;
return errors;
}
}
}
return errors;
};
if (verify(h_D, h_S)) {
return -1;
} else {
std::cout << "Success." << std::endl;
}
thrust::copy(d_Zero.begin(), d_Zero.end(), d_D.begin());
// Construct a TiledCopy with a specific access pattern.
// This version uses a
// (1) Layout-of-Threads to describe the number and arrangement of threads (e.g. row-major, col-major, etc),
// (2) Layout-of-Values that each thread will access.
// Value arrangement per thread
Layout val_layout = make_layout(make_shape(Int<4>{}, Int<1>{})); // (4,1) -> val_idx
// Define `AccessType` which controls the size of the actual memory access instruction.
using CopyOp = UniversalCopy<uint_byte_t<sizeof(Element) * size(val_layout)>>; // A very specific access width copy instruction
//using CopyOp = UniversalCopy<cutlass::AlignedArray<Element, size(val_layout)>>; // A more generic type that supports many copy strategies
//using CopyOp = AutoVectorizingCopy; // An adaptable-width instruction that assumes maximal alignment of inputs
// A Copy_Atom corresponds to one CopyOperation applied to Tensors of type Element.
using Atom = Copy_Atom<CopyOp, Element>;
// Construct tiled copy, a tiling of copy atoms.
//
// Note, this assumes the vector and thread layouts are aligned with contigous data
// in GMEM. Alternative thread layouts are possible but may result in uncoalesced
// reads. Alternative value layouts are also possible, though incompatible layouts
// will result in compile time errors.
TiledCopy tiled_copy = make_tiled_copy(Atom{}, // Access strategy
thr_layout, // thread layout (e.g. 32x4 Col-Major)
val_layout); // value layout (e.g. 4x1)
// copy_if() with vectorization
copy_if_kernel_vectorized<<< gridDim, blockDim >>>(
tensor_S,
tensor_D,
block_shape,
tiled_copy);
result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "CUDA Runtime error: " << cudaGetErrorString(result) << std::endl;
return -1;
}
h_D = d_D;
if (verify(h_D, h_S)) {
return -1;
} else {
std::cout << "Success." << std::endl;
}
return 0;
}

View File

@ -0,0 +1,200 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import cutlass.cute as cute
import cutlass
import torch
import numpy as np
from cutlass.cute.runtime import from_dlpack
"""
A Shared Memory Allocator Example on NVIDIA Ampere architecture using CuTe DSL.
This example demonstrates how to allocate and manage shared memory in JIT kernels by using the SmemAllocator in CuTe DSL.
It shows various ways to allocate different data structures in shared memory:
1. Struct allocation with natural and strict alignment
2. Raw memory block allocation with custom alignment
3. Array allocation with automatic alignment
4. Tensor allocation with layout specification
The example includes:
- Shared storage struct with mixed alignment requirements
- Memory allocation patterns for different data types
- Tensor operations on allocated memory
To run this example:
.. code-block:: bash
python examples/ampere/smem_allocator.py
The example will allocate shared memory, perform tensor operations, and verify the results.
"""
@cute.struct
class complex:
real: cutlass.Float32
imag: cutlass.Float32
# SharedStorage size is 512, alignment is 128
@cute.struct
class SharedStorage:
# struct elements with natural alignment
a: cute.struct.MemRange[cutlass.Float32, 32] # array
b: cutlass.Int64 # saclar
c: complex # nested struct
# struct elements with strict alignment
x: cute.struct.Align[
cute.struct.MemRange[cutlass.Float32, 32],
128,
]
y: cute.struct.Align[cutlass.Int32, 8]
z: cute.struct.Align[complex, 16]
@cute.kernel
def kernel(
const_a: cutlass.Constexpr,
dst_a: cute.Tensor,
const_b: cutlass.Constexpr,
dst_b: cute.Tensor,
const_c: cutlass.Constexpr,
dst_c: cute.Tensor,
):
# Note: SMEM_SIZE bytes (specified in kernel().launch(smem=...)) can be reserved for developer to utilize
# Note: alignment of inital allocator base ptr is 1024
allocator = cutlass.utils.SmemAllocator()
# base ptr of allocator points at: SMEM_ADDR_START (the starting address of available shared memory)
# -- Allocate a struct --
# Note: when specified alignment, max(alignment, alignof(struct)) will be applied
# reserves the section of struct in smem, elements in the struct can be accessed by ptr
struct_in_smem = allocator.allocate(SharedStorage)
# base ptr of allocator now points at: SMEM_ADDR_AFTER_STRUCT = SMEM_ADDR_START + aligned_size(struct)
# -- Allocate a block of memory --
# reserves a section of 64 bytes in smem, align to 128 bytes, returns the section base ptr
section_in_smem = allocator.allocate(64, byte_alignment=128)
# base ptr of allocator now points at: SMEM_ADDR_AFTER_SECTION = SMEM_ADDR_AFTER_STRUCT + aligned_size(section)
# -- Allocate an array --
# reserves an int64 array of size 14 in smem, returns the array base ptr
array_in_smem = allocator.allocate_array(element_type=cutlass.Int64, num_elems=14)
# base ptr of allocator now points at: SMEM_ADDR_AFTER_ARRAY = SMEM_ADDR_AFTER_SECTION + aligned_size(array)
# -- Allocate a tensor --
# Note: use cute.ComposedLayout or cute.Layout to specify layout of tensor
# Note: iterator swizzle with swizzle layout is currently not supported
layout = cute.make_layout((16, 2))
tensor_in_smem = allocator.allocate_tensor(
element_type=cutlass.Float32, layout=layout, byte_alignment=32, swizzle=None
)
# base ptr of allocator now points at: SMEM_ADDR_AFTER_TENSOR = SMEM_ADDR_AFTER_ARRAY + aligned_size(tensor)
# ptr<f16, smem, align<1024>>
# ptr<i64, smem, align<128>>
# ptr<f32, smem, align<8>>
print(struct_in_smem.a.data_ptr())
print(struct_in_smem.b)
print(struct_in_smem.c.real)
# ptr<i8, smem, align<512>>
print(section_in_smem)
# ptr<i64, smem, align<64>>
print(array_in_smem)
# tensor<ptr<f16, smem, align<32>> o (16,4):(1,16)>
print(tensor_in_smem)
# fill MemRange tensor in struct and copy to dst
a_tensor = struct_in_smem.a.get_tensor(cute.make_layout((8, 4)))
a_tensor.fill(const_a)
cute.printf("cute.struct.MemRange: {}", a_tensor)
dst_a.store(a_tensor.load())
# convert block of smem to fill tensor and copy to dst
layout = cute.make_layout((8, 2))
sec_ptr = cute.recast_ptr(section_in_smem, dtype=cutlass.Float32)
sec_tensor = cute.make_tensor(sec_ptr, layout)
sec_tensor.fill(const_b)
cute.printf("block of memory: {}", sec_tensor)
dst_b.store(sec_tensor.load())
# fill allocated tensor in smem and copy to dst
tensor_in_smem.fill(const_c)
cute.printf("tensor in smem: {}", tensor_in_smem)
dst_c.store(tensor_in_smem.load())
@cute.jit
def run_allocation_kernel(
const_a: cutlass.Constexpr,
dst_a: cute.Tensor,
const_b: cutlass.Constexpr,
dst_b: cute.Tensor,
const_c: cutlass.Constexpr,
dst_c: cute.Tensor,
):
# additional size for the example, 64(section) + 112(array) + 128(tensor) < 384
addtional_bytes = 384
# Note: launch shared memory size is: SMEM_SIZE = 512 + 384 = 896 bytes
kernel(const_a, dst_a, const_b, dst_b, const_c, dst_c).launch(
grid=(1, 1, 1),
block=(1, 1, 1),
smem=SharedStorage.size_in_bytes() + addtional_bytes,
)
def veify_allocation_kernel(const_a, const_b, const_c):
dst_a = torch.zeros((8, 4), dtype=torch.float32, device="cuda")
dst_b = torch.zeros((8, 2), dtype=torch.float32, device="cuda")
dst_c = torch.zeros((16, 2), dtype=torch.float32, device="cuda")
run_allocation_kernel(
const_a,
from_dlpack(dst_a),
const_b,
from_dlpack(dst_b),
const_c,
from_dlpack(dst_c),
)
np.testing.assert_equal(const_a, dst_a.detach().cpu().numpy()[0])
np.testing.assert_equal(const_b, dst_b.detach().cpu().numpy()[0])
np.testing.assert_equal(const_c, dst_c.detach().cpu().numpy()[0])
if __name__ == "__main__":
# prepare cuda context
cutlass.cuda.initialize_cuda_context()
# An example for shared memory allocation
const_a = 0.5
const_b = 1.0
const_c = 2.0
veify_allocation_kernel(const_a, const_b, const_c)

View File

@ -0,0 +1,51 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cmake_minimum_required(VERSION 3.15)
project(tensor)
# Find Python
find_package(Python COMPONENTS Interpreter Development REQUIRED)
# Get Python site-packages directory using Python
execute_process(
COMMAND ${Python_EXECUTABLE} -c "import site; print(site.getsitepackages()[0])"
OUTPUT_VARIABLE Python_SITE_PACKAGES
OUTPUT_STRIP_TRAILING_WHITESPACE
)
message(STATUS "Python site-packages directory: ${Python_SITE_PACKAGES}")
# Add nanobind path to CMAKE_PREFIX_PATH
list(APPEND CMAKE_PREFIX_PATH ${Python_SITE_PACKAGES}/nanobind/cmake)
# Find nanobind
find_package(nanobind REQUIRED)
# Add the module
nanobind_add_module(tensor tensor.cpp)

View File

@ -0,0 +1,305 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Example of accessing POD (Plain Old Data) from C or other languages via LLVM operations.
This example demonstrates a basic approach to building customized interfaces as C-structures between user code
and JIT compiled functions. It provides a minimal-cost solution for calling JIT functions
and can be used to build AOT (Ahead-of-Time) launchers for JIT compiled functions.
The C-structure is defined as:
.. code-block:: c
struct Tensor {
void *ptr; // Pointer to tensor data
int32_t shape[3]; // Tensor dimensions
int32_t strides[3]; // Memory strides for each dimension
};
The example defines Tensor and TensorValue classes that wrap C structs for view of a tensor with its data pointer,
shape, and strides, enabling efficient data passing between different language boundaries.
.. note::
Future development may include automated code generation flows.
"""
import cutlass
import cutlass.cute as cute
from cutlass._mlir import ir
from cutlass._mlir.dialects import llvm
import cutlass._mlir.extras.types as T
class ExampleTensorValue(ir.Value):
"""A wrapper class for tensor values in MLIR.
This class extends ir.Value to provide convenient access to tensor data pointer,
shape, and strides through MLIR operations.
:type: ir.Value
"""
def __init__(self, v):
"""Initialize a new TensorValue.
:param v: The underlying MLIR value to wrap
:type v: ir.Value
"""
super().__init__(v)
@property
def data_ptr(self, *, loc=None, ip=None):
"""Get the data pointer from the tensor value.
Extracts the data pointer (first field) from the LLVM struct value.
:param loc: Optional location information for MLIR operations
:type loc: Optional[ir.Location]
:param ip: Optional insertion point for MLIR operations
:type ip: Optional[ir.InsertionPoint]
:return: An integer value representing the data pointer
:rtype: ir.Value
"""
# Extract the data pointer from the LLVM struct value
# The data pointer is the first field (index 0) in the struct
# Use llvm.extractvalue to get the pointer field from the struct
ptr_val = llvm.extractvalue(
llvm.PointerType.get(),
self,
[0], # Extract the first field (index 0)
loc=loc,
ip=ip,
)
return cute.make_ptr(cutlass.Float32, ptr_val)
@property
def shape(self):
"""Get the shape of the tensor.
Extracts the shape (second field) from the LLVM struct value.
:return: A tuple of integers representing the tensor dimensions
:rtype: tuple[ir.Value, ...]
"""
i32_type = ir.IntegerType.get_signless(32)
# Extract the shape field from the LLVM struct value
# The shape is the second field (index 1) in the struct
shape_val = llvm.extractvalue(
llvm.StructType.get_literal([i32_type] * 3),
self,
[1], # Extract the second field (index 1)
)
# Extract each dimension from the shape struct
return tuple(llvm.extractvalue(i32_type, shape_val, [i]) for i in range(3))
@property
def stride(self):
"""Get the strides of the tensor.
Extracts the strides (third field) from the LLVM struct value.
:return: A tuple of integers representing the tensor strides
:rtype: tuple[ir.Value, ...]
"""
i32_type = ir.IntegerType.get_signless(32)
# Extract the strides field from the LLVM struct value
# The strides are the third field (index 2) in the struct
strides_val = llvm.extractvalue(
llvm.StructType.get_literal([i32_type] * 3),
self,
[2], # Extract the third field (index 2)
)
# Extract each dimension from the strides struct
return tuple(llvm.extractvalue(i32_type, strides_val, [i]) for i in range(3))
class ExampleTensor:
"""A class representing a tensor with its data pointer, shape, and strides.
This class provides a Python interface to create and manipulate tensor structures
that can be passed to CUTE JIT compiled functions.
:ivar _c_struct_p: The C struct pointer for the tensor
:ivar _rank: The number of dimensions in the tensor
"""
def __init__(self, c_struct_p, rank):
"""Initialize a new Tensor.
:param c_struct_p: The C struct pointer for the tensor
:type c_struct_p: int
:param rank: The number of dimensions in the tensor
:type rank: int
"""
self._c_struct_p = c_struct_p
self._rank = rank
def __get_mlir_types__(self):
"""Get the MLIR types for this tensor.
Creates an LLVM structure type representing a C-structure with:
.. code-block:: c
struct Tensor {
void *ptr;
int32_t shape[3];
int32_t strides[3];
};
:return: A list containing the MLIR struct type
:rtype: list[llvm.StructType]
Create an LLVM structure type that represents a C-structure like:
"""
# Get the number of dimensions from the shape
ndim = self._rank
# Create the pointer type (void*)
ptr_type = llvm.PointerType.get()
# Create array types for shape and strides (int32_t[ndim])
int32_type = ir.IntegerType.get_signless(32)
shape_type = llvm.StructType.get_literal([int32_type] * ndim)
strides_type = llvm.StructType.get_literal([int32_type] * ndim)
# Create the structure type
struct_type = llvm.StructType.get_literal([ptr_type, shape_type, strides_type])
return [struct_type]
def __new_from_mlir_values__(self, values):
"""Create a new TensorValue from MLIR values.
:param values: A list of MLIR values
:type values: list[ir.Value]
:return: A new TensorValue instance
:rtype: TensorValue
"""
return ExampleTensorValue(values[0])
def __c_pointers__(self):
"""Get the C pointers for this tensor.
:return: A list containing the C struct pointer
:rtype: list[int]
"""
return [self._c_struct_p]
@cute.jit
def foo(tensor):
"""Example JIT function that prints tensor information.
:param tensor: A Tensor instance to print information about
:type tensor: Tensor
"""
cute.printf("data_ptr: {}", tensor.data_ptr)
cute.printf("shape: {}", tensor.shape)
cute.printf("stride: {}", tensor.stride)
mA = cute.make_tensor(
tensor.data_ptr, cute.make_layout(tensor.shape, stride=tensor.stride)
)
cute.print_tensor(mA)
import sys
import os
import subprocess
import shutil
import tempfile
import torch
def run_test(tmpdir=None):
# Skip cleanup if user provides tmpdir
cleanup = tmpdir is None
# Initialize temporary build directory
tmpdir = tmpdir or tempfile.mkdtemp()
try:
current_dir = os.path.dirname(os.path.abspath(__file__))
subprocess.run(["cmake", "-B", tmpdir, current_dir], check=True)
subprocess.run(["cmake", "--build", tmpdir], check=True)
sys.path.append(tmpdir)
from tensor import make_tensor, pycapsule_get_pointer
# Mock test tensor and corresponding C structure for this example
# In production, this may come from external library
x = torch.arange(2 * 8 * 4).to(torch.float32).reshape(2, 8, 4)
c_struct = make_tensor(x.data_ptr(), x.shape, x.stride())
c_struct_p = pycapsule_get_pointer(c_struct)
# Initialize tensor wrapper and compile test function
tensor = ExampleTensor(c_struct_p, len(x.shape))
compiled_func = cute.compile(foo, tensor)
# Benchmark pointer access performance
from time import time
start = time()
# Measure performance of critical path pointer access
# get C pointers is on critical path to call JIT compiled function
for _ in range(1000):
tensor.__c_pointers__()
end = time()
print(f"__c_pointers__: {(end - start) * 1000} us")
# Execute compiled function
compiled_func(tensor)
except Exception as e:
print(e)
finally:
if cleanup:
# Clean up the temporary directory
shutil.rmtree(tmpdir)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description="Set temporary directory for building C modules"
)
parser.add_argument(
"--tmp-dir", type=str, help="Temporary directory path for building C modules"
)
args = parser.parse_args()
run_test(args.tmp_dir)

View File

@ -0,0 +1,82 @@
// Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// 1. Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
#include <cstdint>
#include <nanobind/nanobind.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/vector.h>
namespace nb = nanobind;
// Forward declaration of the MockTensor struct for testing only
struct MockTensor {
void *ptr;
struct {
int32_t shape[3];
} shape;
struct {
int32_t strides[3];
} strides;
};
NB_MODULE(tensor, m) {
// create a tensor for testing
m.def("make_tensor", [](int64_t ptr, std::vector<int32_t> shape,
std::vector<int32_t> strides) {
auto *tensor = new MockTensor();
tensor->ptr = reinterpret_cast<void *>(ptr);
assert(shape.size() == 3 && "shape must have 3 elements");
assert(strides.size() == 3 && "strides must have 3 elements");
for (size_t i = 0; i < shape.size(); i++) {
tensor->shape.shape[i] = shape[i];
tensor->strides.strides[i] = strides[i];
}
return nb::steal(PyCapsule_New(tensor, "tensor", [](PyObject *capsule) {
auto n = PyCapsule_GetName(capsule);
if (void *p = PyCapsule_GetPointer(capsule, n)) {
delete reinterpret_cast<MockTensor *>(p);
}
}));
});
m.def(
"pycapsule_get_pointer",
[](nb::object &capsule) {
void *ptr = PyCapsule_GetPointer(capsule.ptr(), "tensor");
if (!ptr) {
throw std::runtime_error("Invalid tensor capsule");
}
return reinterpret_cast<uintptr_t>(ptr);
},
"Get pointer from PyCapsule");
}

File diff suppressed because it is too large Load Diff

View File

@ -83,11 +83,6 @@
"\n",
" # Print hello world from host code\n",
" cute.printf(\"hello world\")\n",
" \n",
" # Initialize CUDA context for launching a kernel with error checking\n",
" # We make context initialization explicit to allow users to control the context creation \n",
" # and avoid potential issues with multiple contexts\n",
" cutlass.cuda.initialize_cuda_context()\n",
"\n",
" # Launch kernel\n",
" kernel().launch(\n",
@ -129,6 +124,11 @@
}
],
"source": [
"# Initialize CUDA context for launching a kernel with error checking\n",
"# We make context initialization explicit to allow users to control the context creation \n",
"# and avoid potential issues with multiple contexts\n",
"cutlass.cuda.initialize_cuda_context()\n",
"\n",
"# Method 1: Just-In-Time (JIT) compilation - compiles and runs the code immediately\n",
"print(\"Running hello_world()...\")\n",
"hello_world()\n",
@ -136,6 +136,7 @@
"# Method 2: Compile first (useful if you want to run the same code multiple times)\n",
"print(\"Compiling...\")\n",
"hello_world_compiled = cute.compile(hello_world)\n",
"\n",
"# Run the pre-compiled version\n",
"print(\"Running compiled version...\")\n",
"hello_world_compiled()"