Revert "[ex77] fix mla split; add fwd lse; add bwd varlen (#2366)" (#2370)

This reverts commit f12b1d75c9.
This commit is contained in:
Manish Gupta
2025-06-05 20:14:57 -07:00
committed by GitHub
parent f12b1d75c9
commit 2e2af190bd
19 changed files with 326 additions and 846 deletions

View File

@ -117,17 +117,15 @@ struct Options {
int q = 256;
int k = 256;
int d = 128;
int warmup_iterations = 1;
int iterations = 3;
int tensor_ring_buffers = 1;
bool verify = false;
bool verbose = false;
bool causal = false;
bool residual = false;
bool varlen = false;
bool persistent = false;
int sm_count = 0;
std::string kernel_filter;
InitStyle init_style_q = InitStyle::kRandom;
@ -191,15 +189,10 @@ struct Options {
if (b == -1) b = 16384 / k;
if (b == 0) b = 1;
cmd.get_cmd_line_argument("warmup_iterations", warmup_iterations, defaults.warmup_iterations);
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
cmd.get_cmd_line_argument("tensor_ring_buffers", tensor_ring_buffers, defaults.tensor_ring_buffers);
verify = cmd.check_cmd_line_flag("verify");
verbose = cmd.check_cmd_line_flag("verbose");
varlen = cmd.check_cmd_line_flag("varlen");
persistent = cmd.check_cmd_line_flag("persistent");
std::string mask;
cmd.get_cmd_line_argument<std::string>("mask", mask, "");
if (mask == "no" || mask == "") {
@ -217,7 +210,7 @@ struct Options {
causal = false;
}
cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count);
get_init_style_argument(cmd, "init-style", init_style_q, defaults.init_style_q);
get_init_style_argument(cmd, "init-style", init_style_k, defaults.init_style_q);
get_init_style_argument(cmd, "init-style", init_style_v, defaults.init_style_q);
@ -242,13 +235,10 @@ struct Options {
<< " --q=<int> Sets the Q extent\n"
<< " --k=<int> Sets the K extent\n"
<< " --d=<int> Sets the D extentn"
<< " --tensor_ring_buffers=<int> Sets the number of tensor ring buffers\n"
<< " --warmup_iterations=<int> Sets the warmup iterations\n"
<< " --iterations=<int> Benchmarking iterations\n"
<< " --verify Verify results\n"
<< " --verbose Print smem and execution time per kernel\n"
<< " --mask=<no|residual|causal> Enables masking\n"
<< " --persistent Enables persistent scheduler\n"
<< " --varlen Enables variable sequence length\n"
<< " B*Q and B*K become the total sequence length\n"
<< " and are split B-ways, alternatingly +10% and -10%\n"
@ -389,55 +379,40 @@ struct FwdRunner {
StrideLSE stride_LSE;
uint64_t seed = 0;
struct DeviceBuffer {
DeviceAllocation<Element> block_Q;
DeviceAllocation<Element> block_K;
DeviceAllocation<Element> block_V;
DeviceAllocation<ElementOut> block_O;
DeviceAllocation<ElementAccumulatorPV> block_LSE;
DeviceAllocation<ElementOut> block_ref_O;
DeviceAllocation<ElementAccumulatorPV> block_ref_LSE;
DeviceAllocation<int> device_cumulative_seqlen_q;
DeviceAllocation<int> device_cumulative_seqlen_kv;
DeviceBuffer() = default;
DeviceBuffer(const DeviceBuffer&) = delete;
DeviceBuffer& operator=(const DeviceBuffer&) = delete;
size_t get_storage_size() const {
return block_Q.get_storage_size() + block_K.get_storage_size() + block_V.get_storage_size()
+ block_O.get_storage_size() + block_LSE.get_storage_size() + block_ref_O.get_storage_size()
+ block_ref_LSE.get_storage_size() + device_cumulative_seqlen_q.get_storage_size()
+ device_cumulative_seqlen_kv.get_storage_size();
}
};
std::vector<std::unique_ptr<DeviceBuffer>> buffers;
DeviceAllocation<Element> block_Q;
DeviceAllocation<Element> block_K;
DeviceAllocation<Element> block_V;
DeviceAllocation<ElementOut> block_O;
DeviceAllocation<ElementAccumulatorPV> block_LSE;
DeviceAllocation<ElementOut> block_ref_O;
DeviceAllocation<ElementAccumulatorPV> block_ref_LSE;
std::vector<int> cumulative_seqlen_q;
std::vector<int> cumulative_seqlen_kv;
DeviceAllocation<int> device_cumulative_seqlen_q;
DeviceAllocation<int> device_cumulative_seqlen_kv;
//
// Methods
//
bool verify(const ProblemShapeType& problem_shape, DeviceBuffer& buffer) {
Tensor mQ = make_tensor(make_gmem_ptr(buffer.block_Q.get()),
bool verify(const ProblemShapeType& problem_shape) {
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()),
select<0,2,3>(problem_shape),
stride_Q);
Tensor mK = make_tensor(make_gmem_ptr(buffer.block_K.get()),
Tensor mK = make_tensor(make_gmem_ptr(block_K.get()),
select<1,2,3>(problem_shape),
stride_K);
Tensor mV = make_tensor(make_gmem_ptr(buffer.block_V.get()),
Tensor mV = make_tensor(make_gmem_ptr(block_V.get()),
select<1,2,3>(problem_shape),
stride_V);
Tensor mO = make_tensor(make_gmem_ptr(buffer.block_ref_O.get()),
Tensor mO = make_tensor(make_gmem_ptr(block_ref_O.get()),
select<0,2,3>(problem_shape),
stride_O);
Tensor mLSE = make_tensor(make_gmem_ptr(buffer.block_ref_LSE.get()),
Tensor mLSE = make_tensor(make_gmem_ptr(block_ref_LSE.get()),
select<0,3>(problem_shape),
stride_LSE);
@ -456,7 +431,7 @@ struct FwdRunner {
// Check if output from CUTLASS kernel and reference kernel are equal or not
double max_diff = 0;
double mean_diff = 0;
reference_abs_diff(buffer.block_O, buffer.block_ref_O, max_diff, mean_diff);
reference_abs_diff(block_O, block_ref_O, max_diff, mean_diff);
bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
if (! passed_O) {
@ -464,13 +439,14 @@ struct FwdRunner {
<< " mean " << mean_diff << std::endl;
}
reference_abs_diff(buffer.block_LSE, buffer.block_ref_LSE, max_diff, mean_diff);
// reference_abs_diff(block_LSE, block_ref_LSE, max_diff, mean_diff);
bool passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
if ( ! passed_LSE) {
std::cerr << "failed LSE: max diff " << max_diff
<< " mean " << mean_diff << std::endl;
}
bool passed_LSE = true; // future work
// bool passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
// if ( ! passed_LSE) {
// std::cerr << "failed LSE: max diff " << max_diff
// << " mean " << mean_diff << std::endl;
// }
return passed_O && passed_LSE;
}
@ -583,71 +559,50 @@ struct FwdRunner {
get<1,1>(stride_LSE) = 0;
}
auto buffer_init_fn = [&](auto& buffer) {
buffer.block_Q.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
buffer.block_K.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
buffer.block_V.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
buffer.block_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
buffer.block_LSE.reset(size(shape_LSE));
block_Q.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
block_K.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
block_V.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
block_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
block_LSE.reset(size(shape_LSE));
block_ref_O.reset(size(shape_QO));
block_ref_LSE.reset(size(shape_LSE));
initialize_block(buffer.block_Q, seed + 2023, options.init_style_q);
initialize_block(buffer.block_K, seed + 2022, options.init_style_k);
initialize_block(buffer.block_V, seed + 2021, options.init_style_v);
initialize_block(block_Q, seed + 2023, options.init_style_q);
initialize_block(block_K, seed + 2022, options.init_style_k);
initialize_block(block_V, seed + 2021, options.init_style_v);
if ( ! cumulative_seqlen_q.empty()) {
buffer.device_cumulative_seqlen_q.reset(cumulative_seqlen_q.size());
buffer.device_cumulative_seqlen_q.copy_from_host(
cumulative_seqlen_q.data(), cumulative_seqlen_q.size());
}
if ( ! cumulative_seqlen_kv.empty()) {
buffer.device_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size());
buffer.device_cumulative_seqlen_kv.copy_from_host(
cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size());
}
};
buffers.push_back(std::make_unique<DeviceBuffer>());
buffer_init_fn(*buffers.back());
int tensor_ring_buffers = options.tensor_ring_buffers;
for (int i = 1; i < tensor_ring_buffers; i++) {
buffers.push_back(std::make_unique<DeviceBuffer>());
buffer_init_fn(*buffers.back());
if ( ! cumulative_seqlen_q.empty()) {
device_cumulative_seqlen_q.reset(cumulative_seqlen_q.size());
device_cumulative_seqlen_q.copy_from_host(
cumulative_seqlen_q.data(), cumulative_seqlen_q.size());
}
if ( ! cumulative_seqlen_kv.empty()) {
device_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size());
device_cumulative_seqlen_kv.copy_from_host(
cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size());
}
if constexpr (kIsVarlen) {
get<0>(problem_shape).cumulative_length = buffers[0]->device_cumulative_seqlen_q.get();
get<1>(problem_shape).cumulative_length = buffers[0]->device_cumulative_seqlen_kv.get();
get<0>(problem_shape).cumulative_length = device_cumulative_seqlen_q.get();
get<1>(problem_shape).cumulative_length = device_cumulative_seqlen_kv.get();
}
return problem_shape;
}
auto get_arguments(const ProblemShapeType& problem_shape, const cutlass::KernelHardwareInfo& hw_info, int buffer_index) {
auto problem_shape_ = problem_shape;
if constexpr (kIsVarlen) {
get<0>(problem_shape_).cumulative_length = buffers[buffer_index]->device_cumulative_seqlen_q.get();
get<1>(problem_shape_).cumulative_length = buffers[buffer_index]->device_cumulative_seqlen_kv.get();
}
typename Operation::Arguments arguments{
problem_shape_,
{ buffers[buffer_index]->block_Q.get(), stride_Q,
buffers[buffer_index]->block_K.get(), stride_K,
buffers[buffer_index]->block_V.get(), stride_V },
{ buffers[buffer_index]->block_O.get(), stride_O,
buffers[buffer_index]->block_LSE.get(), stride_LSE },
hw_info
};
return arguments;
}
ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
ProblemShapeType problem_shape = initialize(options);
int buffer_index = 0;
typename Operation::Arguments arguments = get_arguments(problem_shape, hw_info, buffer_index);
typename Operation::Arguments arguments{
problem_shape,
{ block_Q.get(), stride_Q,
block_K.get(), stride_K,
block_V.get(), stride_V },
{ block_O.get(), stride_O,
block_LSE.get(), stride_LSE },
hw_info
};
Operation op;
@ -675,21 +630,11 @@ struct FwdRunner {
}
// Run
for (int i = 0; i < options.warmup_iterations; i++) {
status = op.run();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return example_result;
}
buffer_index = (buffer_index + 1) % buffers.size();
arguments = get_arguments(problem_shape, hw_info, buffer_index);
status = op.update(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to update the CUTLASS kernel's parameters. Last CUDA error is: "
<< std::endl;
return example_result;
}
status = op.run();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return example_result;
}
cudaError_t result = cudaDeviceSynchronize();
@ -727,14 +672,6 @@ struct FwdRunner {
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return example_result;
}
buffer_index = (buffer_index + 1) % buffers.size();
arguments = get_arguments(problem_shape, hw_info, buffer_index);
status = op.update(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to update the CUTLASS kernel's parameters. Last CUDA error is: "
<< std::endl;
return example_result;
}
}
//
@ -797,10 +734,10 @@ struct FwdRunner {
// Verify that the result is correct
bool passed = true;
if (options.verify) {
passed = verify(problem_shape, *buffers[0]);
passed = verify(problem_shape);
if (passed) example_result.verified = true;
}
if (!passed) {
std::cerr << "Reference check failed" << std::endl;
return example_result;
@ -852,14 +789,10 @@ void run_fwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareIn
using HeadDim = _128;
if (options.persistent) {
// Persistent Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
}
else {
// Individual Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
}
// Persistent Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
// Individual Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
}
///////////////////////////////////////////////////////////////////////////////////////////////////
@ -885,14 +818,10 @@ void run_fwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInf
using HeadDim = _64;
if (options.persistent) {
// Persistent Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
}
else {
// Individual Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
}
// Persistent Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
// Individual Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
}
@ -916,14 +845,10 @@ void run_fwd_32(Mask fusion, Options const & options, cutlass::KernelHardwareInf
using HeadDim = _32;
#ifdef FP8
if (options.persistent) {
// Persistent Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
}
else {
// Individual Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
}
// Persistent Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
// Individual Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
#endif
}
@ -988,7 +913,6 @@ int main_single(int argc, char const **args) {
hw_info.sm_count = options.sm_count;
}
std::cout << "###### B " << options.b << " H " << options.h << " H_K " << options.h_k << " Q " << options.q << " K " << options.k << " D " << options.d << " ";
std::cout << "Forward" << " " << (options.causal ? "Causal" : (options.residual ? "Residual" : "None")) << " ";
std::cout << "#SM " << hw_info.sm_count << std::endl;

View File

@ -120,8 +120,6 @@ struct Options {
bool verbose = false;
bool causal = false;
bool residual = false;
bool varlen = false;
int sm_count = 0;
std::string kernel_filter;
@ -192,21 +190,14 @@ struct Options {
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
verify = cmd.check_cmd_line_flag("verify");
verbose = cmd.check_cmd_line_flag("verbose");
varlen = cmd.check_cmd_line_flag("varlen");
std::string mask;
cmd.get_cmd_line_argument<std::string>("mask", mask, "");
if (mask == "causal") {
causal = true;
}
else if (mask == "residual") {
residual = true;
}
else {
causal = defaults.causal;
}
if (varlen) {
residual = true;
}
skip_reference = cmd.check_cmd_line_flag("skip-reference");
cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count);
@ -239,12 +230,7 @@ struct Options {
<< " --iterations=<int> Benchmarking iterations\n"
<< " --verify Verify results\n"
<< " --verbose Print smem and execution time per kernel\n"
<< " --mask=<no|residual|causal> Enables masking\n"
<< " --varlen Enables variable sequence length\n"
<< " B*Q and B*K become the total sequence length\n"
<< " and are split B-ways, alternatingly +10% and -10%\n"
<< " with the last batch sized to make it fit\n"
<< " implies at least residual masking for correctness\n"
<< " --mask=<no|causal> Enables masking\n"
<< " --sm-count Sets SM count rather than querying it\n"
<< " --kernel-filter=<filter> Sets regexp to match kernel against\n"
<< "\n";
@ -321,7 +307,6 @@ struct ExampleResult {
///////////////////////////////////////////////////////////////////////////////////////////////////
template<
bool kIsVarlen,
class TileShape,
class DispatchPolicy,
class ActiveMask,
@ -337,11 +322,9 @@ struct BwdRunner {
using ElementAccumulator = float;
// Q K D (H B)
using ProblemShape = std::conditional_t<
kIsVarlen,
cute::tuple<VariableLength, VariableLength, int, cute::tuple<int, int>>,
cute::tuple<int, int, int, cute::tuple<int, int>>
>;
using ProblemShapeType = cute::tuple<int, int, int, cute::tuple<int, int>>;
using Operation = cutlass::fmha::device::Sm100FmhaBwd<Element, ElementAccumulator, TileShape, ActiveMask>;
using TensorStride = Stride<int, _1, Stride<int, int>>; // Seq D (H B)
using StrideQ = TensorStride;
@ -380,9 +363,6 @@ struct BwdRunner {
DeviceAllocation<Element> block_O;
DeviceAllocation<ElementAccumulator> block_LSE;
DeviceAllocation<int> block_cumulative_seqlen_q;
DeviceAllocation<int> block_cumulative_seqlen_kv;
DeviceAllocation<Element> block_dQ;
DeviceAllocation<Element> block_dK;
DeviceAllocation<Element> block_dV;
@ -395,7 +375,7 @@ struct BwdRunner {
//
// Methods
//
bool verify(const ProblemShape& problem_shape) {
bool verify(const ProblemShapeType& problem_shape) {
auto [Q, K, D, HB] = problem_shape;
auto [H, B] = HB;
@ -479,85 +459,22 @@ struct BwdRunner {
return passed_dQ && passed_dK && passed_dV;
}
auto initialize_problem_shape(Options const& options) {
if constexpr (kIsVarlen) {
int num_batches = options.b;
// generate Q as --b times
// gaussian (--Q, --Q / 2) sampled positive
// track cumulative
std::mt19937 rng(0x202305151552ull);
std::normal_distribution<double> dist_q(options.q, options.q / 2);
std::normal_distribution<double> dist_kv(options.k, options.k / 2);
auto generate_positive_int = [](auto& dist, auto& gen) {
// "0" is a valid value we test here
return std::max(0, static_cast<int>(dist(gen)));
};
std::vector<int> cumulative_seqlen_q = {0};
std::vector<int> cumulative_seqlen_kv = {0};
int total_seqlen_q = 0;
int total_seqlen_kv = 0;
int max_seqlen_q = 0;
int max_seqlen_kv = 0;
const bool kVarlenSame = false;
for (int i = 0; i < num_batches; i++) {
int seqlen_q = kVarlenSame ? options.q : generate_positive_int(dist_q, rng);
int seqlen_kv = kVarlenSame ? options.k : generate_positive_int(dist_kv, rng);
total_seqlen_q += seqlen_q;
total_seqlen_kv += seqlen_kv;
max_seqlen_q = std::max(max_seqlen_q, seqlen_q);
max_seqlen_kv = std::max(max_seqlen_kv, seqlen_kv);
cumulative_seqlen_q.push_back(cumulative_seqlen_q.back() + seqlen_q);
cumulative_seqlen_kv.push_back(cumulative_seqlen_kv.back() + seqlen_kv);
}
block_cumulative_seqlen_q.reset(cumulative_seqlen_q.size());
block_cumulative_seqlen_q.copy_from_host(cumulative_seqlen_q.data(), cumulative_seqlen_q.size());
block_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size());
block_cumulative_seqlen_kv.copy_from_host(cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size());
ProblemShape problem_shape{
{max_seqlen_q, block_cumulative_seqlen_q.get(), total_seqlen_q},
{max_seqlen_kv, block_cumulative_seqlen_kv.get(), total_seqlen_kv},
options.d, {options.h, options.b}
};
auto tensor_shape = make_shape(total_seqlen_q, total_seqlen_kv, options.d, make_shape(options.h, 1));
return cute::make_tuple(problem_shape, tensor_shape);
}
else {
ProblemShape problem_shape{options.q, options.k, options.d, {options.h, options.b}};
return cute::make_tuple(problem_shape, problem_shape);
}
}
/// Initialize operands to be used in the GEMM and reference GEMM
ProblemShape initialize(Options const& options) {
auto [problem_shape, tensor_shape] = initialize_problem_shape(options);
auto [Q, K, D, HB] = tensor_shape;
void initialize(const ProblemShapeType& problem_shape, Options const& options) {
auto [Q, K, D, HB] = problem_shape;
auto [H, B] = HB;
D = cutlass::round_up(D, 8); // Alignment
Q = cutlass::round_up(Q, 8); // Alignment
// for varlen, Q == total_Q, K == total_K, B = 1
// but in problem_shape, they've got to be max_Q/max_K, and B = B
auto shape_QO = make_shape(Q, D, make_shape(H, B));
auto shape_KV = make_shape(K, D, make_shape(H, B));
auto shape_LSE = make_shape(Q, make_shape(H, B));
stride_Q = make_stride(D, _1{}, make_stride(D*Q, B == 1 ? 0 : D*Q*H));
stride_K = make_stride(D, _1{}, make_stride(D*K, B == 1 ? 0 : D*K*H));
stride_LSE = make_stride(_1{}, make_stride(Q, B == 1 ? 0 : Q*H));
auto shape_QO = select<0,2,3>(problem_shape);
auto shape_KV = select<1,2,3>(problem_shape);
auto shape_LSE = select<0,3>(problem_shape);
stride_Q = make_stride(D, _1{}, make_stride(D*Q, D*Q*H));
stride_K = make_stride(D, _1{}, make_stride(D*K, D*K*H));
stride_V = stride_K;
stride_O = stride_Q;
stride_LSE = make_stride(_1{}, make_stride(Q, Q*H));
stride_dQ = stride_Q;
stride_dK = stride_K;
@ -588,13 +505,6 @@ struct BwdRunner {
initialize_block(block_V, seed + 2021, options.init_style_v);
initialize_block(block_dO, seed + 2020, options.init_style_do);
initialize_block(block_dQ, seed + 2030, InitStyle::kOne);
initialize_block(block_dK, seed + 2031, InitStyle::kOne);
initialize_block(block_dV, seed + 2032, InitStyle::kOne);
initialize_block(block_ref_dQ, seed + 2033);
initialize_block(block_ref_dK, seed + 2034);
initialize_block(block_ref_dV, seed + 2035);
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()),
select<0,2,3>(problem_shape),
stride_Q);
@ -618,19 +528,15 @@ struct BwdRunner {
if (! options.skip_reference) {
fmha_reference(problem_shape, mQ, mK, mV, mO, mLSE, ActiveMask{});
}
return problem_shape;
}
ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
auto problem_shape = initialize(options);
auto problem_shape = make_shape(options.q, options.k, options.d, make_shape(options.h, options.b));
initialize(problem_shape, options);
ElementAccumulator softmax_scale = 1.0f / sqrtf(options.d);
ExampleResult example_result;
using Operation = cutlass::fmha::device::Sm100FmhaBwd<ProblemShape, Element, ElementAccumulator, TileShape, ActiveMask>;
typename Operation::Arguments arguments{
problem_shape,
block_Q.get(), stride_Q,
@ -648,6 +554,8 @@ struct BwdRunner {
Operation op;
ExampleResult example_result;
example_result.smem_size = Operation::Kernel::SharedStorageSize;
size_t workspace_size = 0;
@ -742,7 +650,7 @@ struct BwdRunner {
runtime_ms /= static_cast<float>(options.iterations);
double flops = 10.0 * (std::is_same_v<ActiveMask, CausalForBackwardMask> ? 0.5 : 1.0);
double flops = 10.0 * (std::is_same_v<ActiveMask, CausalMask> ? 0.5 : 1.0);
flops *= static_cast<double>(get<0>(problem_shape));
flops *= static_cast<double>(get<1>(problem_shape));
flops *= static_cast<double>(get<2>(problem_shape));
@ -798,28 +706,14 @@ void print_result(const std::string& description, ExampleResult result, bool ver
struct KernelCoop {};
///////////////////////////////////////////////////////////////////////////////////////////////////
template<class Fn>
auto dispatch_bool(bool value, Fn fn) {
if (value) {
return fn(std::true_type{});
}
else {
return fn(std::false_type{});
}
}
//////////////////////////////////////////////////////////////////////////////////////////////////
template<class Mask>
void run_bwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) {
dispatch_bool(options.varlen, [&](auto is_varlen) {
BwdRunner<decltype(is_varlen)::value, decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
auto result = runner.run(options, hw_info);
print_result(name, result, options.verbose);
});
BwdRunner<decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
auto result = runner.run(options, hw_info);
print_result(name, result, options.verbose);
};
using HeadDim = _64;
@ -832,11 +726,9 @@ void run_bwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInf
template<class Mask>
void run_bwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) {
dispatch_bool(options.varlen, [&](auto is_varlen) {
BwdRunner<decltype(is_varlen)::value, decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
auto result = runner.run(options, hw_info);
print_result(name, result, options.verbose);
});
BwdRunner<decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
auto result = runner.run(options, hw_info);
print_result(name, result, options.verbose);
};
using HeadDim = _128;
@ -911,10 +803,7 @@ int main_single(int argc, char const **args) {
auto with_causal = [&](auto fn) {
if (options.causal) {
fn(CausalForBackwardMask{});
}
else if (options.residual) {
fn(ResidualMaskForBackward{});
fn(CausalMask{});
}
else {
fn(NoMask{});

View File

@ -394,7 +394,6 @@ struct ExampleRunner {
fmha_fwd_gen_reference<ElementAcc>(
problem_shape, block_seqlen_kv.get(), block_cache_batch_idx.get(),
mQ, mNewK, mNewV, mCacheK, mCacheV, mO);
cudaError_t result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "Reference kernel failed. Last CUDA error: "
@ -409,7 +408,6 @@ struct ExampleRunner {
double max_diff = 0;
double mean_diff = 0;
reference_abs_diff(block_o, block_ref_o, max_diff, mean_diff);
bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
if (! passed_O) {
std::cerr << "failed O: max diff " << max_diff
@ -417,7 +415,6 @@ struct ExampleRunner {
}
reference_abs_diff(block_cache_k, block_ref_cache_k, max_diff, mean_diff);
bool passed_K = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
if ( ! passed_K) {
std::cerr << "failed Cache K: max diff " << max_diff
@ -425,7 +422,6 @@ struct ExampleRunner {
}
reference_abs_diff(block_cache_v, block_ref_cache_v, max_diff, mean_diff);
bool passed_V = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
if ( ! passed_V) {
std::cerr << "failed Cache V: max diff " << max_diff
@ -507,7 +503,6 @@ struct ExampleRunner {
block_ref_cache_k.copy_from_device(block_cache_k.get(), block_cache_k.size());
block_ref_cache_v.copy_from_device(block_cache_v.get(), block_cache_v.size());
block_seqlen_kv.reset(seqlen_kv.size());
block_seqlen_kv.copy_from_host(seqlen_kv.data(), seqlen_kv.size());
@ -726,7 +721,6 @@ int main_single(int argc, char const **args) {
<< "later (compute capability 90 or greater) and CUDA 12.0 or greater.\n";
return 0;
}
//
// Parse options
//

View File

@ -59,7 +59,7 @@ using namespace cutlass::fmha::kernel;
///////////////////////////////////////////////////////////////////////////////////////////////////
enum class InitStyle {
kOne, kLinearStride128, kLinearStride1, kRandom, kRandomLarge, kNone
kOne, kLinearStride128, kLinearStride1, kRandom, kNone
};
///////////////////////////////////////////////////////////////////////////////////////////////////
@ -98,9 +98,6 @@ struct Options {
if (s == "r") {
dst = InitStyle::kRandom;
}
else if (s == "l") {
dst = InitStyle::kRandomLarge;
}
else if (s == "1") {
dst = InitStyle::kOne;
}
@ -206,11 +203,6 @@ void initialize_block(
block.get(), block.size(), seed, (Element) -1, (Element) 1);
break;
}
case InitStyle::kRandomLarge: {
cutlass::reference::device::BlockFillRandomGaussian(
block.get(), block.size(), seed, (Element) -1, (Element) 100);
break;
}
case InitStyle::kLinearStride1: {
std::vector<Element> data(block.size());
for (size_t i = 0; i < block.size() / 128; i ++) {

View File

@ -63,7 +63,7 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
TEST_COMMAND_OPTIONS
TEST_BASIC
# TEST_CAUSAL
TEST_VARLEN
# TEST_VARLEN
# TEST_HDIM64
# TEST_GQA)
)
@ -119,7 +119,7 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
77_blackwell_fmha_bwd.cu
TEST_COMMAND_OPTIONS
TEST_BASIC
TEST_VARLEN
# TEST_GEN_VARLEN
# TEST_GEN_HDIM64
# TEST_GEN_GQA
# TEST_GEN_REMAP
@ -128,24 +128,20 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
target_include_directories(77_blackwell_fmha_bwd_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_fmha_bwd_${PREC} PRIVATE ${PREC_MACRO})
target_compile_options(77_blackwell_fmha_bwd_${PREC} PRIVATE -Xptxas -v)
endforeach()
# Add a target that builds all examples
add_custom_target(77_blackwell_fmha_all
DEPENDS
77_blackwell_fmha_fp8
77_blackwell_fmha_fp16
77_blackwell_fmha_gen_fp8
77_blackwell_fmha_gen_fp16
77_blackwell_mla_2sm_fp8
77_blackwell_mla_2sm_fp16
77_blackwell_mla_2sm_cpasync_fp8
77_blackwell_mla_2sm_cpasync_fp16
77_blackwell_mla_b2b_2sm_fp8
77_blackwell_mla_b2b_2sm_fp16
77_blackwell_fmha_bwd_fp8
77_blackwell_fmha_bwd_fp16
77_blackwell_fmha_bwd_sat_fp8
77_blackwell_fmha_bwd_sat_fp16
)
cutlass_example_add_executable(
77_blackwell_fmha_bwd_sat_${PREC}
77_blackwell_fmha_bwd.cu
TEST_COMMAND_OPTIONS
TEST_BASIC
# TEST_GEN_VARLEN
TEST_GEN_HDIM64
# TEST_GEN_GQA
# TEST_GEN_REMAP
# TEST_GEN_CACHEONLY)
)
target_include_directories(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE ${PREC_MACRO} SKIP_ATOMIC)
target_compile_options(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE -Xptxas -v)
endforeach()
endif()

View File

@ -132,58 +132,6 @@ struct ResidualMask : NoMask {
}
};
struct ResidualMaskForBackward : NoMask {
using Base = NoMask;
template <class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE int get_masked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
if (get<1>(problem_size) % get<1>(tile_shape) != 0) {
return 1;
}
return 0;
}
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_unmasked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
// if the sequence length does not divide the tile size evenly
if (get<1>(problem_size) % get<1>(tile_shape) != 0) {
return get_trip_count(blk_coord, tile_shape, problem_size) - 1;
}
return get_trip_count(blk_coord, tile_shape, problem_size);
}
template<class AccQK, class IndexQK, class ProblemSize>
CUTLASS_DEVICE
void apply_mask(
AccQK& acc_qk,
IndexQK const& index_qk,
ProblemSize const& problem_size) {
// This is useful is seqlen_k % kBlockN != 0 since it masks
// the remaining elements out from softmax.
// d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar
// issues as they are transparently taken care of by TMA and the
// epilogue, if it is instantiated with predication support.
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_qk); i++) {
auto pos = index_qk(i);
if (! elem_less(pos, select<0,1>(problem_size))) {
acc_qk(i) = -INFINITY;
}
}
}
};
struct CausalMask : NoMask {
using Base = NoMask;
@ -209,7 +157,8 @@ struct CausalMask : NoMask {
TileShape const& tile_shape,
ProblemSize const& problem_size) {
return ceil_div(get<0>(tile_shape), get<1>(tile_shape));
int trip_count = get_trip_count(blk_coord, tile_shape, problem_size);
return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));
}
template<class BlkCoord, class TileShape, class ProblemSize>
@ -248,57 +197,25 @@ struct CausalMask : NoMask {
};
struct CausalForBackwardMask : CausalMask, ResidualMaskForBackward {
using Base = CausalMask;
template<class AccQK, class IndexQK, class ProblemSize>
CUTLASS_DEVICE
void apply_mask(
AccQK& acc_qk,
IndexQK const& index_qk,
ProblemSize const& problem_size) {
// There are two ways to do causal if N_Q != N_K
// (1) is to assume that the Q is at the beginning of the matrix
// - this is what we demonstrate here
// (2) is that it is at the end of the matrix
// - this is usually what we want for inference settings
// where we only compute the next row and use cache for the rest
// - if you'd like this, you only need to add an offset like so:
// get<0>(pos) + offset_q < get<1>(pos)
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_qk); i++) {
auto pos = index_qk(i);
bool masked = (get<0>(pos) < get<1>(pos)) || !elem_less(pos, problem_size);
if (masked) {
acc_qk(i) = -INFINITY;
}
}
}
};
struct VariableLength {
int max_length;
int* cumulative_length = nullptr;
int total_length = -1;
CUTE_HOST_DEVICE operator int() const {
return max_length;
}
};
template<class T> struct is_variable_length_impl : std::false_type {};
template<> struct is_variable_length_impl<VariableLength> : std::true_type {};
template<class T> constexpr bool is_variable_length_v = is_variable_length_impl<remove_cvref_t<T>>::value;
template<class T> struct is_variable_length : std::false_type {};
template<> struct is_variable_length<VariableLength> : std::true_type {};
template<class T> constexpr bool is_variable_length_v = is_variable_length<T>::value;
template<class Shape, class Idx>
CUTE_HOST_DEVICE
constexpr auto
apply_variable_length(Shape const& shape, Idx const& idx) {
return transform_leaf(shape, [&](auto const& s) {
if constexpr (is_variable_length_v<decltype(s)>) {
if constexpr (is_variable_length_v<remove_cvref_t<decltype(s)>>) {
return s.cumulative_length[idx+1] - s.cumulative_length[idx];
}
else {
@ -313,7 +230,7 @@ constexpr auto
apply_variable_length(Shape const& shape, Coord const& coord, Idx const& idx) {
auto new_shape = apply_variable_length(shape, idx);
auto new_coord = transform_leaf(shape, coord, [&](auto const& s, auto const& c) {
if constexpr (is_variable_length_v<decltype(s)>) {
if constexpr (is_variable_length_v<remove_cvref_t<decltype(s)>>) {
return cute::make_tuple(c, s.cumulative_length[idx]);
}
else {
@ -323,30 +240,6 @@ apply_variable_length(Shape const& shape, Coord const& coord, Idx const& idx) {
return cute::make_tuple(new_shape, new_coord);
}
template<class Shape, class Coord>
CUTE_HOST_DEVICE
constexpr auto
apply_variable_length_offset(Shape const& shape, Coord const& coord) {
auto idx = back(back(coord));
auto result_shape = transform_leaf(shape, [&](auto const& s) {
if constexpr (is_variable_length_v<decltype(s)>) {
return s.cumulative_length[idx+1] - s.cumulative_length[idx];
}
else {
return s;
}
});
auto result_offset = transform_leaf(coord, shape, [&](auto const& c, auto const& s) {
if constexpr (is_variable_length_v<decltype(s)>) {
return s.cumulative_length[idx];
}
else {
return _0{};
}
});
return cute::make_tuple(result_shape, result_offset);
}
} // namespace cutlass::fmha::collective
namespace cute {

View File

@ -42,7 +42,7 @@ template<
class ElementAcc,
class TileShape, // Q, D, _
class StrideO, // Q, D, B
class StrideLSE_ // Q, B
class StrideLSE // Q, B
>
struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
@ -54,7 +54,6 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
// using SmemLayoutAtomO = decltype(make_ordered_layout(select<0,1>(TileShape{}), Step<_1, _0>{}));
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, replace<2>(TileShape{}, _2{}), Step<_2, _1, _3>{}));
using SmemLayoutO_ = SmemLayoutO;
using StrideLSE = StrideLSE_;
struct TensorStorage {
@ -80,9 +79,6 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
struct Params {
TMA_O tma_store_o;
ElementAcc* ptr_LSE;
StrideLSE dLSE;
};
template<class ProblemShape>
@ -114,9 +110,7 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
);
return {
tma_store_o,
args.ptr_LSE,
args.dLSE
tma_store_o
};
}
@ -125,10 +119,6 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
cute::prefetch_tma_descriptor(params.tma_store_o.get_tma_descriptor());
}
const Params& params;
CUTLASS_DEVICE Sm100FmhaFwdEpilogueTmaWarpspecialized(const Params& params) : params(params) {}
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
CUTLASS_DEVICE auto
store(

View File

@ -637,7 +637,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
}
tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv);
if (i == size(tTMEM_LOADrS) - kReleasePipeCount) {
order_s.arrive();
}
@ -691,7 +691,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3);
cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2);
float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y;
row_sum = local_row_sum;
if (final_call) {
@ -787,14 +787,14 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
// good values would be either 32 or 64
const int kCorrectionTileSize = 32 / sizeof(ElementOut);
using TMEM_LOAD = std::conditional_t<kCorrectionTileSize == 32, SM100_TMEM_LOAD_32dp32b32x, SM100_TMEM_LOAD_32dp32b16x>; // 4x32 threads with 64 cols of 32b elem
using TMEM_LOAD = std::conditional_t<kCorrectionTileSize == 32, SM100_TMEM_LOAD_32dp32b32x, SM100_TMEM_LOAD_32dp32b16x>; // 4x32 threads with 64 cols of 32b elem
typename CollectiveMmaPV::TiledMma mma;
Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));
Tensor tOcO = mma.get_slice(0).partition_C(cO);
Tensor tOsO = mma.get_slice(0).partition_C(sO);
Tensor tOtO_i = logical_divide(tOtO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
Tensor tOcO_i = logical_divide(tOcO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
Tensor tOsO_i = logical_divide(tOsO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
@ -809,7 +809,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i(make_coord(_, _), _0{}));
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i(make_coord(_, _), _));
Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i(make_coord(_, _), _));
Tensor tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i(make_coord(_, _), _));
@ -824,9 +824,9 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i);
Tensor tTMrO = make_tensor<ElementPV>(shape(tTMEM_LOADcO(_, _0{}, _0{}, i)));
copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO);
#ifndef ONLY_SOFTMAX
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size(tTMrO); j += 2) {
@ -872,24 +872,24 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
// good values would be either 32 or 64
const int kCorrectionTileSize = 16;
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem
using TMEM_STORE = SM100_TMEM_STORE_32dp32b16x; // 4x32 threads with 64 cols of 32b elem
typename CollectiveMmaPV::TiledMma mma;
Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));
Tensor tOcO = mma.get_slice(0).partition_C(cO);
Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
tOtO_i.data() = tOtO_i.data().get() + tmem_O;
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i);
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i);
auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx);
Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i);
Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i);
Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i);
@ -899,7 +899,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
float2 scale_f32x2 = make_float2(scale, scale);
Tensor tTMrO = make_tensor<ElementPV>(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{}));
auto copy_in = [&](int i) {
Tensor tTMEM_LOADtO_i = tTMEM_LOADtO;
tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize);
@ -942,7 +942,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
}
}
template<class BlkCoord, class ProblemShape, class TensorStorageEpi, class CollectiveEpilogue>
template<class BlkCoord, class ProblemShape, class TensorStorageEpi>
CUTLASS_DEVICE auto
correction(
BlkCoord const& blk_coord,
@ -951,8 +951,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state,
PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state,
PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state,
PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state,
CollectiveEpilogue& epilogue) {
PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state) {
int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape);
@ -962,7 +961,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{}));
Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS);
Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{})));
Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));
@ -1061,25 +1060,13 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
// F2FP
// store to smem
Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{});
Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), repeat_like(typename CollectiveEpilogue::StrideLSE{}, _1{}), epilogue.params.dLSE);
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO);
if (epilogue.params.ptr_LSE != nullptr) {
int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord);
ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);
if (row_idx < get<0>(problem_shape)) {
gLSE(row_idx, get<2>(blk_coord)) = lse;
}
}
cutlass::arch::fence_view_async_tmem_load();
pipeline_o.consumer_release(pipeline_o_consumer_state);
++pipeline_o_consumer_state;
pipeline_epi.producer_commit(pipeline_epi_producer_state);
++pipeline_epi_producer_state;
@ -1096,16 +1083,6 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _1{}, sO);
if (epilogue.params.ptr_LSE != nullptr) {
int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{});
ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);
if (row_idx < get<0>(problem_shape)) {
gLSE(row_idx, get<2>(blk_coord)) = lse;
}
}
cutlass::arch::fence_view_async_tmem_load();
pipeline_o.consumer_release(pipeline_o_consumer_state);

View File

@ -50,19 +50,13 @@ namespace cutlass::fmha::device {
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
template<
class ProblemShape,
class Element,
class ElementAccumulator,
class TileShape,
class Mask
>
template<class Element, class ElementAccumulator, class TileShape, class Mask>
class Sm100FmhaBwd {
public:
/// Argument structure: User API
struct Arguments {
// Q K D HB
ProblemShape problem_shape;
cute::tuple<int, int, int, cute::tuple<int, int>> problem_size;
const Element* ptr_Q;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_Q;
@ -92,16 +86,14 @@ public:
};
using OperationSumOdO = cutlass::fmha::device::FMHA<
cutlass::fmha::kernel::FmhaKernelBwdSumOdO<ProblemShape, Element, ElementAccumulator>
cutlass::fmha::kernel::FmhaKernelBwdSumOdO<Element, ElementAccumulator>
>;
using OperationConvert = cutlass::fmha::device::FMHA<
cutlass::fmha::kernel::FmhaKernelBwdConvert<ProblemShape, Element, ElementAccumulator>
cutlass::fmha::kernel::FmhaKernelBwdConvert<Element, ElementAccumulator>
>;
using Operation = cutlass::fmha::device::FMHA<
cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized<
ProblemShape, Element, ElementAccumulator, TileShape, Mask
>
cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized<Element, ElementAccumulator, TileShape, Mask>
>;
using Kernel = typename Operation::Kernel;
@ -121,15 +113,15 @@ private:
ElementAccumulator* sum_odo = nullptr,
ElementAccumulator* scaled_lse = nullptr) {
using namespace cute;
auto [Q_, K, D, HB] = args.problem_shape;
auto [Q, K, D, HB] = args.problem_size;
auto [H, B] = HB;
D = cutlass::round_up(D, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
Q = cutlass::round_up(Q, 8); // Alignment
auto stride_sum_OdO = make_stride(_1{}, make_stride(Q, Q*H));
auto stride_scaled_lse = make_stride(_1{}, make_stride(Q, Q*H));
auto log2_e = log2f(expf(1.0f));
return typename OperationSumOdO::Arguments {
args.problem_shape,
args.problem_size,
args.ptr_O, args.stride_O,
args.ptr_dO, args.stride_dO,
sum_odo, stride_sum_OdO,
@ -141,13 +133,13 @@ private:
static typename OperationConvert::Arguments to_convert_arguments(Arguments const& args, ElementAccumulator* src = nullptr) {
using namespace cute;
auto [Q_, K, D, HB] = args.problem_shape;
auto [Q, K, D, HB] = args.problem_size;
auto [H, B] = HB;
D = cutlass::round_up(D, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
Q = cutlass::round_up(Q, 8); // Alignment
auto stride_src_dQ = make_stride(D, _1{}, make_stride(D*Q, D*Q*H));
return typename OperationConvert::Arguments {
args.problem_shape,
args.problem_size,
src, stride_src_dQ,
nullptr, stride_src_dQ,
nullptr, stride_src_dQ,
@ -164,7 +156,7 @@ private:
ElementAccumulator* scaled_lse = nullptr, cute::tuple<cute::_1, cute::tuple<int, int>> const& stride_scaled_lse = {},
ElementAccumulator* dQ_acc = nullptr, cute::tuple<int, cute::_1, cute::tuple<int, int>> const& stride_dQ = {}) {
return typename Operation::Arguments{
args.problem_shape,
args.problem_size,
{ args.ptr_Q, args.stride_Q,
args.ptr_K, args.stride_K,
args.ptr_V, args.stride_V,
@ -207,10 +199,10 @@ public:
/// Gets the workspace size
static size_t
get_workspace_size(Arguments const& args) {
auto [Q_, K, D, HB] = args.problem_shape;
auto [Q, K, D, HB] = args.problem_size;
auto [H, B] = HB;
D = cutlass::round_up(D, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
Q = cutlass::round_up(Q, 8); // Alignment
size_t workspace_bytes = 0;
// OdO vector
workspace_bytes += B*H*Q * sizeof(ElementAccumulator);
@ -227,10 +219,10 @@ public:
CUTLASS_TRACE_HOST("Universal::initialize_split() - workspace_dQ="
<< workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null"));
auto [Q_, K, D, HB] = args.problem_shape;
auto [Q, K, D, HB] = args.problem_size;
auto [H, B] = HB;
D = cutlass::round_up(D, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
Q = cutlass::round_up(Q, 8); // Alignment
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_sum_OdO);
ElementAccumulator* scaled_lse = reinterpret_cast<ElementAccumulator*>(workspace_scaled_lse);
ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_dQ);
@ -256,10 +248,10 @@ public:
CUTLASS_TRACE_HOST("Universal::initialize() - workspace "
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
auto [Q_, K, D, HB] = args.problem_shape;
auto [Q, K, D, HB] = args.problem_size;
auto [H, B] = HB;
D = cutlass::round_up(D, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
Q = cutlass::round_up(Q, 8); // Alignment
char* workspace_chr = reinterpret_cast<char*>(workspace);
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_chr);
workspace_chr += B*H*Q * sizeof(ElementAccumulator);

View File

@ -39,11 +39,11 @@ namespace cutlass::fmha::kernel {
using namespace cute;
template<class ProblemShape, class Element, class ElementAcc>
template<class Element, class ElementAcc>
struct FmhaKernelBwdConvert {
struct Arguments {
ProblemShape problem_shape;
tuple<int, int, int, tuple<int, int>> problem_size;
const ElementAcc* ptr_src_dQ;
tuple<int, _1, tuple<int, int>> stride_src_dQ;
@ -85,11 +85,11 @@ struct FmhaKernelBwdConvert {
static const int kIterationsSeq = kBlockSeq / kNumThreadsSeq;
static bool can_implement(Arguments const& args) {
return get<2>(args.problem_shape) % kElementsPerLoad == 0;
return get<2>(args.problem_size) % kElementsPerLoad == 0;
}
static dim3 get_grid_shape(Params const& params) {
dim3 grid(size<3,0>(params.problem_shape), size<3,1>(params.problem_shape), ceil_div(std::max(size<0>(params.problem_shape), size<1>(params.problem_shape)), kBlockSeq));
dim3 grid(size<3,0>(params.problem_size), size<3,1>(params.problem_size), ceil_div(std::max(size<0>(params.problem_size), size<1>(params.problem_size)), kBlockSeq));
return grid;
}
@ -102,25 +102,18 @@ struct FmhaKernelBwdConvert {
return args;
}
template<class StrideSrc, class StrideDest, class Count>
CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, Count const& count) {
template<class StrideSrc, class StrideDest>
CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, int count) {
auto ptr_src_bh = ptr_src + get<2,0>(stride_src) * blockIdx.x + get<2,1>(stride_src) * blockIdx.y;
auto ptr_dest_bh = ptr_dest + get<2,0>(stride_dest) * blockIdx.x + get<2,1>(stride_dest) * blockIdx.y;
int seqlen = count;
if constexpr (is_variable_length_v<decltype(count)>) {
int offset = count.cumulative_length[blockIdx.y];
ptr_dest_bh += offset * get<0>(stride_dest);
seqlen = count.cumulative_length[blockIdx.y + 1] - offset;
}
for (int idx_s_t = threadIdx.y; idx_s_t < kBlockSeq; idx_s_t += kNumThreadsSeq) {
int idx_s = idx_s_t + kBlockSeq * blockIdx.z;
if (idx_s >= seqlen) continue;
if (idx_s >= count) continue;
auto ptr_src_bhs = ptr_src_bh + idx_s * get<0>(stride_src);
auto ptr_dest_bhs = ptr_dest_bh + idx_s * get<0>(stride_dest);
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_shape); idx_d += kElementsPerLoad * kNumThreadsD) {
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) {
ElementAcc value_src[kElementsPerLoad];
Element value_dest[kElementsPerLoad];
@ -139,13 +132,13 @@ struct FmhaKernelBwdConvert {
CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
if (params.ptr_src_dQ != nullptr) {
copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_shape));
copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_size));
}
if (params.ptr_src_dK != nullptr) {
copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<1>(params.problem_shape));
copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<1>(params.problem_size));
}
if (params.ptr_src_dV != nullptr) {
copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<1>(params.problem_shape));
copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<1>(params.problem_size));
}
}
};

View File

@ -39,11 +39,11 @@ namespace cutlass::fmha::kernel {
using namespace cute;
template<class ProblemShape, class Element, class ElementAcc>
template<class Element, class ElementAcc>
struct FmhaKernelBwdSumOdO {
struct Arguments {
ProblemShape problem_shape;
cute::tuple<int, int, int, cute::tuple<int, int>> problem_size;
const Element* ptr_O;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_O;
@ -86,11 +86,11 @@ struct FmhaKernelBwdSumOdO {
static const int kIterationsQ = kBlockQ / kNumThreadsQ;
static bool can_implement(Arguments const& args) {
return get<2>(args.problem_shape) % kElementsPerLoad == 0;
return get<2>(args.problem_size) % kElementsPerLoad == 0;
}
static dim3 get_grid_shape(Params const& params) {
dim3 grid(ceil_div(size<0>(params.problem_shape), kBlockQ), size<3,0>(params.problem_shape), size<3,1>(params.problem_shape));
dim3 grid(ceil_div(size<0>(params.problem_size), kBlockQ), size<3,0>(params.problem_size), size<3,1>(params.problem_size));
return grid;
}
@ -110,20 +110,10 @@ struct FmhaKernelBwdSumOdO {
auto ptr_lse_bh = params.ptr_lse + blockIdx.y * get<1,0>(params.stride_lse) + blockIdx.z * get<1,1>(params.stride_lse);
auto ptr_scaled_lse_bh = params.ptr_scaled_lse + blockIdx.y * get<1,0>(params.stride_scaled_lse) + blockIdx.z * get<1,1>(params.stride_scaled_lse);
auto problem_q = get<0>(params.problem_shape);
int seqlen_q = problem_q;
if constexpr (is_variable_length_v<decltype(problem_q)>) {
int offset = problem_q.cumulative_length[blockIdx.z];
ptr_O_bh += offset * get<0>(params.stride_O);
ptr_dO_bh += offset * get<0>(params.stride_dO);
ptr_lse_bh += offset * get<0>(params.stride_lse);
seqlen_q = problem_q.cumulative_length[blockIdx.z + 1] - offset;
}
CUTLASS_PRAGMA_UNROLL
for (int idx_q_t = threadIdx.y; idx_q_t < kBlockQ; idx_q_t += kNumThreadsQ) {
int idx_q = idx_q_t + kBlockQ * blockIdx.x;
if (idx_q >= seqlen_q) continue;
if (idx_q >= get<0>(params.problem_size)) continue;
ElementAcc acc = 0;
auto ptr_O_bhq = ptr_O_bh + idx_q * get<0>(params.stride_O);
auto ptr_dO_bhq = ptr_dO_bh + idx_q * get<0>(params.stride_dO);
@ -131,7 +121,7 @@ struct FmhaKernelBwdSumOdO {
auto ptr_lse_bhq = ptr_lse_bh + idx_q * get<0>(params.stride_lse);
auto ptr_scaled_lse_bhq = ptr_scaled_lse_bh + idx_q * get<0>(params.stride_scaled_lse);
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_shape); idx_d += kElementsPerLoad * kNumThreadsD) {
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) {
Element value_O[kElementsPerLoad];
Element value_dO[kElementsPerLoad];

View File

@ -82,4 +82,4 @@ struct Option {
using option_value = Value;
};
} // namespace cutlass::fmha::kernel
} // namespace cutlass::fmha::kernel

View File

@ -90,8 +90,8 @@ struct PersistentTileScheduler {
struct Params {
int num_blocks;
FastDivmod divmod_m_block;
FastDivmod divmod_b;
FastDivmod divmod_h;
FastDivmod divmod_b;
KernelHardwareInfo hw_info;
};
@ -146,7 +146,7 @@ struct PersistentTileScheduler {
params.divmod_m_block(block_decode, m_block, block_decode);
params.divmod_b(block_decode, bidb, block_decode);
params.divmod_h(block_decode, bidh, block_decode);
return make_coord(m_block, _0{}, make_coord(bidb, bidh));
return make_coord(m_block, _0{}, make_coord(bidh, bidb));
}
CUTLASS_DEVICE

View File

@ -43,8 +43,6 @@
#include "collective/fmha_common.hpp"
#include <cmath>
namespace cutlass::fmha::kernel {
using namespace cutlass::fmha::collective;
@ -52,7 +50,6 @@ using namespace cutlass::fmha::collective;
using namespace cute;
template<
class ProblemShape,
class Element,
class ElementAcc,
class TileShape,
@ -121,7 +118,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
using TensorStrideContiguousK = Stride<int, _1, Stride<int, int>>;
using TensorStrideContiguousMN = Stride<_1, int, Stride<int, int>>;
// compute S
using CollectiveMmaKQ = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
@ -277,6 +274,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
static constexpr int SharedStorageSize = offsetof(SharedStorage, tmem_base_ptr) + sizeof(uint32_t);
static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem");
using ProblemShape = Shape<int, int, int, Shape<int, int>>; // Q K D (H B), eventuall D = (D_QK, D_VO)
using TensorStride = TensorStrideContiguousK; // S D (H B)
using RowTensorStride = Stride<_1, Stride<int, int>>; // S (H B)
@ -362,16 +360,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
static Params to_underlying_arguments(Arguments const& args, void*) {
auto [Q_, K_, D, HB] = args.problem_shape;
int Q = Q_;
int K = K_;
if constexpr (is_variable_length_v<decltype(Q_)>) {
Q = Q_.total_length;
}
if constexpr (is_variable_length_v<decltype(K_)>) {
K = K_.total_length;
}
auto [Q, K, D, HB] = args.problem_shape;
auto params_kq = CollectiveMmaKQ::to_underlying_arguments(
make_shape(K, Q, D, HB),
@ -389,10 +378,10 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
TMA_DQ tma_red_dq = make_tma_copy(
SM90_TMA_REDUCE_ADD{},
make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q_, D, HB), args.mainloop.stride_dq_acc),
make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q, D, HB), args.mainloop.stride_dq_acc),
SmemLayoutDQ{}(_, _, _0{})
);
return Params{
args.problem_shape,
args.mainloop,
@ -427,11 +416,10 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
}
template<class BlkCoord, class BlkOffset, class ProblemShape_>
template<class BlkCoord>
CUTLASS_DEVICE void load(
BlkCoord const& blk_coord,
BlkOffset const& blk_offset,
ProblemShape_ const& problem_shape,
ProblemShape const& problem_shape,
int iter_index,
int iter_count,
MainloopArguments const& mainloop_args,
@ -452,15 +440,10 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
uint16_t mcast_mask = 0;
auto mK_in = mainloop_params.tma_load_k.get_tma_tensor(make_shape(K, D, HB));
auto mV_in = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D, HB));
auto mQ_in = mainloop_params.tma_load_q.get_tma_tensor(make_shape(Q, D, HB));
auto mDO_in = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D, HB));
auto mK = domain_offset(select<1,2,3>(blk_offset), mK_in);
auto mV = domain_offset(select<1,2,3>(blk_offset), mV_in);
auto mQ = domain_offset(select<0,2,3>(blk_offset), mQ_in);
auto mDO = domain_offset(select<0,2,3>(blk_offset), mDO_in);
auto mK = mainloop_params.tma_load_k.get_tma_tensor(make_shape(K, D, HB));
auto mQ = mainloop_params.tma_load_q.get_tma_tensor(make_shape(Q, D, HB));
auto mV = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D, HB));
auto mDO = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D, HB));
auto gK = local_tile(mK, TileShapeKQ{}, make_coord(_,_,_), Step<_1, X, _1>{});
auto gQ = local_tile(mQ, TileShapeKQ{}, make_coord(_,_,_), Step<X, _1, _1>{});
@ -469,7 +452,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
ThrMMA cta_mma_kq = TiledMmaKQ{}.get_slice(_0{});
ThrMMA cta_mma_vdo = TiledMmaVDO{}.get_slice(_0{});
auto tSTgK = cta_mma_kq.partition_A(gK);
auto tSTgQ = cta_mma_kq.partition_B(gQ);
auto tDPTgV = cta_mma_vdo.partition_A(gV);
@ -494,8 +477,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
group_modes<0,3>(sDO), group_modes<0,3>(tDPTgDO));
// set up lse and sum_odo
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch] = blk_coord;
auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord;
pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state);
auto tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state);
@ -512,7 +495,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
}
// load Q
if (cute::elect_one_sync()) {
if (cute::elect_one_sync()) {
cute::copy(
mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask),
tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch),
@ -532,14 +515,12 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
int smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4;
int gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
auto mLSE = make_tensor(mainloop_args.ptr_lse, make_shape(Q, HB), mainloop_args.stride_lse);
for (int i = 0; i < 4; i++) {
cutlass::arch::cp_async_zfill<4>(
shared_tensors.smem_lse.begin() + smem_idx + i,
&mLSE(gmem_idx + i, blk_coord_batch),
gmem_idx + i < Q
);
}
cutlass::arch::cp_async_zfill<16>(
shared_tensors.smem_lse.begin() + smem_idx,
&mLSE(gmem_idx, blk_coord_batch),
gmem_idx < Q
);
pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive);
++pipeline_load_compute_lse_producer_state;
@ -548,7 +529,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state);
pipeline_load_mma_do.producer_expect_transaction(pipeline_load_mma_do_producer_state, kTransactionsBytesLoadV);
// load V
if (cute::elect_one_sync()) {
cute::copy(
@ -559,7 +540,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
}
// load dO
if (cute::elect_one_sync()) {
if (cute::elect_one_sync()) {
cute::copy(
mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask),
tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch),
@ -575,13 +556,11 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4;
gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
auto mSumOdO = make_tensor(mainloop_args.ptr_sum_odo, make_shape(Q, HB), mainloop_args.stride_sum_odo);
for (int i = 0; i < 4; i++) {
cutlass::arch::cp_async_zfill<4>(
shared_tensors.smem_sum_odo.begin() + smem_idx + i,
&mSumOdO(gmem_idx + i, blk_coord_batch),
gmem_idx + i < Q
);
}
cutlass::arch::cp_async<16>(
shared_tensors.smem_sum_odo.begin() + smem_idx,
&mSumOdO(gmem_idx, blk_coord_batch),
gmem_idx < Q
);
pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive);
++pipeline_load_compute_sum_odo_producer_state;
@ -594,7 +573,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state);
// load Q
if (cute::elect_one_sync()) {
if (cute::elect_one_sync()) {
cute::copy(
mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask),
tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch),
@ -605,26 +584,24 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
++pipeline_load_mma_q_producer_state;
pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state);
// load LSE
smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4;
gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
for (int i = 0; i < 4; i++) {
cutlass::arch::cp_async_zfill<4>(
shared_tensors.smem_lse.begin() + smem_idx + i,
&mLSE(gmem_idx + i, blk_coord_batch),
gmem_idx + i < Q
);
}
cutlass::arch::cp_async<16>(
shared_tensors.smem_lse.begin() + smem_idx,
&mLSE(gmem_idx, blk_coord_batch),
gmem_idx < Q
);
pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive);
++pipeline_load_compute_lse_producer_state;
pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state);
tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state);
// load dO
if (cute::elect_one_sync()) {
// load dO
if (cute::elect_one_sync()) {
cute::copy(
mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask),
tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch),
@ -635,18 +612,16 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
++pipeline_load_mma_do_producer_state;
pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state);
// load sum_OdO
smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4;
gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
for (int i = 0; i < 4; i++) {
cutlass::arch::cp_async_zfill<4>(
shared_tensors.smem_sum_odo.begin() + smem_idx + i,
&mSumOdO(gmem_idx + i, blk_coord_batch),
gmem_idx + i < Q
);
}
cutlass::arch::cp_async_zfill<16>(
shared_tensors.smem_sum_odo.begin() + smem_idx,
&mSumOdO(gmem_idx, blk_coord_batch),
gmem_idx < Q
);
pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive);
++pipeline_load_compute_sum_odo_producer_state;
@ -656,31 +631,31 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
}
template<class BlkCoord, class ProblemShape_>
template<class BlkCoord>
CUTLASS_DEVICE void mma(
BlkCoord const& blk_coord,
ProblemShape_ const& problem_shape,
ProblemShape const& problem_shape,
int iter_index,
int iter_count,
MainloopArguments const& mainloop_args,
TensorStorage& shared_tensors,
PipelineLoadMmaQ& pipeline_load_mma_q,
typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state,
PipelineLoadMmaDO& pipeline_load_mma_do,
PipelineLoadMmaQ& pipeline_load_mma_q,
typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state,
PipelineLoadMmaDO& pipeline_load_mma_do,
typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_consumer_state,
PipelineMmaComputeS& pipeline_mma_compute_s,
PipelineMmaComputeS& pipeline_mma_compute_s,
typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_producer_state,
PipelineMmaComputeDP& pipeline_mma_compute_dp,
PipelineMmaComputeDP& pipeline_mma_compute_dp,
typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_producer_state,
PipelineMmaReduceDQ& pipeline_mma_reduce_dq,
PipelineMmaReduceDQ& pipeline_mma_reduce_dq,
typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_producer_state,
PipelineComputeMmaP& pipeline_compute_mma_p,
PipelineComputeMmaP& pipeline_compute_mma_p,
typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_consumer_state,
PipelineComputeMmaDS& pipeline_compute_mma_ds,
PipelineComputeMmaDS& pipeline_compute_mma_ds,
typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_consumer_state,
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) {
auto [Q, K, D, HB] = problem_shape;
auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{});
@ -710,7 +685,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
Tensor tDVrP = TiledMmaPDO::make_fragment_A(sP)(_, _, _, _0{});
tDVrP.data() = TmemAllocation::kP;
Tensor tDVrDOT = TiledMmaPDO::make_fragment_B(sDOT);
TiledMmaKQ tiled_mma_kq;
TiledMmaVDO tiled_mma_vdo;
TiledMmaDSK tiled_mma_dsk;
@ -948,8 +923,6 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
TensorC const& coord,
TensorShape const& tensor_shape) {
Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); });
auto copy_op = make_cotiled_copy(
Copy_Atom<UniversalCopy<uint128_t>, Element>{},
make_layout(make_shape(_1{}, Int<sizeof(uint128_t) / sizeof(Element)>{})),
@ -957,91 +930,42 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
);
auto thr_copy = copy_op.get_slice(_0{});
Tensor tCg = thr_copy.partition_D(gmem);
Tensor tCr = thr_copy.partition_S(quantize(regs));
Tensor tPc = thr_copy.partition_D(preds);
auto tCg = thr_copy.partition_D(gmem);
auto tCr = thr_copy.partition_S(quantize(regs));
auto tCc = thr_copy.partition_D(coord);
copy_if(copy_op, tPc, tCr, tCg);
}
constexpr int R = decltype(tCr.layout())::rank;
auto tCg_v = group_modes<1, R>(tCg);
auto tCr_v = group_modes<1, R>(tCr);
auto tCc_v = group_modes<1, R>(tCc);
auto tCp_v = make_tensor<bool>(shape<1>(tCc_v));
template<class BlkCoord, class BlkOffset, class ProblemShape_>
CUTLASS_DEVICE void epilogue_clear(
BlkCoord const& blk_coord,
BlkOffset const& blk_offset,
ProblemShape_ const& problem_shape,
MainloopArguments const& mainloop_args,
EpilogueArguments const& epilogue_args) {
auto [Q, K, D, HB] = problem_shape;
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch] = blk_coord;
auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk);
auto mDK = domain_offset(select<1,2,3>(blk_offset), mDK_in);
auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{})
(_, _, blk_coord_k, _0{}, blk_coord_batch);
Tensor cDK = domain_offset(
make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}),
make_identity_tensor(take<0,2>(TileShapeDSQ{}))
);
auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv);
auto mDV = domain_offset(select<1,2,3>(blk_offset), mDV_in);
auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{})
(_, _, blk_coord_k, _0{}, blk_coord_batch);
Tensor cDV = domain_offset(
make_coord(blk_coord_k * TileShapeK{}, _0{}),
make_identity_tensor(take<0,2>(TileShapePDO{}))
);
if (threadIdx.x >= 256) {
return;
for (int i = 0; i < size(tCp_v); ++i) {
tCp_v(i) = elem_less(tCc_v(_0{},i), tensor_shape);
}
auto tiled_copy = make_cotiled_copy(
Copy_Atom<UniversalCopy<uint128_t>, Element>{},
make_ordered_layout(make_shape(_256{}, Int<sizeof(uint128_t) / sizeof(Element)>{}), Step<_1, _0>{}),
make_ordered_layout(make_shape(TileShapeK{}, TileShapeDQK{}), Step<_1, _0>{}));
auto thr_copy = tiled_copy.get_slice(threadIdx.x);
auto tCgDK = thr_copy.partition_D(gDK);
auto tCcDK = thr_copy.partition_S(cDK);
auto tCrDK = make_tensor<Element>(shape(tCcDK));
clear(tCrDK);
store(tCgDK, tCrDK, tCcDK, select<1,2>(problem_shape));
auto tCgDV = thr_copy.partition_D(gDV);
auto tCcDV = thr_copy.partition_S(cDV);
auto tCrDV = make_tensor<Element>(shape(tCcDV));
clear(tCrDV);
store(tCgDV, tCrDV, tCcDV, select<1,2>(problem_shape));
copy_if(copy_op, tCp_v, tCr_v, tCg_v);
}
template<class BlkCoord, class BlkOffset, class ProblemShape_>
template<class BlkCoord>
CUTLASS_DEVICE void epilogue(
BlkCoord const& blk_coord,
BlkOffset const& blk_offset,
ProblemShape_ const& problem_shape,
ProblemShape const& problem_shape,
MainloopArguments const& mainloop_args,
EpilogueArguments const& epilogue_args,
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) {
auto [Q, K, D, HB] = problem_shape;
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch] = blk_coord;
auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord;
auto load_op = SM100_TMEM_LOAD_32dp32b16x{};
auto tDKtDK = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{});
tDKtDK.data() = TmemAllocation::kDK;
auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk);
auto mDK = domain_offset(select<1,2,3>(blk_offset), mDK_in);
auto mDK = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk);
auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{})
(_, _, blk_coord_k, _0{}, blk_coord_batch);
@ -1076,13 +1000,12 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
auto tDVtDV = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{});
tDVtDV.data() = TmemAllocation::kDV;
auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv);
auto mDV = domain_offset(select<1,2,3>(blk_offset), mDV_in);
auto mDV = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv);
auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{})
(_, _, blk_coord_k, _0{}, blk_coord_batch);
Tensor cDV = domain_offset(
make_coord(blk_coord_k * TileShapeK{}, _0{}),
make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}),
make_identity_tensor(take<0,2>(TileShapePDO{}))
);
@ -1126,11 +1049,10 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
}
template<class BlkCoord, class BlkOffset, class ProblemShape_>
template<class BlkCoord>
CUTLASS_DEVICE void compute(
BlkCoord const& blk_coord,
BlkOffset const& blk_offset,
ProblemShape_ const& problem_shape,
ProblemShape const& problem_shape,
int iter_index,
int iter_count,
MainloopArguments const& mainloop_args,
@ -1151,7 +1073,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) {
auto [Q, K, D, HB] = problem_shape;
// in tmem, S & P overlap
@ -1192,7 +1114,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
Tensor tTR_cST = split_wg(thread_t2r.partition_D(cST));
Tensor tTR_rST = make_tensor<ElementAcc>(shape(tTR_cST));
Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST));
Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT);
Tensor tTR_cDPT = split_wg(tTR_cDPT_p);
Tensor tTR_rDPT = make_tensor<ElementAcc>(shape(tTR_cDPT));
@ -1214,9 +1136,6 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
auto tRT_tP = split_wg(thread_r2t.partition_D(tDVrP));
auto tRT_cST = split_wg(thread_r2t.partition_S(tDVcST));
bool is_residual_k = get<1>(blk_coord) * TileShapeK{} + TileShapeK{} >= get<1>(problem_shape);
int last_iter = iter_count - 1 + iter_index;
CUTLASS_PRAGMA_NO_UNROLL
while (iter_count > 0) {
// wait for S and P
@ -1233,28 +1152,20 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
fn(cute::false_type{});
}
};
bool leading_causal_masking = false;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask, Mask>) {
leading_causal_masking = warp_uniform(iter_index == get<1>(blk_coord));
}
bool trailing_residual_masking = false;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::ResidualMaskForBackward, Mask>) {
trailing_residual_masking = warp_uniform((iter_index == last_iter) || is_residual_k);
}
dispatch_bool(leading_causal_masking || trailing_residual_masking, [&](auto is_masked_tile) {
dispatch_bool(std::is_base_of_v<cutlass::fmha::collective::CausalMask, Mask> &&
warp_uniform(iter_index == get<1>(blk_coord)), [&](auto is_causal_masked_tile) {
// compute P = softmax(S, LSE)
cute::copy(tiled_t2r, tTR_tST, tTR_rST);
if constexpr (decltype(is_masked_tile)::value) {
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask, Mask> && decltype(is_causal_masked_tile)::value) {
Mask{}.apply_mask(tTR_rST, [&](int i) {
auto c_transpose = tTR_cST(i);
return make_coord(get<1>(c_transpose) + iter_index * TileShapeQ{}, get<0>(c_transpose) + get<1>(blk_coord) * TileShapeK{});
}, problem_shape);
}
ElementAcc log2_e = static_cast<ElementAcc>(M_LOG2E);
float2 softmax_scale_log2_e;
softmax_scale_log2_e.x = mainloop_args.softmax_scale * log2_e;
@ -1273,16 +1184,16 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
tTR_rST(i) = ::exp2f(out.x);
tTR_rST(i+1) = ::exp2f(out.y);
}
auto tRT_rST = quantize(tTR_rST);
auto tRT_rST_reshaped = make_tensor(tRT_rST.data(), shape(tRT_cST));
cutlass::arch::fence_view_async_tmem_load();
cutlass::arch::NamedBarrier(
kNumComputeWarps * NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::TransformBarrier
).arrive_and_wait();
cute::copy(tiled_r2t, tRT_rST_reshaped, tRT_tP);
});
@ -1364,15 +1275,15 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
}
epilogue(
blk_coord, blk_offset, problem_shape, mainloop_args, epilogue_args,
blk_coord, problem_shape, mainloop_args, epilogue_args,
pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state
);
}
template<class BlkCoord, class ProblemShape_>
template<class BlkCoord>
CUTLASS_DEVICE void reduce(
BlkCoord const& blk_coord,
ProblemShape_ const& problem_shape,
ProblemShape const& problem_shape,
int iter_index,
int iter_count,
MainloopArguments const& mainloop_args,
@ -1382,12 +1293,12 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state,
PipelineReduceTmaStore& pipeline_reduce_tma_store,
typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state) {
using X = Underscore;
auto [Q, K, D, HB] = problem_shape;
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch] = blk_coord;
auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord;
// must match TileShapeDQ
auto load_op = SM100_TMEM_LOAD_32dp32b32x{};
@ -1396,7 +1307,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
tDQtDQ.data() = TmemAllocation::kDQ;
Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB));
auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step<X, _1, _1>{})
auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step<_1, _1, X>{})
(_, _, _, _0{}, blk_coord_batch);
Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{}));
@ -1465,7 +1376,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
iter_index += 1;
}
}
CUTLASS_DEVICE void operator()(Params const& params, char* smem) {
int warp_idx = cutlass::canonical_warp_idx_sync();
@ -1650,7 +1561,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
typename decltype(pipeline_compute_mma_p)::PipelineState pipeline_compute_mma_p_consumer_state;
typename decltype(pipeline_compute_mma_ds)::PipelineState pipeline_compute_mma_ds_consumer_state;
typename decltype(pipeline_mma_compute_dkdv)::PipelineState pipeline_mma_compute_dkdv_consumer_state;
auto pipeline_load_mma_q_producer_state = make_producer_start_state<decltype(pipeline_load_mma_q)>();
auto pipeline_load_mma_do_producer_state = make_producer_start_state<decltype(pipeline_load_mma_do)>();
auto pipeline_load_compute_lse_producer_state = make_producer_start_state<decltype(pipeline_load_compute_lse)>();
@ -1665,45 +1576,27 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
pipeline_init_wait(size(ClusterShape{}));
auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, make_coord(blockIdx.y, blockIdx.z));
auto [problem_shape, blk_offset] = apply_variable_length_offset(
params.problem_shape,
blk_coord
);
auto blk_coord = make_coord(_0{}, blockIdx.x, make_coord(blockIdx.y, blockIdx.z));
auto problem_shape = params.problem_shape;
int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{});
int iter_start = 0;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask, Mask>) {
iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{};
}
if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) {
return;
}
iter_count -= iter_start;
if (iter_count <= 0) {
epilogue_clear(
blk_coord,
blk_offset,
problem_shape,
params.mainloop,
params.epilogue
);
return;
}
if (role == WarpRole::Load) {
warpgroup_reg_set<RegisterAllocation::kLoad>();
load(
blk_coord,
blk_offset,
problem_shape,
iter_start,
iter_count,
params.mainloop,
params.mainloop_params,
shared_storage.tensors,
pipeline_load_mma_q, pipeline_load_mma_q_producer_state,
pipeline_load_mma_q, pipeline_load_mma_q_producer_state,
pipeline_load_mma_do, pipeline_load_mma_do_producer_state,
pipeline_load_compute_lse, pipeline_load_compute_lse_producer_state,
pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_producer_state
@ -1715,7 +1608,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr);
__syncwarp();
mma(
blk_coord,
problem_shape,
@ -1723,7 +1616,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
iter_count,
params.mainloop,
shared_storage.tensors,
pipeline_load_mma_q, pipeline_load_mma_q_consumer_state,
pipeline_load_mma_q, pipeline_load_mma_q_consumer_state,
pipeline_load_mma_do, pipeline_load_mma_do_consumer_state,
pipeline_mma_compute_s, pipeline_mma_compute_s_producer_state,
pipeline_mma_compute_dp, pipeline_mma_compute_dp_producer_state,
@ -1736,10 +1629,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
}
else if (role == WarpRole::Compute) {
warpgroup_reg_set<RegisterAllocation::kCompute>();
compute(
blk_coord,
blk_offset,
problem_shape,
iter_start,
iter_count,
@ -1768,7 +1660,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
}
else if (role == WarpRole::Reduce) {
warpgroup_reg_set<RegisterAllocation::kReduce>();
reduce(
blk_coord,
problem_shape,
@ -1785,9 +1677,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
}
else {
warpgroup_reg_set<RegisterAllocation::kEmpty>();
/* no-op */
}
}

View File

@ -356,7 +356,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineO>();
CollectiveMainloop mainloop;
CollectiveEpilogue epilogue{params.epilogue};
CollectiveEpilogue epilogue;
if (role == WarpRole::Softmax0 || role == WarpRole::Softmax1) {
warpgroup_reg_set<NumRegsSoftmax>();
@ -407,8 +407,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
pipeline_s0_corr, pipeline_s0_corr_consumer_state,
pipeline_s1_corr, pipeline_s1_corr_consumer_state,
pipeline_mma_corr, pipeline_mma_corr_consumer_state,
pipeline_corr_epi, pipeline_corr_epi_producer_state,
epilogue
pipeline_corr_epi, pipeline_corr_epi_producer_state
);

View File

@ -146,7 +146,7 @@ struct Sm100FmhaMlaReductionKernel {
ElementAcc sum_lse = 0;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNLsePerThread; ++i) {
sum_lse = sum_lse + expf(local_lse[i] - lse_max);
sum_lse = sum_lse + expf(local_lse[i] - params.scale * lse_max);
}
CUTLASS_PRAGMA_UNROLL
@ -156,7 +156,7 @@ struct Sm100FmhaMlaReductionKernel {
sum_lse = __shfl_sync(0xffffffff, sum_lse, 0);
ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementAcc>::infinity() : logf(sum_lse) + lse_max;
ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementAcc>::infinity() : logf(sum_lse) + params.scale * lse_max;
if (threadIdx.x == 0 and params.ptr_lse != nullptr) {
gLSE(0) = global_lse;
}

View File

@ -44,10 +44,10 @@ template<
class Fusion
>
void __global__ fmha_bwd_reference_dQ_kernel(
ProblemShape problem_shape_in,
TensorQ mQ_in, TensorK mK_in, TensorV mV_in,
TensorO mO_in, TensorLSE mLSE_in, TensorDO mDO_in,
TensorDQ mDQ_in, /* TensorDK mDK, TensorDV mDV, */
ProblemShape problem_shape,
TensorQ mQ, TensorK mK, TensorV mV,
TensorO mO, TensorLSE mLSE, TensorDO mDO,
TensorDQ mDQ, /* TensorDK mDK, TensorDV mDV, */
Fusion fusion) {
using namespace cute;
@ -58,28 +58,15 @@ void __global__ fmha_bwd_reference_dQ_kernel(
extern __shared__ char mS_mem[];
Element* mS = reinterpret_cast<Element*>(mS_mem);
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<2>(problem_shape_in)));
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<1>(mO)));
for (int idx_L = blockIdx.y; idx_L < size<3>(problem_shape_in); idx_L += gridDim.y) {
auto [problem_shape, offset] = apply_variable_length_offset(
problem_shape_in,
make_coord(_0{}, _0{}, _0{}, idx2crd(idx_L, get<3>(problem_shape_in)))
);
// problem_shape = problem_shape_in;
// offset = repeat_like(problem_shape_in, _0{});
auto mQ = domain_offset(select<0,2,3>(offset), mQ_in);
auto mK = domain_offset(select<1,2,3>(offset), mK_in);
auto mV = domain_offset(select<1,2,3>(offset), mV_in);
auto mO = domain_offset(select<0,2,3>(offset), mO_in);
auto mLSE = domain_offset(select<0,3>(offset), mLSE_in);
auto mDO = domain_offset(select<0,2,3>(offset), mDO_in);
auto mDQ = domain_offset(select<0,2,3>(offset), mDQ_in);
for (int idx_Q = blockIdx.x; idx_Q < size<0>(problem_shape); idx_Q += gridDim.x) {
for (int idx_K = threadIdx.x; idx_K < size<1>(problem_shape); idx_K += blockDim.x) {
for (int idx_L = blockIdx.y; idx_L < size<2>(mDQ); idx_L += gridDim.y) {
for (int idx_Q = blockIdx.x; idx_Q < size<0>(mDQ); idx_Q += gridDim.x) {
for (int idx_K = threadIdx.x; idx_K < size<0>(mK); idx_K += blockDim.x) {
ElementAccumulator acc_qk = 0;
ElementAccumulator acc_dov = 0;
ElementAccumulator acc_doo = 0;
for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) {
for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) {
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L);
@ -96,9 +83,9 @@ void __global__ fmha_bwd_reference_dQ_kernel(
__syncthreads();
for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) {
for (int idx_D = threadIdx.x; idx_D < size<1>(mDQ); idx_D += blockDim.x) {
ElementAccumulator acc = 0;
for (int idx_K = 0; idx_K < size<1>(problem_shape); idx_K++) {
for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) {
acc += mS[idx_K] * mK(idx_K, idx_D, idx_L);
}
mDQ(idx_Q, idx_D, idx_L) = static_cast<typename TensorDQ::value_type>(acc);
@ -117,10 +104,10 @@ template<
class Fusion
>
void __global__ fmha_bwd_reference_dK_kernel(
ProblemShape problem_shape_in,
TensorQ mQ_in, TensorK mK_in, TensorV mV_in,
TensorO mO_in, TensorLSE mLSE_in, TensorDO mDO_in,
/* TensorDQ mDQ_in, */ TensorDK mDK_in, /* TensorDV mDV_in, */
ProblemShape problem_shape,
TensorQ mQ, TensorK mK, TensorV mV,
TensorO mO, TensorLSE mLSE, TensorDO mDO,
/* TensorDQ mDQ, */ TensorDK mDK, /* TensorDV mDV, */
Fusion fusion) {
using namespace cute;
@ -131,28 +118,15 @@ void __global__ fmha_bwd_reference_dK_kernel(
extern __shared__ char mS_mem[];
Element* mS = reinterpret_cast<Element*>(mS_mem);
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<2>(problem_shape_in)));
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<1>(mO)));
for (int idx_L = blockIdx.y; idx_L < size<3>(problem_shape_in); idx_L += gridDim.y) {
auto [problem_shape, offset] = apply_variable_length_offset(
problem_shape_in,
make_coord(_0{}, _0{}, _0{}, idx2crd(idx_L, get<3>(problem_shape_in)))
);
// problem_shape = problem_shape_in;
// offset = repeat_like(problem_shape_in, _0{});
auto mQ = domain_offset(select<0,2,3>(offset), mQ_in);
auto mK = domain_offset(select<1,2,3>(offset), mK_in);
auto mV = domain_offset(select<1,2,3>(offset), mV_in);
auto mO = domain_offset(select<0,2,3>(offset), mO_in);
auto mLSE = domain_offset(select<0,3>(offset), mLSE_in);
auto mDO = domain_offset(select<0,2,3>(offset), mDO_in);
auto mDK = domain_offset(select<1,2,3>(offset), mDK_in);
for (int idx_K = blockIdx.x; idx_K < size<1>(problem_shape); idx_K += gridDim.x) {
for (int idx_Q = threadIdx.x; idx_Q < size<0>(problem_shape); idx_Q += blockDim.x) {
for (int idx_L = blockIdx.y; idx_L < size<2>(mDK); idx_L += gridDim.y) {
for (int idx_K = blockIdx.x; idx_K < size<0>(mDK); idx_K += gridDim.x) {
for (int idx_Q = threadIdx.x; idx_Q < size<0>(mDO); idx_Q += blockDim.x) {
ElementAccumulator acc_qk = 0;
ElementAccumulator acc_dov = 0;
ElementAccumulator acc_doo = 0;
for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) {
for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) {
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L);
@ -169,9 +143,9 @@ void __global__ fmha_bwd_reference_dK_kernel(
__syncthreads();
for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) {
for (int idx_D = threadIdx.x; idx_D < size<1>(mDK); idx_D += blockDim.x) {
ElementAccumulator acc = 0;
for (int idx_Q = 0; idx_Q < size<0>(problem_shape); idx_Q++) {
for (int idx_Q = 0; idx_Q < size<0>(mDO); idx_Q++) {
acc += mS[idx_Q] * mQ(idx_Q, idx_D, idx_L);
}
mDK(idx_K, idx_D, idx_L) = static_cast<typename TensorDK::value_type>(acc);
@ -190,10 +164,10 @@ template<
class Fusion
>
void __global__ fmha_bwd_reference_dV_kernel(
ProblemShape problem_shape_in,
TensorQ mQ_in, TensorK mK_in, TensorV mV_in,
TensorO mO_in, TensorLSE mLSE_in, TensorDO mDO_in,
/* TensorDQ mDQ_in, TensorDK mDK_in, */ TensorDV mDV_in,
ProblemShape problem_shape,
TensorQ mQ, TensorK mK, TensorV mV,
TensorO mO, TensorLSE mLSE, TensorDO mDO,
/* TensorDQ mDQ, TensorDK mDK, */ TensorDV mDV,
Fusion fusion) {
using namespace cute;
@ -204,27 +178,14 @@ void __global__ fmha_bwd_reference_dV_kernel(
extern __shared__ char mS_mem[];
Element* mS = reinterpret_cast<Element*>(mS_mem);
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<2>(problem_shape_in)));
ElementAcc softmax_scale = static_cast<ElementAcc>(1.0 / sqrt(1.0 * size<1>(mO)));
for (int idx_L = blockIdx.y; idx_L < size<3>(problem_shape_in); idx_L += gridDim.y) {
auto [problem_shape, offset] = apply_variable_length_offset(
problem_shape_in,
make_coord(_0{}, _0{}, _0{}, idx2crd(idx_L, get<3>(problem_shape_in)))
);
// problem_shape = problem_shape_in;
// offset = repeat_like(problem_shape_in, _0{});
auto mQ = domain_offset(select<0,2,3>(offset), mQ_in);
auto mK = domain_offset(select<1,2,3>(offset), mK_in);
auto mV = domain_offset(select<1,2,3>(offset), mV_in);
auto mO = domain_offset(select<0,2,3>(offset), mO_in);
auto mLSE = domain_offset(select<0,3>(offset), mLSE_in);
auto mDO = domain_offset(select<0,2,3>(offset), mDO_in);
auto mDV = domain_offset(select<1,2,3>(offset), mDV_in);
for (int idx_K = blockIdx.x; idx_K < size<1>(problem_shape); idx_K += gridDim.x) {
for (int idx_Q = threadIdx.x; idx_Q < size<0>(problem_shape); idx_Q += blockDim.x) {
for (int idx_L = blockIdx.y; idx_L < size<2>(mDV); idx_L += gridDim.y) {
for (int idx_K = blockIdx.x; idx_K < size<0>(mDV); idx_K += gridDim.x) {
for (int idx_Q = threadIdx.x; idx_Q < size<0>(mDO); idx_Q += blockDim.x) {
ElementAcc acc_qk = 0;
for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) {
for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) {
ElementAcc rQ = mQ(idx_Q, idx_D0, idx_L);
ElementAcc rK = mK(idx_K, idx_D0, idx_L);
acc_qk += rQ * rK;
@ -241,9 +202,9 @@ void __global__ fmha_bwd_reference_dV_kernel(
__syncthreads();
for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) {
for (int idx_D = threadIdx.x; idx_D < size<1>(mDV); idx_D += blockDim.x) {
ElementAcc acc = 0;
for (int idx_Q = 0; idx_Q < size<0>(problem_shape); idx_Q++) {
for (int idx_Q = 0; idx_Q < size<0>(mDO); idx_Q++) {
ElementAcc rS = mS[idx_Q];
ElementAcc rDO = mDO(idx_Q, idx_D, idx_L);
acc += rS * rDO;

View File

@ -127,7 +127,7 @@ void __global__ fmha_reference_kernel(
mO(idx_Q + offset_Q, idx_D, idx_L) = static_cast<typename TensorO::value_type>(acc * scale);
}
if (threadIdx.x == 0 && mLSE.data() != nullptr) {
if (threadIdx.x == 0) {
mLSE(idx_Q + offset_Q, idx_L) = log(sum) + softmax_scale * maxS;
}

View File

@ -75,8 +75,6 @@ struct DeviceAllocation {
size_t size() const { return size_; }
size_t get_storage_size() const { return (size_ + offset_) * sizeof(T); }
void copy_from_host(const T* ptr, size_t sz) {
auto ret = cudaMemcpy(ptr_, ptr, sz * sizeof(T), cudaMemcpyDefault);
assert(ret == cudaSuccess);