Revert "[ex77] fix mla split; add fwd lse; add bwd varlen (#2366)" (#2370)

This reverts commit f12b1d75c9.
This commit is contained in:
Manish Gupta
2025-06-05 20:14:57 -07:00
committed by GitHub
parent f12b1d75c9
commit 2e2af190bd
19 changed files with 326 additions and 846 deletions

View File

@ -117,17 +117,15 @@ struct Options {
int q = 256;
int k = 256;
int d = 128;
int warmup_iterations = 1;
int iterations = 3;
int tensor_ring_buffers = 1;
bool verify = false;
bool verbose = false;
bool causal = false;
bool residual = false;
bool varlen = false;
bool persistent = false;
int sm_count = 0;
std::string kernel_filter;
InitStyle init_style_q = InitStyle::kRandom;
@ -191,15 +189,10 @@ struct Options {
if (b == -1) b = 16384 / k;
if (b == 0) b = 1;
cmd.get_cmd_line_argument("warmup_iterations", warmup_iterations, defaults.warmup_iterations);
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
cmd.get_cmd_line_argument("tensor_ring_buffers", tensor_ring_buffers, defaults.tensor_ring_buffers);
verify = cmd.check_cmd_line_flag("verify");
verbose = cmd.check_cmd_line_flag("verbose");
varlen = cmd.check_cmd_line_flag("varlen");
persistent = cmd.check_cmd_line_flag("persistent");
std::string mask;
cmd.get_cmd_line_argument<std::string>("mask", mask, "");
if (mask == "no" || mask == "") {
@ -242,13 +235,10 @@ struct Options {
<< " --q=<int> Sets the Q extent\n"
<< " --k=<int> Sets the K extent\n"
<< " --d=<int> Sets the D extentn"
<< " --tensor_ring_buffers=<int> Sets the number of tensor ring buffers\n"
<< " --warmup_iterations=<int> Sets the warmup iterations\n"
<< " --iterations=<int> Benchmarking iterations\n"
<< " --verify Verify results\n"
<< " --verbose Print smem and execution time per kernel\n"
<< " --mask=<no|residual|causal> Enables masking\n"
<< " --persistent Enables persistent scheduler\n"
<< " --varlen Enables variable sequence length\n"
<< " B*Q and B*K become the total sequence length\n"
<< " and are split B-ways, alternatingly +10% and -10%\n"
@ -389,55 +379,40 @@ struct FwdRunner {
StrideLSE stride_LSE;
uint64_t seed = 0;
struct DeviceBuffer {
DeviceAllocation<Element> block_Q;
DeviceAllocation<Element> block_K;
DeviceAllocation<Element> block_V;
DeviceAllocation<ElementOut> block_O;
DeviceAllocation<ElementAccumulatorPV> block_LSE;
DeviceAllocation<ElementOut> block_ref_O;
DeviceAllocation<ElementAccumulatorPV> block_ref_LSE;
DeviceAllocation<int> device_cumulative_seqlen_q;
DeviceAllocation<int> device_cumulative_seqlen_kv;
DeviceBuffer() = default;
DeviceBuffer(const DeviceBuffer&) = delete;
DeviceBuffer& operator=(const DeviceBuffer&) = delete;
size_t get_storage_size() const {
return block_Q.get_storage_size() + block_K.get_storage_size() + block_V.get_storage_size()
+ block_O.get_storage_size() + block_LSE.get_storage_size() + block_ref_O.get_storage_size()
+ block_ref_LSE.get_storage_size() + device_cumulative_seqlen_q.get_storage_size()
+ device_cumulative_seqlen_kv.get_storage_size();
}
};
std::vector<std::unique_ptr<DeviceBuffer>> buffers;
DeviceAllocation<Element> block_Q;
DeviceAllocation<Element> block_K;
DeviceAllocation<Element> block_V;
DeviceAllocation<ElementOut> block_O;
DeviceAllocation<ElementAccumulatorPV> block_LSE;
DeviceAllocation<ElementOut> block_ref_O;
DeviceAllocation<ElementAccumulatorPV> block_ref_LSE;
std::vector<int> cumulative_seqlen_q;
std::vector<int> cumulative_seqlen_kv;
DeviceAllocation<int> device_cumulative_seqlen_q;
DeviceAllocation<int> device_cumulative_seqlen_kv;
//
// Methods
//
bool verify(const ProblemShapeType& problem_shape, DeviceBuffer& buffer) {
Tensor mQ = make_tensor(make_gmem_ptr(buffer.block_Q.get()),
bool verify(const ProblemShapeType& problem_shape) {
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()),
select<0,2,3>(problem_shape),
stride_Q);
Tensor mK = make_tensor(make_gmem_ptr(buffer.block_K.get()),
Tensor mK = make_tensor(make_gmem_ptr(block_K.get()),
select<1,2,3>(problem_shape),
stride_K);
Tensor mV = make_tensor(make_gmem_ptr(buffer.block_V.get()),
Tensor mV = make_tensor(make_gmem_ptr(block_V.get()),
select<1,2,3>(problem_shape),
stride_V);
Tensor mO = make_tensor(make_gmem_ptr(buffer.block_ref_O.get()),
Tensor mO = make_tensor(make_gmem_ptr(block_ref_O.get()),
select<0,2,3>(problem_shape),
stride_O);
Tensor mLSE = make_tensor(make_gmem_ptr(buffer.block_ref_LSE.get()),
Tensor mLSE = make_tensor(make_gmem_ptr(block_ref_LSE.get()),
select<0,3>(problem_shape),
stride_LSE);
@ -456,7 +431,7 @@ struct FwdRunner {
// Check if output from CUTLASS kernel and reference kernel are equal or not
double max_diff = 0;
double mean_diff = 0;
reference_abs_diff(buffer.block_O, buffer.block_ref_O, max_diff, mean_diff);
reference_abs_diff(block_O, block_ref_O, max_diff, mean_diff);
bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
if (! passed_O) {
@ -464,13 +439,14 @@ struct FwdRunner {
<< " mean " << mean_diff << std::endl;
}
reference_abs_diff(buffer.block_LSE, buffer.block_ref_LSE, max_diff, mean_diff);
// reference_abs_diff(block_LSE, block_ref_LSE, max_diff, mean_diff);
bool passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
if ( ! passed_LSE) {
std::cerr << "failed LSE: max diff " << max_diff
<< " mean " << mean_diff << std::endl;
}
bool passed_LSE = true; // future work
// bool passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
// if ( ! passed_LSE) {
// std::cerr << "failed LSE: max diff " << max_diff
// << " mean " << mean_diff << std::endl;
// }
return passed_O && passed_LSE;
}
@ -583,71 +559,50 @@ struct FwdRunner {
get<1,1>(stride_LSE) = 0;
}
auto buffer_init_fn = [&](auto& buffer) {
buffer.block_Q.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
buffer.block_K.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
buffer.block_V.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
buffer.block_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
buffer.block_LSE.reset(size(shape_LSE));
block_Q.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
block_K.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
block_V.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
block_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
block_LSE.reset(size(shape_LSE));
block_ref_O.reset(size(shape_QO));
block_ref_LSE.reset(size(shape_LSE));
initialize_block(buffer.block_Q, seed + 2023, options.init_style_q);
initialize_block(buffer.block_K, seed + 2022, options.init_style_k);
initialize_block(buffer.block_V, seed + 2021, options.init_style_v);
initialize_block(block_Q, seed + 2023, options.init_style_q);
initialize_block(block_K, seed + 2022, options.init_style_k);
initialize_block(block_V, seed + 2021, options.init_style_v);
if ( ! cumulative_seqlen_q.empty()) {
buffer.device_cumulative_seqlen_q.reset(cumulative_seqlen_q.size());
buffer.device_cumulative_seqlen_q.copy_from_host(
cumulative_seqlen_q.data(), cumulative_seqlen_q.size());
}
if ( ! cumulative_seqlen_kv.empty()) {
buffer.device_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size());
buffer.device_cumulative_seqlen_kv.copy_from_host(
cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size());
}
};
buffers.push_back(std::make_unique<DeviceBuffer>());
buffer_init_fn(*buffers.back());
int tensor_ring_buffers = options.tensor_ring_buffers;
for (int i = 1; i < tensor_ring_buffers; i++) {
buffers.push_back(std::make_unique<DeviceBuffer>());
buffer_init_fn(*buffers.back());
if ( ! cumulative_seqlen_q.empty()) {
device_cumulative_seqlen_q.reset(cumulative_seqlen_q.size());
device_cumulative_seqlen_q.copy_from_host(
cumulative_seqlen_q.data(), cumulative_seqlen_q.size());
}
if ( ! cumulative_seqlen_kv.empty()) {
device_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size());
device_cumulative_seqlen_kv.copy_from_host(
cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size());
}
if constexpr (kIsVarlen) {
get<0>(problem_shape).cumulative_length = buffers[0]->device_cumulative_seqlen_q.get();
get<1>(problem_shape).cumulative_length = buffers[0]->device_cumulative_seqlen_kv.get();
get<0>(problem_shape).cumulative_length = device_cumulative_seqlen_q.get();
get<1>(problem_shape).cumulative_length = device_cumulative_seqlen_kv.get();
}
return problem_shape;
}
auto get_arguments(const ProblemShapeType& problem_shape, const cutlass::KernelHardwareInfo& hw_info, int buffer_index) {
auto problem_shape_ = problem_shape;
if constexpr (kIsVarlen) {
get<0>(problem_shape_).cumulative_length = buffers[buffer_index]->device_cumulative_seqlen_q.get();
get<1>(problem_shape_).cumulative_length = buffers[buffer_index]->device_cumulative_seqlen_kv.get();
}
typename Operation::Arguments arguments{
problem_shape_,
{ buffers[buffer_index]->block_Q.get(), stride_Q,
buffers[buffer_index]->block_K.get(), stride_K,
buffers[buffer_index]->block_V.get(), stride_V },
{ buffers[buffer_index]->block_O.get(), stride_O,
buffers[buffer_index]->block_LSE.get(), stride_LSE },
hw_info
};
return arguments;
}
ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
ProblemShapeType problem_shape = initialize(options);
int buffer_index = 0;
typename Operation::Arguments arguments = get_arguments(problem_shape, hw_info, buffer_index);
typename Operation::Arguments arguments{
problem_shape,
{ block_Q.get(), stride_Q,
block_K.get(), stride_K,
block_V.get(), stride_V },
{ block_O.get(), stride_O,
block_LSE.get(), stride_LSE },
hw_info
};
Operation op;
@ -675,21 +630,11 @@ struct FwdRunner {
}
// Run
for (int i = 0; i < options.warmup_iterations; i++) {
status = op.run();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return example_result;
}
buffer_index = (buffer_index + 1) % buffers.size();
arguments = get_arguments(problem_shape, hw_info, buffer_index);
status = op.update(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to update the CUTLASS kernel's parameters. Last CUDA error is: "
<< std::endl;
return example_result;
}
status = op.run();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return example_result;
}
cudaError_t result = cudaDeviceSynchronize();
@ -727,14 +672,6 @@ struct FwdRunner {
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return example_result;
}
buffer_index = (buffer_index + 1) % buffers.size();
arguments = get_arguments(problem_shape, hw_info, buffer_index);
status = op.update(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to update the CUTLASS kernel's parameters. Last CUDA error is: "
<< std::endl;
return example_result;
}
}
//
@ -797,7 +734,7 @@ struct FwdRunner {
// Verify that the result is correct
bool passed = true;
if (options.verify) {
passed = verify(problem_shape, *buffers[0]);
passed = verify(problem_shape);
if (passed) example_result.verified = true;
}
@ -852,14 +789,10 @@ void run_fwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareIn
using HeadDim = _128;
if (options.persistent) {
// Persistent Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
}
else {
// Individual Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
}
// Persistent Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
// Individual Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
}
///////////////////////////////////////////////////////////////////////////////////////////////////
@ -885,14 +818,10 @@ void run_fwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInf
using HeadDim = _64;
if (options.persistent) {
// Persistent Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
}
else {
// Individual Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
}
// Persistent Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
// Individual Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
}
@ -916,14 +845,10 @@ void run_fwd_32(Mask fusion, Options const & options, cutlass::KernelHardwareInf
using HeadDim = _32;
#ifdef FP8
if (options.persistent) {
// Persistent Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
}
else {
// Individual Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
}
// Persistent Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
// Individual Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
#endif
}
@ -988,7 +913,6 @@ int main_single(int argc, char const **args) {
hw_info.sm_count = options.sm_count;
}
std::cout << "###### B " << options.b << " H " << options.h << " H_K " << options.h_k << " Q " << options.q << " K " << options.k << " D " << options.d << " ";
std::cout << "Forward" << " " << (options.causal ? "Causal" : (options.residual ? "Residual" : "None")) << " ";
std::cout << "#SM " << hw_info.sm_count << std::endl;

View File

@ -120,8 +120,6 @@ struct Options {
bool verbose = false;
bool causal = false;
bool residual = false;
bool varlen = false;
int sm_count = 0;
std::string kernel_filter;
@ -192,21 +190,14 @@ struct Options {
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
verify = cmd.check_cmd_line_flag("verify");
verbose = cmd.check_cmd_line_flag("verbose");
varlen = cmd.check_cmd_line_flag("varlen");
std::string mask;
cmd.get_cmd_line_argument<std::string>("mask", mask, "");
if (mask == "causal") {
causal = true;
}
else if (mask == "residual") {
residual = true;
}
else {
causal = defaults.causal;
}
if (varlen) {
residual = true;
}
skip_reference = cmd.check_cmd_line_flag("skip-reference");
cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count);
@ -239,12 +230,7 @@ struct Options {
<< " --iterations=<int> Benchmarking iterations\n"
<< " --verify Verify results\n"
<< " --verbose Print smem and execution time per kernel\n"
<< " --mask=<no|residual|causal> Enables masking\n"
<< " --varlen Enables variable sequence length\n"
<< " B*Q and B*K become the total sequence length\n"
<< " and are split B-ways, alternatingly +10% and -10%\n"
<< " with the last batch sized to make it fit\n"
<< " implies at least residual masking for correctness\n"
<< " --mask=<no|causal> Enables masking\n"
<< " --sm-count Sets SM count rather than querying it\n"
<< " --kernel-filter=<filter> Sets regexp to match kernel against\n"
<< "\n";
@ -321,7 +307,6 @@ struct ExampleResult {
///////////////////////////////////////////////////////////////////////////////////////////////////
template<
bool kIsVarlen,
class TileShape,
class DispatchPolicy,
class ActiveMask,
@ -337,11 +322,9 @@ struct BwdRunner {
using ElementAccumulator = float;
// Q K D (H B)
using ProblemShape = std::conditional_t<
kIsVarlen,
cute::tuple<VariableLength, VariableLength, int, cute::tuple<int, int>>,
cute::tuple<int, int, int, cute::tuple<int, int>>
>;
using ProblemShapeType = cute::tuple<int, int, int, cute::tuple<int, int>>;
using Operation = cutlass::fmha::device::Sm100FmhaBwd<Element, ElementAccumulator, TileShape, ActiveMask>;
using TensorStride = Stride<int, _1, Stride<int, int>>; // Seq D (H B)
using StrideQ = TensorStride;
@ -380,9 +363,6 @@ struct BwdRunner {
DeviceAllocation<Element> block_O;
DeviceAllocation<ElementAccumulator> block_LSE;
DeviceAllocation<int> block_cumulative_seqlen_q;
DeviceAllocation<int> block_cumulative_seqlen_kv;
DeviceAllocation<Element> block_dQ;
DeviceAllocation<Element> block_dK;
DeviceAllocation<Element> block_dV;
@ -395,7 +375,7 @@ struct BwdRunner {
//
// Methods
//
bool verify(const ProblemShape& problem_shape) {
bool verify(const ProblemShapeType& problem_shape) {
auto [Q, K, D, HB] = problem_shape;
auto [H, B] = HB;
@ -479,85 +459,22 @@ struct BwdRunner {
return passed_dQ && passed_dK && passed_dV;
}
auto initialize_problem_shape(Options const& options) {
if constexpr (kIsVarlen) {
int num_batches = options.b;
// generate Q as --b times
// gaussian (--Q, --Q / 2) sampled positive
// track cumulative
std::mt19937 rng(0x202305151552ull);
std::normal_distribution<double> dist_q(options.q, options.q / 2);
std::normal_distribution<double> dist_kv(options.k, options.k / 2);
auto generate_positive_int = [](auto& dist, auto& gen) {
// "0" is a valid value we test here
return std::max(0, static_cast<int>(dist(gen)));
};
std::vector<int> cumulative_seqlen_q = {0};
std::vector<int> cumulative_seqlen_kv = {0};
int total_seqlen_q = 0;
int total_seqlen_kv = 0;
int max_seqlen_q = 0;
int max_seqlen_kv = 0;
const bool kVarlenSame = false;
for (int i = 0; i < num_batches; i++) {
int seqlen_q = kVarlenSame ? options.q : generate_positive_int(dist_q, rng);
int seqlen_kv = kVarlenSame ? options.k : generate_positive_int(dist_kv, rng);
total_seqlen_q += seqlen_q;
total_seqlen_kv += seqlen_kv;
max_seqlen_q = std::max(max_seqlen_q, seqlen_q);
max_seqlen_kv = std::max(max_seqlen_kv, seqlen_kv);
cumulative_seqlen_q.push_back(cumulative_seqlen_q.back() + seqlen_q);
cumulative_seqlen_kv.push_back(cumulative_seqlen_kv.back() + seqlen_kv);
}
block_cumulative_seqlen_q.reset(cumulative_seqlen_q.size());
block_cumulative_seqlen_q.copy_from_host(cumulative_seqlen_q.data(), cumulative_seqlen_q.size());
block_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size());
block_cumulative_seqlen_kv.copy_from_host(cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size());
ProblemShape problem_shape{
{max_seqlen_q, block_cumulative_seqlen_q.get(), total_seqlen_q},
{max_seqlen_kv, block_cumulative_seqlen_kv.get(), total_seqlen_kv},
options.d, {options.h, options.b}
};
auto tensor_shape = make_shape(total_seqlen_q, total_seqlen_kv, options.d, make_shape(options.h, 1));
return cute::make_tuple(problem_shape, tensor_shape);
}
else {
ProblemShape problem_shape{options.q, options.k, options.d, {options.h, options.b}};
return cute::make_tuple(problem_shape, problem_shape);
}
}
/// Initialize operands to be used in the GEMM and reference GEMM
ProblemShape initialize(Options const& options) {
auto [problem_shape, tensor_shape] = initialize_problem_shape(options);
auto [Q, K, D, HB] = tensor_shape;
void initialize(const ProblemShapeType& problem_shape, Options const& options) {
auto [Q, K, D, HB] = problem_shape;
auto [H, B] = HB;
D = cutlass::round_up(D, 8); // Alignment
Q = cutlass::round_up(Q, 8); // Alignment
// for varlen, Q == total_Q, K == total_K, B = 1
// but in problem_shape, they've got to be max_Q/max_K, and B = B
auto shape_QO = make_shape(Q, D, make_shape(H, B));
auto shape_KV = make_shape(K, D, make_shape(H, B));
auto shape_LSE = make_shape(Q, make_shape(H, B));
stride_Q = make_stride(D, _1{}, make_stride(D*Q, B == 1 ? 0 : D*Q*H));
stride_K = make_stride(D, _1{}, make_stride(D*K, B == 1 ? 0 : D*K*H));
stride_LSE = make_stride(_1{}, make_stride(Q, B == 1 ? 0 : Q*H));
auto shape_QO = select<0,2,3>(problem_shape);
auto shape_KV = select<1,2,3>(problem_shape);
auto shape_LSE = select<0,3>(problem_shape);
stride_Q = make_stride(D, _1{}, make_stride(D*Q, D*Q*H));
stride_K = make_stride(D, _1{}, make_stride(D*K, D*K*H));
stride_V = stride_K;
stride_O = stride_Q;
stride_LSE = make_stride(_1{}, make_stride(Q, Q*H));
stride_dQ = stride_Q;
stride_dK = stride_K;
@ -588,13 +505,6 @@ struct BwdRunner {
initialize_block(block_V, seed + 2021, options.init_style_v);
initialize_block(block_dO, seed + 2020, options.init_style_do);
initialize_block(block_dQ, seed + 2030, InitStyle::kOne);
initialize_block(block_dK, seed + 2031, InitStyle::kOne);
initialize_block(block_dV, seed + 2032, InitStyle::kOne);
initialize_block(block_ref_dQ, seed + 2033);
initialize_block(block_ref_dK, seed + 2034);
initialize_block(block_ref_dV, seed + 2035);
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()),
select<0,2,3>(problem_shape),
stride_Q);
@ -618,19 +528,15 @@ struct BwdRunner {
if (! options.skip_reference) {
fmha_reference(problem_shape, mQ, mK, mV, mO, mLSE, ActiveMask{});
}
return problem_shape;
}
ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
auto problem_shape = initialize(options);
auto problem_shape = make_shape(options.q, options.k, options.d, make_shape(options.h, options.b));
initialize(problem_shape, options);
ElementAccumulator softmax_scale = 1.0f / sqrtf(options.d);
ExampleResult example_result;
using Operation = cutlass::fmha::device::Sm100FmhaBwd<ProblemShape, Element, ElementAccumulator, TileShape, ActiveMask>;
typename Operation::Arguments arguments{
problem_shape,
block_Q.get(), stride_Q,
@ -648,6 +554,8 @@ struct BwdRunner {
Operation op;
ExampleResult example_result;
example_result.smem_size = Operation::Kernel::SharedStorageSize;
size_t workspace_size = 0;
@ -742,7 +650,7 @@ struct BwdRunner {
runtime_ms /= static_cast<float>(options.iterations);
double flops = 10.0 * (std::is_same_v<ActiveMask, CausalForBackwardMask> ? 0.5 : 1.0);
double flops = 10.0 * (std::is_same_v<ActiveMask, CausalMask> ? 0.5 : 1.0);
flops *= static_cast<double>(get<0>(problem_shape));
flops *= static_cast<double>(get<1>(problem_shape));
flops *= static_cast<double>(get<2>(problem_shape));
@ -798,28 +706,14 @@ void print_result(const std::string& description, ExampleResult result, bool ver
struct KernelCoop {};
///////////////////////////////////////////////////////////////////////////////////////////////////
template<class Fn>
auto dispatch_bool(bool value, Fn fn) {
if (value) {
return fn(std::true_type{});
}
else {
return fn(std::false_type{});
}
}
//////////////////////////////////////////////////////////////////////////////////////////////////
template<class Mask>
void run_bwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) {
dispatch_bool(options.varlen, [&](auto is_varlen) {
BwdRunner<decltype(is_varlen)::value, decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
auto result = runner.run(options, hw_info);
print_result(name, result, options.verbose);
});
BwdRunner<decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
auto result = runner.run(options, hw_info);
print_result(name, result, options.verbose);
};
using HeadDim = _64;
@ -832,11 +726,9 @@ void run_bwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInf
template<class Mask>
void run_bwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) {
dispatch_bool(options.varlen, [&](auto is_varlen) {
BwdRunner<decltype(is_varlen)::value, decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
auto result = runner.run(options, hw_info);
print_result(name, result, options.verbose);
});
BwdRunner<decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
auto result = runner.run(options, hw_info);
print_result(name, result, options.verbose);
};
using HeadDim = _128;
@ -911,10 +803,7 @@ int main_single(int argc, char const **args) {
auto with_causal = [&](auto fn) {
if (options.causal) {
fn(CausalForBackwardMask{});
}
else if (options.residual) {
fn(ResidualMaskForBackward{});
fn(CausalMask{});
}
else {
fn(NoMask{});

View File

@ -394,7 +394,6 @@ struct ExampleRunner {
fmha_fwd_gen_reference<ElementAcc>(
problem_shape, block_seqlen_kv.get(), block_cache_batch_idx.get(),
mQ, mNewK, mNewV, mCacheK, mCacheV, mO);
cudaError_t result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "Reference kernel failed. Last CUDA error: "
@ -409,7 +408,6 @@ struct ExampleRunner {
double max_diff = 0;
double mean_diff = 0;
reference_abs_diff(block_o, block_ref_o, max_diff, mean_diff);
bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
if (! passed_O) {
std::cerr << "failed O: max diff " << max_diff
@ -417,7 +415,6 @@ struct ExampleRunner {
}
reference_abs_diff(block_cache_k, block_ref_cache_k, max_diff, mean_diff);
bool passed_K = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
if ( ! passed_K) {
std::cerr << "failed Cache K: max diff " << max_diff
@ -425,7 +422,6 @@ struct ExampleRunner {
}
reference_abs_diff(block_cache_v, block_ref_cache_v, max_diff, mean_diff);
bool passed_V = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
if ( ! passed_V) {
std::cerr << "failed Cache V: max diff " << max_diff
@ -507,7 +503,6 @@ struct ExampleRunner {
block_ref_cache_k.copy_from_device(block_cache_k.get(), block_cache_k.size());
block_ref_cache_v.copy_from_device(block_cache_v.get(), block_cache_v.size());
block_seqlen_kv.reset(seqlen_kv.size());
block_seqlen_kv.copy_from_host(seqlen_kv.data(), seqlen_kv.size());
@ -726,7 +721,6 @@ int main_single(int argc, char const **args) {
<< "later (compute capability 90 or greater) and CUDA 12.0 or greater.\n";
return 0;
}
//
// Parse options
//

View File

@ -59,7 +59,7 @@ using namespace cutlass::fmha::kernel;
///////////////////////////////////////////////////////////////////////////////////////////////////
enum class InitStyle {
kOne, kLinearStride128, kLinearStride1, kRandom, kRandomLarge, kNone
kOne, kLinearStride128, kLinearStride1, kRandom, kNone
};
///////////////////////////////////////////////////////////////////////////////////////////////////
@ -98,9 +98,6 @@ struct Options {
if (s == "r") {
dst = InitStyle::kRandom;
}
else if (s == "l") {
dst = InitStyle::kRandomLarge;
}
else if (s == "1") {
dst = InitStyle::kOne;
}
@ -206,11 +203,6 @@ void initialize_block(
block.get(), block.size(), seed, (Element) -1, (Element) 1);
break;
}
case InitStyle::kRandomLarge: {
cutlass::reference::device::BlockFillRandomGaussian(
block.get(), block.size(), seed, (Element) -1, (Element) 100);
break;
}
case InitStyle::kLinearStride1: {
std::vector<Element> data(block.size());
for (size_t i = 0; i < block.size() / 128; i ++) {

View File

@ -63,7 +63,7 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
TEST_COMMAND_OPTIONS
TEST_BASIC
# TEST_CAUSAL
TEST_VARLEN
# TEST_VARLEN
# TEST_HDIM64
# TEST_GQA)
)
@ -119,7 +119,7 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
77_blackwell_fmha_bwd.cu
TEST_COMMAND_OPTIONS
TEST_BASIC
TEST_VARLEN
# TEST_GEN_VARLEN
# TEST_GEN_HDIM64
# TEST_GEN_GQA
# TEST_GEN_REMAP
@ -128,24 +128,20 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
target_include_directories(77_blackwell_fmha_bwd_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_fmha_bwd_${PREC} PRIVATE ${PREC_MACRO})
target_compile_options(77_blackwell_fmha_bwd_${PREC} PRIVATE -Xptxas -v)
endforeach()
# Add a target that builds all examples
add_custom_target(77_blackwell_fmha_all
DEPENDS
77_blackwell_fmha_fp8
77_blackwell_fmha_fp16
77_blackwell_fmha_gen_fp8
77_blackwell_fmha_gen_fp16
77_blackwell_mla_2sm_fp8
77_blackwell_mla_2sm_fp16
77_blackwell_mla_2sm_cpasync_fp8
77_blackwell_mla_2sm_cpasync_fp16
77_blackwell_mla_b2b_2sm_fp8
77_blackwell_mla_b2b_2sm_fp16
77_blackwell_fmha_bwd_fp8
77_blackwell_fmha_bwd_fp16
77_blackwell_fmha_bwd_sat_fp8
77_blackwell_fmha_bwd_sat_fp16
)
cutlass_example_add_executable(
77_blackwell_fmha_bwd_sat_${PREC}
77_blackwell_fmha_bwd.cu
TEST_COMMAND_OPTIONS
TEST_BASIC
# TEST_GEN_VARLEN
TEST_GEN_HDIM64
# TEST_GEN_GQA
# TEST_GEN_REMAP
# TEST_GEN_CACHEONLY)
)
target_include_directories(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE ${PREC_MACRO} SKIP_ATOMIC)
target_compile_options(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE -Xptxas -v)
endforeach()
endif()

View File

@ -132,58 +132,6 @@ struct ResidualMask : NoMask {
}
};
struct ResidualMaskForBackward : NoMask {
using Base = NoMask;
template <class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE int get_masked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
if (get<1>(problem_size) % get<1>(tile_shape) != 0) {
return 1;
}
return 0;
}
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_unmasked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
// if the sequence length does not divide the tile size evenly
if (get<1>(problem_size) % get<1>(tile_shape) != 0) {
return get_trip_count(blk_coord, tile_shape, problem_size) - 1;
}
return get_trip_count(blk_coord, tile_shape, problem_size);
}
template<class AccQK, class IndexQK, class ProblemSize>
CUTLASS_DEVICE
void apply_mask(
AccQK& acc_qk,
IndexQK const& index_qk,
ProblemSize const& problem_size) {
// This is useful is seqlen_k % kBlockN != 0 since it masks
// the remaining elements out from softmax.
// d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar
// issues as they are transparently taken care of by TMA and the
// epilogue, if it is instantiated with predication support.
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_qk); i++) {
auto pos = index_qk(i);
if (! elem_less(pos, select<0,1>(problem_size))) {
acc_qk(i) = -INFINITY;
}
}
}
};
struct CausalMask : NoMask {
using Base = NoMask;
@ -209,7 +157,8 @@ struct CausalMask : NoMask {
TileShape const& tile_shape,
ProblemSize const& problem_size) {
return ceil_div(get<0>(tile_shape), get<1>(tile_shape));
int trip_count = get_trip_count(blk_coord, tile_shape, problem_size);
return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));
}
template<class BlkCoord, class TileShape, class ProblemSize>
@ -248,57 +197,25 @@ struct CausalMask : NoMask {
};
struct CausalForBackwardMask : CausalMask, ResidualMaskForBackward {
using Base = CausalMask;
template<class AccQK, class IndexQK, class ProblemSize>
CUTLASS_DEVICE
void apply_mask(
AccQK& acc_qk,
IndexQK const& index_qk,
ProblemSize const& problem_size) {
// There are two ways to do causal if N_Q != N_K
// (1) is to assume that the Q is at the beginning of the matrix
// - this is what we demonstrate here
// (2) is that it is at the end of the matrix
// - this is usually what we want for inference settings
// where we only compute the next row and use cache for the rest
// - if you'd like this, you only need to add an offset like so:
// get<0>(pos) + offset_q < get<1>(pos)
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_qk); i++) {
auto pos = index_qk(i);
bool masked = (get<0>(pos) < get<1>(pos)) || !elem_less(pos, problem_size);
if (masked) {
acc_qk(i) = -INFINITY;
}
}
}
};
struct VariableLength {
int max_length;
int* cumulative_length = nullptr;
int total_length = -1;
CUTE_HOST_DEVICE operator int() const {
return max_length;
}
};
template<class T> struct is_variable_length_impl : std::false_type {};
template<> struct is_variable_length_impl<VariableLength> : std::true_type {};
template<class T> constexpr bool is_variable_length_v = is_variable_length_impl<remove_cvref_t<T>>::value;
template<class T> struct is_variable_length : std::false_type {};
template<> struct is_variable_length<VariableLength> : std::true_type {};
template<class T> constexpr bool is_variable_length_v = is_variable_length<T>::value;
template<class Shape, class Idx>
CUTE_HOST_DEVICE
constexpr auto
apply_variable_length(Shape const& shape, Idx const& idx) {
return transform_leaf(shape, [&](auto const& s) {
if constexpr (is_variable_length_v<decltype(s)>) {
if constexpr (is_variable_length_v<remove_cvref_t<decltype(s)>>) {
return s.cumulative_length[idx+1] - s.cumulative_length[idx];
}
else {
@ -313,7 +230,7 @@ constexpr auto
apply_variable_length(Shape const& shape, Coord const& coord, Idx const& idx) {
auto new_shape = apply_variable_length(shape, idx);
auto new_coord = transform_leaf(shape, coord, [&](auto const& s, auto const& c) {
if constexpr (is_variable_length_v<decltype(s)>) {
if constexpr (is_variable_length_v<remove_cvref_t<decltype(s)>>) {
return cute::make_tuple(c, s.cumulative_length[idx]);
}
else {
@ -323,30 +240,6 @@ apply_variable_length(Shape const& shape, Coord const& coord, Idx const& idx) {
return cute::make_tuple(new_shape, new_coord);
}
template<class Shape, class Coord>
CUTE_HOST_DEVICE
constexpr auto
apply_variable_length_offset(Shape const& shape, Coord const& coord) {
auto idx = back(back(coord));
auto result_shape = transform_leaf(shape, [&](auto const& s) {
if constexpr (is_variable_length_v<decltype(s)>) {
return s.cumulative_length[idx+1] - s.cumulative_length[idx];
}
else {
return s;
}
});
auto result_offset = transform_leaf(coord, shape, [&](auto const& c, auto const& s) {
if constexpr (is_variable_length_v<decltype(s)>) {
return s.cumulative_length[idx];
}
else {
return _0{};
}
});
return cute::make_tuple(result_shape, result_offset);
}
} // namespace cutlass::fmha::collective
namespace cute {

View File

@ -42,7 +42,7 @@ template<
class ElementAcc,
class TileShape, // Q, D, _
class StrideO, // Q, D, B
class StrideLSE_ // Q, B
class StrideLSE // Q, B
>
struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
@ -54,7 +54,6 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
// using SmemLayoutAtomO = decltype(make_ordered_layout(select<0,1>(TileShape{}), Step<_1, _0>{}));
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, replace<2>(TileShape{}, _2{}), Step<_2, _1, _3>{}));
using SmemLayoutO_ = SmemLayoutO;
using StrideLSE = StrideLSE_;
struct TensorStorage {
@ -80,9 +79,6 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
struct Params {
TMA_O tma_store_o;
ElementAcc* ptr_LSE;
StrideLSE dLSE;
};
template<class ProblemShape>
@ -114,9 +110,7 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
);
return {
tma_store_o,
args.ptr_LSE,
args.dLSE
tma_store_o
};
}
@ -125,10 +119,6 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
cute::prefetch_tma_descriptor(params.tma_store_o.get_tma_descriptor());
}
const Params& params;
CUTLASS_DEVICE Sm100FmhaFwdEpilogueTmaWarpspecialized(const Params& params) : params(params) {}
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
CUTLASS_DEVICE auto
store(

View File

@ -942,7 +942,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
}
}
template<class BlkCoord, class ProblemShape, class TensorStorageEpi, class CollectiveEpilogue>
template<class BlkCoord, class ProblemShape, class TensorStorageEpi>
CUTLASS_DEVICE auto
correction(
BlkCoord const& blk_coord,
@ -951,8 +951,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state,
PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state,
PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state,
PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state,
CollectiveEpilogue& epilogue) {
PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state) {
int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape);
@ -1061,20 +1060,8 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
// F2FP
// store to smem
Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{});
Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), repeat_like(typename CollectiveEpilogue::StrideLSE{}, _1{}), epilogue.params.dLSE);
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO);
if (epilogue.params.ptr_LSE != nullptr) {
int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord);
ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);
if (row_idx < get<0>(problem_shape)) {
gLSE(row_idx, get<2>(blk_coord)) = lse;
}
}
cutlass::arch::fence_view_async_tmem_load();
pipeline_o.consumer_release(pipeline_o_consumer_state);
@ -1096,16 +1083,6 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _1{}, sO);
if (epilogue.params.ptr_LSE != nullptr) {
int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{});
ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);
if (row_idx < get<0>(problem_shape)) {
gLSE(row_idx, get<2>(blk_coord)) = lse;
}
}
cutlass::arch::fence_view_async_tmem_load();
pipeline_o.consumer_release(pipeline_o_consumer_state);

View File

@ -50,19 +50,13 @@ namespace cutlass::fmha::device {
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
template<
class ProblemShape,
class Element,
class ElementAccumulator,
class TileShape,
class Mask
>
template<class Element, class ElementAccumulator, class TileShape, class Mask>
class Sm100FmhaBwd {
public:
/// Argument structure: User API
struct Arguments {
// Q K D HB
ProblemShape problem_shape;
cute::tuple<int, int, int, cute::tuple<int, int>> problem_size;
const Element* ptr_Q;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_Q;
@ -92,16 +86,14 @@ public:
};
using OperationSumOdO = cutlass::fmha::device::FMHA<
cutlass::fmha::kernel::FmhaKernelBwdSumOdO<ProblemShape, Element, ElementAccumulator>
cutlass::fmha::kernel::FmhaKernelBwdSumOdO<Element, ElementAccumulator>
>;
using OperationConvert = cutlass::fmha::device::FMHA<
cutlass::fmha::kernel::FmhaKernelBwdConvert<ProblemShape, Element, ElementAccumulator>
cutlass::fmha::kernel::FmhaKernelBwdConvert<Element, ElementAccumulator>
>;
using Operation = cutlass::fmha::device::FMHA<
cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized<
ProblemShape, Element, ElementAccumulator, TileShape, Mask
>
cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized<Element, ElementAccumulator, TileShape, Mask>
>;
using Kernel = typename Operation::Kernel;
@ -121,15 +113,15 @@ private:
ElementAccumulator* sum_odo = nullptr,
ElementAccumulator* scaled_lse = nullptr) {
using namespace cute;
auto [Q_, K, D, HB] = args.problem_shape;
auto [Q, K, D, HB] = args.problem_size;
auto [H, B] = HB;
D = cutlass::round_up(D, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
Q = cutlass::round_up(Q, 8); // Alignment
auto stride_sum_OdO = make_stride(_1{}, make_stride(Q, Q*H));
auto stride_scaled_lse = make_stride(_1{}, make_stride(Q, Q*H));
auto log2_e = log2f(expf(1.0f));
return typename OperationSumOdO::Arguments {
args.problem_shape,
args.problem_size,
args.ptr_O, args.stride_O,
args.ptr_dO, args.stride_dO,
sum_odo, stride_sum_OdO,
@ -141,13 +133,13 @@ private:
static typename OperationConvert::Arguments to_convert_arguments(Arguments const& args, ElementAccumulator* src = nullptr) {
using namespace cute;
auto [Q_, K, D, HB] = args.problem_shape;
auto [Q, K, D, HB] = args.problem_size;
auto [H, B] = HB;
D = cutlass::round_up(D, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
Q = cutlass::round_up(Q, 8); // Alignment
auto stride_src_dQ = make_stride(D, _1{}, make_stride(D*Q, D*Q*H));
return typename OperationConvert::Arguments {
args.problem_shape,
args.problem_size,
src, stride_src_dQ,
nullptr, stride_src_dQ,
nullptr, stride_src_dQ,
@ -164,7 +156,7 @@ private:
ElementAccumulator* scaled_lse = nullptr, cute::tuple<cute::_1, cute::tuple<int, int>> const& stride_scaled_lse = {},
ElementAccumulator* dQ_acc = nullptr, cute::tuple<int, cute::_1, cute::tuple<int, int>> const& stride_dQ = {}) {
return typename Operation::Arguments{
args.problem_shape,
args.problem_size,
{ args.ptr_Q, args.stride_Q,
args.ptr_K, args.stride_K,
args.ptr_V, args.stride_V,
@ -207,10 +199,10 @@ public:
/// Gets the workspace size
static size_t
get_workspace_size(Arguments const& args) {
auto [Q_, K, D, HB] = args.problem_shape;
auto [Q, K, D, HB] = args.problem_size;
auto [H, B] = HB;
D = cutlass::round_up(D, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
Q = cutlass::round_up(Q, 8); // Alignment
size_t workspace_bytes = 0;
// OdO vector
workspace_bytes += B*H*Q * sizeof(ElementAccumulator);
@ -227,10 +219,10 @@ public:
CUTLASS_TRACE_HOST("Universal::initialize_split() - workspace_dQ="
<< workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null"));
auto [Q_, K, D, HB] = args.problem_shape;
auto [Q, K, D, HB] = args.problem_size;
auto [H, B] = HB;
D = cutlass::round_up(D, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
Q = cutlass::round_up(Q, 8); // Alignment
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_sum_OdO);
ElementAccumulator* scaled_lse = reinterpret_cast<ElementAccumulator*>(workspace_scaled_lse);
ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_dQ);
@ -256,10 +248,10 @@ public:
CUTLASS_TRACE_HOST("Universal::initialize() - workspace "
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
auto [Q_, K, D, HB] = args.problem_shape;
auto [Q, K, D, HB] = args.problem_size;
auto [H, B] = HB;
D = cutlass::round_up(D, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
Q = cutlass::round_up(Q, 8); // Alignment
char* workspace_chr = reinterpret_cast<char*>(workspace);
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_chr);
workspace_chr += B*H*Q * sizeof(ElementAccumulator);

View File

@ -39,11 +39,11 @@ namespace cutlass::fmha::kernel {
using namespace cute;
template<class ProblemShape, class Element, class ElementAcc>
template<class Element, class ElementAcc>
struct FmhaKernelBwdConvert {
struct Arguments {
ProblemShape problem_shape;
tuple<int, int, int, tuple<int, int>> problem_size;
const ElementAcc* ptr_src_dQ;
tuple<int, _1, tuple<int, int>> stride_src_dQ;
@ -85,11 +85,11 @@ struct FmhaKernelBwdConvert {
static const int kIterationsSeq = kBlockSeq / kNumThreadsSeq;
static bool can_implement(Arguments const& args) {
return get<2>(args.problem_shape) % kElementsPerLoad == 0;
return get<2>(args.problem_size) % kElementsPerLoad == 0;
}
static dim3 get_grid_shape(Params const& params) {
dim3 grid(size<3,0>(params.problem_shape), size<3,1>(params.problem_shape), ceil_div(std::max(size<0>(params.problem_shape), size<1>(params.problem_shape)), kBlockSeq));
dim3 grid(size<3,0>(params.problem_size), size<3,1>(params.problem_size), ceil_div(std::max(size<0>(params.problem_size), size<1>(params.problem_size)), kBlockSeq));
return grid;
}
@ -102,25 +102,18 @@ struct FmhaKernelBwdConvert {
return args;
}
template<class StrideSrc, class StrideDest, class Count>
CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, Count const& count) {
template<class StrideSrc, class StrideDest>
CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, int count) {
auto ptr_src_bh = ptr_src + get<2,0>(stride_src) * blockIdx.x + get<2,1>(stride_src) * blockIdx.y;
auto ptr_dest_bh = ptr_dest + get<2,0>(stride_dest) * blockIdx.x + get<2,1>(stride_dest) * blockIdx.y;
int seqlen = count;
if constexpr (is_variable_length_v<decltype(count)>) {
int offset = count.cumulative_length[blockIdx.y];
ptr_dest_bh += offset * get<0>(stride_dest);
seqlen = count.cumulative_length[blockIdx.y + 1] - offset;
}
for (int idx_s_t = threadIdx.y; idx_s_t < kBlockSeq; idx_s_t += kNumThreadsSeq) {
int idx_s = idx_s_t + kBlockSeq * blockIdx.z;
if (idx_s >= seqlen) continue;
if (idx_s >= count) continue;
auto ptr_src_bhs = ptr_src_bh + idx_s * get<0>(stride_src);
auto ptr_dest_bhs = ptr_dest_bh + idx_s * get<0>(stride_dest);
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_shape); idx_d += kElementsPerLoad * kNumThreadsD) {
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) {
ElementAcc value_src[kElementsPerLoad];
Element value_dest[kElementsPerLoad];
@ -139,13 +132,13 @@ struct FmhaKernelBwdConvert {
CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
if (params.ptr_src_dQ != nullptr) {
copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_shape));
copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_size));
}
if (params.ptr_src_dK != nullptr) {
copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<1>(params.problem_shape));
copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<1>(params.problem_size));
}
if (params.ptr_src_dV != nullptr) {
copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<1>(params.problem_shape));
copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<1>(params.problem_size));
}
}
};

View File

@ -39,11 +39,11 @@ namespace cutlass::fmha::kernel {
using namespace cute;
template<class ProblemShape, class Element, class ElementAcc>
template<class Element, class ElementAcc>
struct FmhaKernelBwdSumOdO {
struct Arguments {
ProblemShape problem_shape;
cute::tuple<int, int, int, cute::tuple<int, int>> problem_size;
const Element* ptr_O;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_O;
@ -86,11 +86,11 @@ struct FmhaKernelBwdSumOdO {
static const int kIterationsQ = kBlockQ / kNumThreadsQ;
static bool can_implement(Arguments const& args) {
return get<2>(args.problem_shape) % kElementsPerLoad == 0;
return get<2>(args.problem_size) % kElementsPerLoad == 0;
}
static dim3 get_grid_shape(Params const& params) {
dim3 grid(ceil_div(size<0>(params.problem_shape), kBlockQ), size<3,0>(params.problem_shape), size<3,1>(params.problem_shape));
dim3 grid(ceil_div(size<0>(params.problem_size), kBlockQ), size<3,0>(params.problem_size), size<3,1>(params.problem_size));
return grid;
}
@ -110,20 +110,10 @@ struct FmhaKernelBwdSumOdO {
auto ptr_lse_bh = params.ptr_lse + blockIdx.y * get<1,0>(params.stride_lse) + blockIdx.z * get<1,1>(params.stride_lse);
auto ptr_scaled_lse_bh = params.ptr_scaled_lse + blockIdx.y * get<1,0>(params.stride_scaled_lse) + blockIdx.z * get<1,1>(params.stride_scaled_lse);
auto problem_q = get<0>(params.problem_shape);
int seqlen_q = problem_q;
if constexpr (is_variable_length_v<decltype(problem_q)>) {
int offset = problem_q.cumulative_length[blockIdx.z];
ptr_O_bh += offset * get<0>(params.stride_O);
ptr_dO_bh += offset * get<0>(params.stride_dO);
ptr_lse_bh += offset * get<0>(params.stride_lse);
seqlen_q = problem_q.cumulative_length[blockIdx.z + 1] - offset;
}
CUTLASS_PRAGMA_UNROLL
for (int idx_q_t = threadIdx.y; idx_q_t < kBlockQ; idx_q_t += kNumThreadsQ) {
int idx_q = idx_q_t + kBlockQ * blockIdx.x;
if (idx_q >= seqlen_q) continue;
if (idx_q >= get<0>(params.problem_size)) continue;
ElementAcc acc = 0;
auto ptr_O_bhq = ptr_O_bh + idx_q * get<0>(params.stride_O);
auto ptr_dO_bhq = ptr_dO_bh + idx_q * get<0>(params.stride_dO);
@ -131,7 +121,7 @@ struct FmhaKernelBwdSumOdO {
auto ptr_lse_bhq = ptr_lse_bh + idx_q * get<0>(params.stride_lse);
auto ptr_scaled_lse_bhq = ptr_scaled_lse_bh + idx_q * get<0>(params.stride_scaled_lse);
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_shape); idx_d += kElementsPerLoad * kNumThreadsD) {
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) {
Element value_O[kElementsPerLoad];
Element value_dO[kElementsPerLoad];

View File

@ -90,8 +90,8 @@ struct PersistentTileScheduler {
struct Params {
int num_blocks;
FastDivmod divmod_m_block;
FastDivmod divmod_b;
FastDivmod divmod_h;
FastDivmod divmod_b;
KernelHardwareInfo hw_info;
};
@ -146,7 +146,7 @@ struct PersistentTileScheduler {
params.divmod_m_block(block_decode, m_block, block_decode);
params.divmod_b(block_decode, bidb, block_decode);
params.divmod_h(block_decode, bidh, block_decode);
return make_coord(m_block, _0{}, make_coord(bidb, bidh));
return make_coord(m_block, _0{}, make_coord(bidh, bidb));
}
CUTLASS_DEVICE

View File

@ -43,8 +43,6 @@
#include "collective/fmha_common.hpp"
#include <cmath>
namespace cutlass::fmha::kernel {
using namespace cutlass::fmha::collective;
@ -52,7 +50,6 @@ using namespace cutlass::fmha::collective;
using namespace cute;
template<
class ProblemShape,
class Element,
class ElementAcc,
class TileShape,
@ -277,6 +274,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
static constexpr int SharedStorageSize = offsetof(SharedStorage, tmem_base_ptr) + sizeof(uint32_t);
static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem");
using ProblemShape = Shape<int, int, int, Shape<int, int>>; // Q K D (H B), eventuall D = (D_QK, D_VO)
using TensorStride = TensorStrideContiguousK; // S D (H B)
using RowTensorStride = Stride<_1, Stride<int, int>>; // S (H B)
@ -362,16 +360,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
static Params to_underlying_arguments(Arguments const& args, void*) {
auto [Q_, K_, D, HB] = args.problem_shape;
int Q = Q_;
int K = K_;
if constexpr (is_variable_length_v<decltype(Q_)>) {
Q = Q_.total_length;
}
if constexpr (is_variable_length_v<decltype(K_)>) {
K = K_.total_length;
}
auto [Q, K, D, HB] = args.problem_shape;
auto params_kq = CollectiveMmaKQ::to_underlying_arguments(
make_shape(K, Q, D, HB),
@ -389,7 +378,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
TMA_DQ tma_red_dq = make_tma_copy(
SM90_TMA_REDUCE_ADD{},
make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q_, D, HB), args.mainloop.stride_dq_acc),
make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q, D, HB), args.mainloop.stride_dq_acc),
SmemLayoutDQ{}(_, _, _0{})
);
@ -427,11 +416,10 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
}
template<class BlkCoord, class BlkOffset, class ProblemShape_>
template<class BlkCoord>
CUTLASS_DEVICE void load(
BlkCoord const& blk_coord,
BlkOffset const& blk_offset,
ProblemShape_ const& problem_shape,
ProblemShape const& problem_shape,
int iter_index,
int iter_count,
MainloopArguments const& mainloop_args,
@ -452,15 +440,10 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
uint16_t mcast_mask = 0;
auto mK_in = mainloop_params.tma_load_k.get_tma_tensor(make_shape(K, D, HB));
auto mV_in = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D, HB));
auto mQ_in = mainloop_params.tma_load_q.get_tma_tensor(make_shape(Q, D, HB));
auto mDO_in = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D, HB));
auto mK = domain_offset(select<1,2,3>(blk_offset), mK_in);
auto mV = domain_offset(select<1,2,3>(blk_offset), mV_in);
auto mQ = domain_offset(select<0,2,3>(blk_offset), mQ_in);
auto mDO = domain_offset(select<0,2,3>(blk_offset), mDO_in);
auto mK = mainloop_params.tma_load_k.get_tma_tensor(make_shape(K, D, HB));
auto mQ = mainloop_params.tma_load_q.get_tma_tensor(make_shape(Q, D, HB));
auto mV = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D, HB));
auto mDO = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D, HB));
auto gK = local_tile(mK, TileShapeKQ{}, make_coord(_,_,_), Step<_1, X, _1>{});
auto gQ = local_tile(mQ, TileShapeKQ{}, make_coord(_,_,_), Step<X, _1, _1>{});
@ -495,7 +478,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
// set up lse and sum_odo
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch] = blk_coord;
auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord;
pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state);
auto tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state);
@ -532,13 +515,11 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
int smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4;
int gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
auto mLSE = make_tensor(mainloop_args.ptr_lse, make_shape(Q, HB), mainloop_args.stride_lse);
for (int i = 0; i < 4; i++) {
cutlass::arch::cp_async_zfill<4>(
shared_tensors.smem_lse.begin() + smem_idx + i,
&mLSE(gmem_idx + i, blk_coord_batch),
gmem_idx + i < Q
);
}
cutlass::arch::cp_async_zfill<16>(
shared_tensors.smem_lse.begin() + smem_idx,
&mLSE(gmem_idx, blk_coord_batch),
gmem_idx < Q
);
pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive);
++pipeline_load_compute_lse_producer_state;
@ -575,13 +556,11 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4;
gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
auto mSumOdO = make_tensor(mainloop_args.ptr_sum_odo, make_shape(Q, HB), mainloop_args.stride_sum_odo);
for (int i = 0; i < 4; i++) {
cutlass::arch::cp_async_zfill<4>(
shared_tensors.smem_sum_odo.begin() + smem_idx + i,
&mSumOdO(gmem_idx + i, blk_coord_batch),
gmem_idx + i < Q
);
}
cutlass::arch::cp_async<16>(
shared_tensors.smem_sum_odo.begin() + smem_idx,
&mSumOdO(gmem_idx, blk_coord_batch),
gmem_idx < Q
);
pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive);
++pipeline_load_compute_sum_odo_producer_state;
@ -609,13 +588,11 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
// load LSE
smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4;
gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
for (int i = 0; i < 4; i++) {
cutlass::arch::cp_async_zfill<4>(
shared_tensors.smem_lse.begin() + smem_idx + i,
&mLSE(gmem_idx + i, blk_coord_batch),
gmem_idx + i < Q
);
}
cutlass::arch::cp_async<16>(
shared_tensors.smem_lse.begin() + smem_idx,
&mLSE(gmem_idx, blk_coord_batch),
gmem_idx < Q
);
pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive);
++pipeline_load_compute_lse_producer_state;
@ -639,13 +616,11 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
// load sum_OdO
smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4;
gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
for (int i = 0; i < 4; i++) {
cutlass::arch::cp_async_zfill<4>(
shared_tensors.smem_sum_odo.begin() + smem_idx + i,
&mSumOdO(gmem_idx + i, blk_coord_batch),
gmem_idx + i < Q
);
}
cutlass::arch::cp_async_zfill<16>(
shared_tensors.smem_sum_odo.begin() + smem_idx,
&mSumOdO(gmem_idx, blk_coord_batch),
gmem_idx < Q
);
pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive);
++pipeline_load_compute_sum_odo_producer_state;
@ -656,10 +631,10 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
}
template<class BlkCoord, class ProblemShape_>
template<class BlkCoord>
CUTLASS_DEVICE void mma(
BlkCoord const& blk_coord,
ProblemShape_ const& problem_shape,
ProblemShape const& problem_shape,
int iter_index,
int iter_count,
MainloopArguments const& mainloop_args,
@ -948,8 +923,6 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
TensorC const& coord,
TensorShape const& tensor_shape) {
Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); });
auto copy_op = make_cotiled_copy(
Copy_Atom<UniversalCopy<uint128_t>, Element>{},
make_layout(make_shape(_1{}, Int<sizeof(uint128_t) / sizeof(Element)>{})),
@ -957,91 +930,42 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
);
auto thr_copy = copy_op.get_slice(_0{});
Tensor tCg = thr_copy.partition_D(gmem);
Tensor tCr = thr_copy.partition_S(quantize(regs));
Tensor tPc = thr_copy.partition_D(preds);
auto tCg = thr_copy.partition_D(gmem);
auto tCr = thr_copy.partition_S(quantize(regs));
auto tCc = thr_copy.partition_D(coord);
copy_if(copy_op, tPc, tCr, tCg);
}
constexpr int R = decltype(tCr.layout())::rank;
auto tCg_v = group_modes<1, R>(tCg);
auto tCr_v = group_modes<1, R>(tCr);
auto tCc_v = group_modes<1, R>(tCc);
auto tCp_v = make_tensor<bool>(shape<1>(tCc_v));
template<class BlkCoord, class BlkOffset, class ProblemShape_>
CUTLASS_DEVICE void epilogue_clear(
BlkCoord const& blk_coord,
BlkOffset const& blk_offset,
ProblemShape_ const& problem_shape,
MainloopArguments const& mainloop_args,
EpilogueArguments const& epilogue_args) {
auto [Q, K, D, HB] = problem_shape;
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch] = blk_coord;
auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk);
auto mDK = domain_offset(select<1,2,3>(blk_offset), mDK_in);
auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{})
(_, _, blk_coord_k, _0{}, blk_coord_batch);
Tensor cDK = domain_offset(
make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}),
make_identity_tensor(take<0,2>(TileShapeDSQ{}))
);
auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv);
auto mDV = domain_offset(select<1,2,3>(blk_offset), mDV_in);
auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{})
(_, _, blk_coord_k, _0{}, blk_coord_batch);
Tensor cDV = domain_offset(
make_coord(blk_coord_k * TileShapeK{}, _0{}),
make_identity_tensor(take<0,2>(TileShapePDO{}))
);
if (threadIdx.x >= 256) {
return;
for (int i = 0; i < size(tCp_v); ++i) {
tCp_v(i) = elem_less(tCc_v(_0{},i), tensor_shape);
}
auto tiled_copy = make_cotiled_copy(
Copy_Atom<UniversalCopy<uint128_t>, Element>{},
make_ordered_layout(make_shape(_256{}, Int<sizeof(uint128_t) / sizeof(Element)>{}), Step<_1, _0>{}),
make_ordered_layout(make_shape(TileShapeK{}, TileShapeDQK{}), Step<_1, _0>{}));
auto thr_copy = tiled_copy.get_slice(threadIdx.x);
auto tCgDK = thr_copy.partition_D(gDK);
auto tCcDK = thr_copy.partition_S(cDK);
auto tCrDK = make_tensor<Element>(shape(tCcDK));
clear(tCrDK);
store(tCgDK, tCrDK, tCcDK, select<1,2>(problem_shape));
auto tCgDV = thr_copy.partition_D(gDV);
auto tCcDV = thr_copy.partition_S(cDV);
auto tCrDV = make_tensor<Element>(shape(tCcDV));
clear(tCrDV);
store(tCgDV, tCrDV, tCcDV, select<1,2>(problem_shape));
copy_if(copy_op, tCp_v, tCr_v, tCg_v);
}
template<class BlkCoord, class BlkOffset, class ProblemShape_>
template<class BlkCoord>
CUTLASS_DEVICE void epilogue(
BlkCoord const& blk_coord,
BlkOffset const& blk_offset,
ProblemShape_ const& problem_shape,
ProblemShape const& problem_shape,
MainloopArguments const& mainloop_args,
EpilogueArguments const& epilogue_args,
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) {
auto [Q, K, D, HB] = problem_shape;
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch] = blk_coord;
auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord;
auto load_op = SM100_TMEM_LOAD_32dp32b16x{};
auto tDKtDK = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{});
tDKtDK.data() = TmemAllocation::kDK;
auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk);
auto mDK = domain_offset(select<1,2,3>(blk_offset), mDK_in);
auto mDK = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk);
auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{})
(_, _, blk_coord_k, _0{}, blk_coord_batch);
@ -1076,13 +1000,12 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
auto tDVtDV = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{});
tDVtDV.data() = TmemAllocation::kDV;
auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv);
auto mDV = domain_offset(select<1,2,3>(blk_offset), mDV_in);
auto mDV = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv);
auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{})
(_, _, blk_coord_k, _0{}, blk_coord_batch);
Tensor cDV = domain_offset(
make_coord(blk_coord_k * TileShapeK{}, _0{}),
make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}),
make_identity_tensor(take<0,2>(TileShapePDO{}))
);
@ -1126,11 +1049,10 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
}
template<class BlkCoord, class BlkOffset, class ProblemShape_>
template<class BlkCoord>
CUTLASS_DEVICE void compute(
BlkCoord const& blk_coord,
BlkOffset const& blk_offset,
ProblemShape_ const& problem_shape,
ProblemShape const& problem_shape,
int iter_index,
int iter_count,
MainloopArguments const& mainloop_args,
@ -1214,9 +1136,6 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
auto tRT_tP = split_wg(thread_r2t.partition_D(tDVrP));
auto tRT_cST = split_wg(thread_r2t.partition_S(tDVcST));
bool is_residual_k = get<1>(blk_coord) * TileShapeK{} + TileShapeK{} >= get<1>(problem_shape);
int last_iter = iter_count - 1 + iter_index;
CUTLASS_PRAGMA_NO_UNROLL
while (iter_count > 0) {
// wait for S and P
@ -1234,21 +1153,13 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
}
};
bool leading_causal_masking = false;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask, Mask>) {
leading_causal_masking = warp_uniform(iter_index == get<1>(blk_coord));
}
bool trailing_residual_masking = false;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::ResidualMaskForBackward, Mask>) {
trailing_residual_masking = warp_uniform((iter_index == last_iter) || is_residual_k);
}
dispatch_bool(leading_causal_masking || trailing_residual_masking, [&](auto is_masked_tile) {
dispatch_bool(std::is_base_of_v<cutlass::fmha::collective::CausalMask, Mask> &&
warp_uniform(iter_index == get<1>(blk_coord)), [&](auto is_causal_masked_tile) {
// compute P = softmax(S, LSE)
cute::copy(tiled_t2r, tTR_tST, tTR_rST);
if constexpr (decltype(is_masked_tile)::value) {
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask, Mask> && decltype(is_causal_masked_tile)::value) {
Mask{}.apply_mask(tTR_rST, [&](int i) {
auto c_transpose = tTR_cST(i);
return make_coord(get<1>(c_transpose) + iter_index * TileShapeQ{}, get<0>(c_transpose) + get<1>(blk_coord) * TileShapeK{});
@ -1364,15 +1275,15 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
}
epilogue(
blk_coord, blk_offset, problem_shape, mainloop_args, epilogue_args,
blk_coord, problem_shape, mainloop_args, epilogue_args,
pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state
);
}
template<class BlkCoord, class ProblemShape_>
template<class BlkCoord>
CUTLASS_DEVICE void reduce(
BlkCoord const& blk_coord,
ProblemShape_ const& problem_shape,
ProblemShape const& problem_shape,
int iter_index,
int iter_count,
MainloopArguments const& mainloop_args,
@ -1387,7 +1298,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
auto [Q, K, D, HB] = problem_shape;
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch] = blk_coord;
auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord;
// must match TileShapeDQ
auto load_op = SM100_TMEM_LOAD_32dp32b32x{};
@ -1396,7 +1307,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
tDQtDQ.data() = TmemAllocation::kDQ;
Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB));
auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step<X, _1, _1>{})
auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step<_1, _1, X>{})
(_, _, _, _0{}, blk_coord_batch);
Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{}));
@ -1665,38 +1576,20 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
pipeline_init_wait(size(ClusterShape{}));
auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, make_coord(blockIdx.y, blockIdx.z));
auto [problem_shape, blk_offset] = apply_variable_length_offset(
params.problem_shape,
blk_coord
);
auto blk_coord = make_coord(_0{}, blockIdx.x, make_coord(blockIdx.y, blockIdx.z));
auto problem_shape = params.problem_shape;
int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{});
int iter_start = 0;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask, Mask>) {
iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{};
}
if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) {
return;
}
iter_count -= iter_start;
if (iter_count <= 0) {
epilogue_clear(
blk_coord,
blk_offset,
problem_shape,
params.mainloop,
params.epilogue
);
return;
}
if (role == WarpRole::Load) {
warpgroup_reg_set<RegisterAllocation::kLoad>();
load(
blk_coord,
blk_offset,
problem_shape,
iter_start,
iter_count,
@ -1739,7 +1632,6 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
compute(
blk_coord,
blk_offset,
problem_shape,
iter_start,
iter_count,

View File

@ -356,7 +356,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineO>();
CollectiveMainloop mainloop;
CollectiveEpilogue epilogue{params.epilogue};
CollectiveEpilogue epilogue;
if (role == WarpRole::Softmax0 || role == WarpRole::Softmax1) {
warpgroup_reg_set<NumRegsSoftmax>();
@ -407,8 +407,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
pipeline_s0_corr, pipeline_s0_corr_consumer_state,
pipeline_s1_corr, pipeline_s1_corr_consumer_state,
pipeline_mma_corr, pipeline_mma_corr_consumer_state,
pipeline_corr_epi, pipeline_corr_epi_producer_state,
epilogue
pipeline_corr_epi, pipeline_corr_epi_producer_state
);

View File

@ -146,7 +146,7 @@ struct Sm100FmhaMlaReductionKernel {
ElementAcc sum_lse = 0;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNLsePerThread; ++i) {
sum_lse = sum_lse + expf(local_lse[i] - lse_max);
sum_lse = sum_lse + expf(local_lse[i] - params.scale * lse_max);
}
CUTLASS_PRAGMA_UNROLL
@ -156,7 +156,7 @@ struct Sm100FmhaMlaReductionKernel {
sum_lse = __shfl_sync(0xffffffff, sum_lse, 0);
ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementAcc>::infinity() : logf(sum_lse) + lse_max;
ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementAcc>::infinity() : logf(sum_lse) + params.scale * lse_max;
if (threadIdx.x == 0 and params.ptr_lse != nullptr) {
gLSE(0) = global_lse;
}

View File

@ -44,10 +44,10 @@ template<
class Fusion
>
void __global__ fmha_bwd_reference_dQ_kernel(
ProblemShape problem_shape_in,
TensorQ mQ_in, TensorK mK_in, TensorV mV_in,
TensorO mO_in, TensorLSE mLSE_in, TensorDO mDO_in,
TensorDQ mDQ_in, /* TensorDK mDK, TensorDV mDV, */
ProblemShape problem_shape,
TensorQ mQ, TensorK mK, TensorV mV,
TensorO mO, TensorLSE mLSE, TensorDO mDO,
TensorDQ mDQ, /* TensorDK mDK, TensorDV mDV, */
Fusion fusion) {
using namespace cute;
@ -58,28 +58,15 @@ void __global__ fmha_bwd_reference_dQ_kernel(
extern __shared__ char mS_mem[];
Element* mS = reinterpret_cast<Element*>(mS_mem);
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<2>(problem_shape_in)));
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<1>(mO)));
for (int idx_L = blockIdx.y; idx_L < size<3>(problem_shape_in); idx_L += gridDim.y) {
auto [problem_shape, offset] = apply_variable_length_offset(
problem_shape_in,
make_coord(_0{}, _0{}, _0{}, idx2crd(idx_L, get<3>(problem_shape_in)))
);
// problem_shape = problem_shape_in;
// offset = repeat_like(problem_shape_in, _0{});
auto mQ = domain_offset(select<0,2,3>(offset), mQ_in);
auto mK = domain_offset(select<1,2,3>(offset), mK_in);
auto mV = domain_offset(select<1,2,3>(offset), mV_in);
auto mO = domain_offset(select<0,2,3>(offset), mO_in);
auto mLSE = domain_offset(select<0,3>(offset), mLSE_in);
auto mDO = domain_offset(select<0,2,3>(offset), mDO_in);
auto mDQ = domain_offset(select<0,2,3>(offset), mDQ_in);
for (int idx_Q = blockIdx.x; idx_Q < size<0>(problem_shape); idx_Q += gridDim.x) {
for (int idx_K = threadIdx.x; idx_K < size<1>(problem_shape); idx_K += blockDim.x) {
for (int idx_L = blockIdx.y; idx_L < size<2>(mDQ); idx_L += gridDim.y) {
for (int idx_Q = blockIdx.x; idx_Q < size<0>(mDQ); idx_Q += gridDim.x) {
for (int idx_K = threadIdx.x; idx_K < size<0>(mK); idx_K += blockDim.x) {
ElementAccumulator acc_qk = 0;
ElementAccumulator acc_dov = 0;
ElementAccumulator acc_doo = 0;
for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) {
for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) {
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L);
@ -96,9 +83,9 @@ void __global__ fmha_bwd_reference_dQ_kernel(
__syncthreads();
for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) {
for (int idx_D = threadIdx.x; idx_D < size<1>(mDQ); idx_D += blockDim.x) {
ElementAccumulator acc = 0;
for (int idx_K = 0; idx_K < size<1>(problem_shape); idx_K++) {
for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) {
acc += mS[idx_K] * mK(idx_K, idx_D, idx_L);
}
mDQ(idx_Q, idx_D, idx_L) = static_cast<typename TensorDQ::value_type>(acc);
@ -117,10 +104,10 @@ template<
class Fusion
>
void __global__ fmha_bwd_reference_dK_kernel(
ProblemShape problem_shape_in,
TensorQ mQ_in, TensorK mK_in, TensorV mV_in,
TensorO mO_in, TensorLSE mLSE_in, TensorDO mDO_in,
/* TensorDQ mDQ_in, */ TensorDK mDK_in, /* TensorDV mDV_in, */
ProblemShape problem_shape,
TensorQ mQ, TensorK mK, TensorV mV,
TensorO mO, TensorLSE mLSE, TensorDO mDO,
/* TensorDQ mDQ, */ TensorDK mDK, /* TensorDV mDV, */
Fusion fusion) {
using namespace cute;
@ -131,28 +118,15 @@ void __global__ fmha_bwd_reference_dK_kernel(
extern __shared__ char mS_mem[];
Element* mS = reinterpret_cast<Element*>(mS_mem);
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<2>(problem_shape_in)));
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<1>(mO)));
for (int idx_L = blockIdx.y; idx_L < size<3>(problem_shape_in); idx_L += gridDim.y) {
auto [problem_shape, offset] = apply_variable_length_offset(
problem_shape_in,
make_coord(_0{}, _0{}, _0{}, idx2crd(idx_L, get<3>(problem_shape_in)))
);
// problem_shape = problem_shape_in;
// offset = repeat_like(problem_shape_in, _0{});
auto mQ = domain_offset(select<0,2,3>(offset), mQ_in);
auto mK = domain_offset(select<1,2,3>(offset), mK_in);
auto mV = domain_offset(select<1,2,3>(offset), mV_in);
auto mO = domain_offset(select<0,2,3>(offset), mO_in);
auto mLSE = domain_offset(select<0,3>(offset), mLSE_in);
auto mDO = domain_offset(select<0,2,3>(offset), mDO_in);
auto mDK = domain_offset(select<1,2,3>(offset), mDK_in);
for (int idx_K = blockIdx.x; idx_K < size<1>(problem_shape); idx_K += gridDim.x) {
for (int idx_Q = threadIdx.x; idx_Q < size<0>(problem_shape); idx_Q += blockDim.x) {
for (int idx_L = blockIdx.y; idx_L < size<2>(mDK); idx_L += gridDim.y) {
for (int idx_K = blockIdx.x; idx_K < size<0>(mDK); idx_K += gridDim.x) {
for (int idx_Q = threadIdx.x; idx_Q < size<0>(mDO); idx_Q += blockDim.x) {
ElementAccumulator acc_qk = 0;
ElementAccumulator acc_dov = 0;
ElementAccumulator acc_doo = 0;
for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) {
for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) {
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L);
@ -169,9 +143,9 @@ void __global__ fmha_bwd_reference_dK_kernel(
__syncthreads();
for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) {
for (int idx_D = threadIdx.x; idx_D < size<1>(mDK); idx_D += blockDim.x) {
ElementAccumulator acc = 0;
for (int idx_Q = 0; idx_Q < size<0>(problem_shape); idx_Q++) {
for (int idx_Q = 0; idx_Q < size<0>(mDO); idx_Q++) {
acc += mS[idx_Q] * mQ(idx_Q, idx_D, idx_L);
}
mDK(idx_K, idx_D, idx_L) = static_cast<typename TensorDK::value_type>(acc);
@ -190,10 +164,10 @@ template<
class Fusion
>
void __global__ fmha_bwd_reference_dV_kernel(
ProblemShape problem_shape_in,
TensorQ mQ_in, TensorK mK_in, TensorV mV_in,
TensorO mO_in, TensorLSE mLSE_in, TensorDO mDO_in,
/* TensorDQ mDQ_in, TensorDK mDK_in, */ TensorDV mDV_in,
ProblemShape problem_shape,
TensorQ mQ, TensorK mK, TensorV mV,
TensorO mO, TensorLSE mLSE, TensorDO mDO,
/* TensorDQ mDQ, TensorDK mDK, */ TensorDV mDV,
Fusion fusion) {
using namespace cute;
@ -204,27 +178,14 @@ void __global__ fmha_bwd_reference_dV_kernel(
extern __shared__ char mS_mem[];
Element* mS = reinterpret_cast<Element*>(mS_mem);
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<2>(problem_shape_in)));
ElementAcc softmax_scale = static_cast<ElementAcc>(1.0 / sqrt(1.0 * size<1>(mO)));
for (int idx_L = blockIdx.y; idx_L < size<3>(problem_shape_in); idx_L += gridDim.y) {
auto [problem_shape, offset] = apply_variable_length_offset(
problem_shape_in,
make_coord(_0{}, _0{}, _0{}, idx2crd(idx_L, get<3>(problem_shape_in)))
);
// problem_shape = problem_shape_in;
// offset = repeat_like(problem_shape_in, _0{});
auto mQ = domain_offset(select<0,2,3>(offset), mQ_in);
auto mK = domain_offset(select<1,2,3>(offset), mK_in);
auto mV = domain_offset(select<1,2,3>(offset), mV_in);
auto mO = domain_offset(select<0,2,3>(offset), mO_in);
auto mLSE = domain_offset(select<0,3>(offset), mLSE_in);
auto mDO = domain_offset(select<0,2,3>(offset), mDO_in);
auto mDV = domain_offset(select<1,2,3>(offset), mDV_in);
for (int idx_K = blockIdx.x; idx_K < size<1>(problem_shape); idx_K += gridDim.x) {
for (int idx_Q = threadIdx.x; idx_Q < size<0>(problem_shape); idx_Q += blockDim.x) {
for (int idx_L = blockIdx.y; idx_L < size<2>(mDV); idx_L += gridDim.y) {
for (int idx_K = blockIdx.x; idx_K < size<0>(mDV); idx_K += gridDim.x) {
for (int idx_Q = threadIdx.x; idx_Q < size<0>(mDO); idx_Q += blockDim.x) {
ElementAcc acc_qk = 0;
for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) {
for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) {
ElementAcc rQ = mQ(idx_Q, idx_D0, idx_L);
ElementAcc rK = mK(idx_K, idx_D0, idx_L);
acc_qk += rQ * rK;
@ -241,9 +202,9 @@ void __global__ fmha_bwd_reference_dV_kernel(
__syncthreads();
for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) {
for (int idx_D = threadIdx.x; idx_D < size<1>(mDV); idx_D += blockDim.x) {
ElementAcc acc = 0;
for (int idx_Q = 0; idx_Q < size<0>(problem_shape); idx_Q++) {
for (int idx_Q = 0; idx_Q < size<0>(mDO); idx_Q++) {
ElementAcc rS = mS[idx_Q];
ElementAcc rDO = mDO(idx_Q, idx_D, idx_L);
acc += rS * rDO;

View File

@ -127,7 +127,7 @@ void __global__ fmha_reference_kernel(
mO(idx_Q + offset_Q, idx_D, idx_L) = static_cast<typename TensorO::value_type>(acc * scale);
}
if (threadIdx.x == 0 && mLSE.data() != nullptr) {
if (threadIdx.x == 0) {
mLSE(idx_Q + offset_Q, idx_L) = log(sum) + softmax_scale * maxS;
}

View File

@ -75,8 +75,6 @@ struct DeviceAllocation {
size_t size() const { return size_; }
size_t get_storage_size() const { return (size_ + offset_) * sizeof(T); }
void copy_from_host(const T* ptr, size_t sz) {
auto ret = cudaMemcpy(ptr_, ptr, sz * sizeof(T), cudaMemcpyDefault);
assert(ret == cudaSuccess);