This reverts commit f12b1d75c9.
This commit is contained in:
@ -117,17 +117,15 @@ struct Options {
|
||||
int q = 256;
|
||||
int k = 256;
|
||||
int d = 128;
|
||||
int warmup_iterations = 1;
|
||||
int iterations = 3;
|
||||
int tensor_ring_buffers = 1;
|
||||
bool verify = false;
|
||||
bool verbose = false;
|
||||
|
||||
bool causal = false;
|
||||
bool residual = false;
|
||||
bool varlen = false;
|
||||
bool persistent = false;
|
||||
int sm_count = 0;
|
||||
|
||||
std::string kernel_filter;
|
||||
|
||||
InitStyle init_style_q = InitStyle::kRandom;
|
||||
@ -191,15 +189,10 @@ struct Options {
|
||||
if (b == -1) b = 16384 / k;
|
||||
if (b == 0) b = 1;
|
||||
|
||||
cmd.get_cmd_line_argument("warmup_iterations", warmup_iterations, defaults.warmup_iterations);
|
||||
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
|
||||
cmd.get_cmd_line_argument("tensor_ring_buffers", tensor_ring_buffers, defaults.tensor_ring_buffers);
|
||||
|
||||
verify = cmd.check_cmd_line_flag("verify");
|
||||
verbose = cmd.check_cmd_line_flag("verbose");
|
||||
varlen = cmd.check_cmd_line_flag("varlen");
|
||||
persistent = cmd.check_cmd_line_flag("persistent");
|
||||
|
||||
std::string mask;
|
||||
cmd.get_cmd_line_argument<std::string>("mask", mask, "");
|
||||
if (mask == "no" || mask == "") {
|
||||
@ -217,7 +210,7 @@ struct Options {
|
||||
causal = false;
|
||||
}
|
||||
cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count);
|
||||
|
||||
|
||||
get_init_style_argument(cmd, "init-style", init_style_q, defaults.init_style_q);
|
||||
get_init_style_argument(cmd, "init-style", init_style_k, defaults.init_style_q);
|
||||
get_init_style_argument(cmd, "init-style", init_style_v, defaults.init_style_q);
|
||||
@ -242,13 +235,10 @@ struct Options {
|
||||
<< " --q=<int> Sets the Q extent\n"
|
||||
<< " --k=<int> Sets the K extent\n"
|
||||
<< " --d=<int> Sets the D extentn"
|
||||
<< " --tensor_ring_buffers=<int> Sets the number of tensor ring buffers\n"
|
||||
<< " --warmup_iterations=<int> Sets the warmup iterations\n"
|
||||
<< " --iterations=<int> Benchmarking iterations\n"
|
||||
<< " --verify Verify results\n"
|
||||
<< " --verbose Print smem and execution time per kernel\n"
|
||||
<< " --mask=<no|residual|causal> Enables masking\n"
|
||||
<< " --persistent Enables persistent scheduler\n"
|
||||
<< " --varlen Enables variable sequence length\n"
|
||||
<< " B*Q and B*K become the total sequence length\n"
|
||||
<< " and are split B-ways, alternatingly +10% and -10%\n"
|
||||
@ -389,55 +379,40 @@ struct FwdRunner {
|
||||
StrideLSE stride_LSE;
|
||||
uint64_t seed = 0;
|
||||
|
||||
struct DeviceBuffer {
|
||||
DeviceAllocation<Element> block_Q;
|
||||
DeviceAllocation<Element> block_K;
|
||||
DeviceAllocation<Element> block_V;
|
||||
DeviceAllocation<ElementOut> block_O;
|
||||
DeviceAllocation<ElementAccumulatorPV> block_LSE;
|
||||
DeviceAllocation<ElementOut> block_ref_O;
|
||||
DeviceAllocation<ElementAccumulatorPV> block_ref_LSE;
|
||||
DeviceAllocation<int> device_cumulative_seqlen_q;
|
||||
DeviceAllocation<int> device_cumulative_seqlen_kv;
|
||||
|
||||
DeviceBuffer() = default;
|
||||
DeviceBuffer(const DeviceBuffer&) = delete;
|
||||
DeviceBuffer& operator=(const DeviceBuffer&) = delete;
|
||||
|
||||
size_t get_storage_size() const {
|
||||
return block_Q.get_storage_size() + block_K.get_storage_size() + block_V.get_storage_size()
|
||||
+ block_O.get_storage_size() + block_LSE.get_storage_size() + block_ref_O.get_storage_size()
|
||||
+ block_ref_LSE.get_storage_size() + device_cumulative_seqlen_q.get_storage_size()
|
||||
+ device_cumulative_seqlen_kv.get_storage_size();
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<std::unique_ptr<DeviceBuffer>> buffers;
|
||||
DeviceAllocation<Element> block_Q;
|
||||
DeviceAllocation<Element> block_K;
|
||||
DeviceAllocation<Element> block_V;
|
||||
DeviceAllocation<ElementOut> block_O;
|
||||
DeviceAllocation<ElementAccumulatorPV> block_LSE;
|
||||
DeviceAllocation<ElementOut> block_ref_O;
|
||||
DeviceAllocation<ElementAccumulatorPV> block_ref_LSE;
|
||||
|
||||
std::vector<int> cumulative_seqlen_q;
|
||||
std::vector<int> cumulative_seqlen_kv;
|
||||
DeviceAllocation<int> device_cumulative_seqlen_q;
|
||||
DeviceAllocation<int> device_cumulative_seqlen_kv;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
bool verify(const ProblemShapeType& problem_shape, DeviceBuffer& buffer) {
|
||||
Tensor mQ = make_tensor(make_gmem_ptr(buffer.block_Q.get()),
|
||||
bool verify(const ProblemShapeType& problem_shape) {
|
||||
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()),
|
||||
select<0,2,3>(problem_shape),
|
||||
stride_Q);
|
||||
|
||||
Tensor mK = make_tensor(make_gmem_ptr(buffer.block_K.get()),
|
||||
Tensor mK = make_tensor(make_gmem_ptr(block_K.get()),
|
||||
select<1,2,3>(problem_shape),
|
||||
stride_K);
|
||||
|
||||
Tensor mV = make_tensor(make_gmem_ptr(buffer.block_V.get()),
|
||||
Tensor mV = make_tensor(make_gmem_ptr(block_V.get()),
|
||||
select<1,2,3>(problem_shape),
|
||||
stride_V);
|
||||
|
||||
Tensor mO = make_tensor(make_gmem_ptr(buffer.block_ref_O.get()),
|
||||
Tensor mO = make_tensor(make_gmem_ptr(block_ref_O.get()),
|
||||
select<0,2,3>(problem_shape),
|
||||
stride_O);
|
||||
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(buffer.block_ref_LSE.get()),
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(block_ref_LSE.get()),
|
||||
select<0,3>(problem_shape),
|
||||
stride_LSE);
|
||||
|
||||
@ -456,7 +431,7 @@ struct FwdRunner {
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
double max_diff = 0;
|
||||
double mean_diff = 0;
|
||||
reference_abs_diff(buffer.block_O, buffer.block_ref_O, max_diff, mean_diff);
|
||||
reference_abs_diff(block_O, block_ref_O, max_diff, mean_diff);
|
||||
|
||||
bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
if (! passed_O) {
|
||||
@ -464,13 +439,14 @@ struct FwdRunner {
|
||||
<< " mean " << mean_diff << std::endl;
|
||||
}
|
||||
|
||||
reference_abs_diff(buffer.block_LSE, buffer.block_ref_LSE, max_diff, mean_diff);
|
||||
// reference_abs_diff(block_LSE, block_ref_LSE, max_diff, mean_diff);
|
||||
|
||||
bool passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
if ( ! passed_LSE) {
|
||||
std::cerr << "failed LSE: max diff " << max_diff
|
||||
<< " mean " << mean_diff << std::endl;
|
||||
}
|
||||
bool passed_LSE = true; // future work
|
||||
// bool passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
// if ( ! passed_LSE) {
|
||||
// std::cerr << "failed LSE: max diff " << max_diff
|
||||
// << " mean " << mean_diff << std::endl;
|
||||
// }
|
||||
|
||||
return passed_O && passed_LSE;
|
||||
}
|
||||
@ -583,71 +559,50 @@ struct FwdRunner {
|
||||
get<1,1>(stride_LSE) = 0;
|
||||
}
|
||||
|
||||
auto buffer_init_fn = [&](auto& buffer) {
|
||||
buffer.block_Q.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
|
||||
buffer.block_K.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
|
||||
buffer.block_V.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
|
||||
buffer.block_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
|
||||
buffer.block_LSE.reset(size(shape_LSE));
|
||||
block_Q.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
|
||||
block_K.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
|
||||
block_V.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
|
||||
block_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
|
||||
block_LSE.reset(size(shape_LSE));
|
||||
block_ref_O.reset(size(shape_QO));
|
||||
block_ref_LSE.reset(size(shape_LSE));
|
||||
|
||||
initialize_block(buffer.block_Q, seed + 2023, options.init_style_q);
|
||||
initialize_block(buffer.block_K, seed + 2022, options.init_style_k);
|
||||
initialize_block(buffer.block_V, seed + 2021, options.init_style_v);
|
||||
initialize_block(block_Q, seed + 2023, options.init_style_q);
|
||||
initialize_block(block_K, seed + 2022, options.init_style_k);
|
||||
initialize_block(block_V, seed + 2021, options.init_style_v);
|
||||
|
||||
if ( ! cumulative_seqlen_q.empty()) {
|
||||
buffer.device_cumulative_seqlen_q.reset(cumulative_seqlen_q.size());
|
||||
buffer.device_cumulative_seqlen_q.copy_from_host(
|
||||
cumulative_seqlen_q.data(), cumulative_seqlen_q.size());
|
||||
}
|
||||
if ( ! cumulative_seqlen_kv.empty()) {
|
||||
buffer.device_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size());
|
||||
buffer.device_cumulative_seqlen_kv.copy_from_host(
|
||||
cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size());
|
||||
}
|
||||
};
|
||||
|
||||
buffers.push_back(std::make_unique<DeviceBuffer>());
|
||||
buffer_init_fn(*buffers.back());
|
||||
|
||||
int tensor_ring_buffers = options.tensor_ring_buffers;
|
||||
|
||||
for (int i = 1; i < tensor_ring_buffers; i++) {
|
||||
buffers.push_back(std::make_unique<DeviceBuffer>());
|
||||
buffer_init_fn(*buffers.back());
|
||||
if ( ! cumulative_seqlen_q.empty()) {
|
||||
device_cumulative_seqlen_q.reset(cumulative_seqlen_q.size());
|
||||
device_cumulative_seqlen_q.copy_from_host(
|
||||
cumulative_seqlen_q.data(), cumulative_seqlen_q.size());
|
||||
}
|
||||
if ( ! cumulative_seqlen_kv.empty()) {
|
||||
device_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size());
|
||||
device_cumulative_seqlen_kv.copy_from_host(
|
||||
cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size());
|
||||
}
|
||||
|
||||
if constexpr (kIsVarlen) {
|
||||
get<0>(problem_shape).cumulative_length = buffers[0]->device_cumulative_seqlen_q.get();
|
||||
get<1>(problem_shape).cumulative_length = buffers[0]->device_cumulative_seqlen_kv.get();
|
||||
get<0>(problem_shape).cumulative_length = device_cumulative_seqlen_q.get();
|
||||
get<1>(problem_shape).cumulative_length = device_cumulative_seqlen_kv.get();
|
||||
}
|
||||
|
||||
return problem_shape;
|
||||
}
|
||||
|
||||
auto get_arguments(const ProblemShapeType& problem_shape, const cutlass::KernelHardwareInfo& hw_info, int buffer_index) {
|
||||
auto problem_shape_ = problem_shape;
|
||||
if constexpr (kIsVarlen) {
|
||||
get<0>(problem_shape_).cumulative_length = buffers[buffer_index]->device_cumulative_seqlen_q.get();
|
||||
get<1>(problem_shape_).cumulative_length = buffers[buffer_index]->device_cumulative_seqlen_kv.get();
|
||||
}
|
||||
typename Operation::Arguments arguments{
|
||||
problem_shape_,
|
||||
{ buffers[buffer_index]->block_Q.get(), stride_Q,
|
||||
buffers[buffer_index]->block_K.get(), stride_K,
|
||||
buffers[buffer_index]->block_V.get(), stride_V },
|
||||
{ buffers[buffer_index]->block_O.get(), stride_O,
|
||||
buffers[buffer_index]->block_LSE.get(), stride_LSE },
|
||||
hw_info
|
||||
};
|
||||
return arguments;
|
||||
}
|
||||
|
||||
ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
|
||||
|
||||
ProblemShapeType problem_shape = initialize(options);
|
||||
|
||||
int buffer_index = 0;
|
||||
typename Operation::Arguments arguments = get_arguments(problem_shape, hw_info, buffer_index);
|
||||
typename Operation::Arguments arguments{
|
||||
problem_shape,
|
||||
{ block_Q.get(), stride_Q,
|
||||
block_K.get(), stride_K,
|
||||
block_V.get(), stride_V },
|
||||
{ block_O.get(), stride_O,
|
||||
block_LSE.get(), stride_LSE },
|
||||
hw_info
|
||||
};
|
||||
|
||||
Operation op;
|
||||
|
||||
@ -675,21 +630,11 @@ struct FwdRunner {
|
||||
}
|
||||
|
||||
// Run
|
||||
for (int i = 0; i < options.warmup_iterations; i++) {
|
||||
status = op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
buffer_index = (buffer_index + 1) % buffers.size();
|
||||
arguments = get_arguments(problem_shape, hw_info, buffer_index);
|
||||
status = op.update(arguments, workspace.get());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to update the CUTLASS kernel's parameters. Last CUDA error is: "
|
||||
<< std::endl;
|
||||
return example_result;
|
||||
}
|
||||
status = op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
@ -727,14 +672,6 @@ struct FwdRunner {
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
buffer_index = (buffer_index + 1) % buffers.size();
|
||||
arguments = get_arguments(problem_shape, hw_info, buffer_index);
|
||||
status = op.update(arguments, workspace.get());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to update the CUTLASS kernel's parameters. Last CUDA error is: "
|
||||
<< std::endl;
|
||||
return example_result;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
@ -797,10 +734,10 @@ struct FwdRunner {
|
||||
// Verify that the result is correct
|
||||
bool passed = true;
|
||||
if (options.verify) {
|
||||
passed = verify(problem_shape, *buffers[0]);
|
||||
passed = verify(problem_shape);
|
||||
if (passed) example_result.verified = true;
|
||||
}
|
||||
|
||||
|
||||
if (!passed) {
|
||||
std::cerr << "Reference check failed" << std::endl;
|
||||
return example_result;
|
||||
@ -852,14 +789,10 @@ void run_fwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareIn
|
||||
|
||||
using HeadDim = _128;
|
||||
|
||||
if (options.persistent) {
|
||||
// Persistent Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
|
||||
}
|
||||
else {
|
||||
// Individual Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
|
||||
}
|
||||
// Persistent Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
|
||||
// Individual Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -885,14 +818,10 @@ void run_fwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInf
|
||||
|
||||
using HeadDim = _64;
|
||||
|
||||
if (options.persistent) {
|
||||
// Persistent Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
|
||||
}
|
||||
else {
|
||||
// Individual Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
|
||||
}
|
||||
// Persistent Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
|
||||
// Individual Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
|
||||
}
|
||||
|
||||
|
||||
@ -916,14 +845,10 @@ void run_fwd_32(Mask fusion, Options const & options, cutlass::KernelHardwareInf
|
||||
using HeadDim = _32;
|
||||
|
||||
#ifdef FP8
|
||||
if (options.persistent) {
|
||||
// Persistent Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
|
||||
}
|
||||
else {
|
||||
// Individual Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
|
||||
}
|
||||
// Persistent Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
|
||||
// Individual Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -988,7 +913,6 @@ int main_single(int argc, char const **args) {
|
||||
hw_info.sm_count = options.sm_count;
|
||||
}
|
||||
|
||||
|
||||
std::cout << "###### B " << options.b << " H " << options.h << " H_K " << options.h_k << " Q " << options.q << " K " << options.k << " D " << options.d << " ";
|
||||
std::cout << "Forward" << " " << (options.causal ? "Causal" : (options.residual ? "Residual" : "None")) << " ";
|
||||
std::cout << "#SM " << hw_info.sm_count << std::endl;
|
||||
|
||||
@ -120,8 +120,6 @@ struct Options {
|
||||
bool verbose = false;
|
||||
|
||||
bool causal = false;
|
||||
bool residual = false;
|
||||
bool varlen = false;
|
||||
int sm_count = 0;
|
||||
|
||||
std::string kernel_filter;
|
||||
@ -192,21 +190,14 @@ struct Options {
|
||||
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
|
||||
verify = cmd.check_cmd_line_flag("verify");
|
||||
verbose = cmd.check_cmd_line_flag("verbose");
|
||||
varlen = cmd.check_cmd_line_flag("varlen");
|
||||
std::string mask;
|
||||
cmd.get_cmd_line_argument<std::string>("mask", mask, "");
|
||||
if (mask == "causal") {
|
||||
causal = true;
|
||||
}
|
||||
else if (mask == "residual") {
|
||||
residual = true;
|
||||
}
|
||||
else {
|
||||
causal = defaults.causal;
|
||||
}
|
||||
if (varlen) {
|
||||
residual = true;
|
||||
}
|
||||
|
||||
skip_reference = cmd.check_cmd_line_flag("skip-reference");
|
||||
cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count);
|
||||
@ -239,12 +230,7 @@ struct Options {
|
||||
<< " --iterations=<int> Benchmarking iterations\n"
|
||||
<< " --verify Verify results\n"
|
||||
<< " --verbose Print smem and execution time per kernel\n"
|
||||
<< " --mask=<no|residual|causal> Enables masking\n"
|
||||
<< " --varlen Enables variable sequence length\n"
|
||||
<< " B*Q and B*K become the total sequence length\n"
|
||||
<< " and are split B-ways, alternatingly +10% and -10%\n"
|
||||
<< " with the last batch sized to make it fit\n"
|
||||
<< " implies at least residual masking for correctness\n"
|
||||
<< " --mask=<no|causal> Enables masking\n"
|
||||
<< " --sm-count Sets SM count rather than querying it\n"
|
||||
<< " --kernel-filter=<filter> Sets regexp to match kernel against\n"
|
||||
<< "\n";
|
||||
@ -321,7 +307,6 @@ struct ExampleResult {
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
bool kIsVarlen,
|
||||
class TileShape,
|
||||
class DispatchPolicy,
|
||||
class ActiveMask,
|
||||
@ -337,11 +322,9 @@ struct BwdRunner {
|
||||
using ElementAccumulator = float;
|
||||
|
||||
// Q K D (H B)
|
||||
using ProblemShape = std::conditional_t<
|
||||
kIsVarlen,
|
||||
cute::tuple<VariableLength, VariableLength, int, cute::tuple<int, int>>,
|
||||
cute::tuple<int, int, int, cute::tuple<int, int>>
|
||||
>;
|
||||
using ProblemShapeType = cute::tuple<int, int, int, cute::tuple<int, int>>;
|
||||
|
||||
using Operation = cutlass::fmha::device::Sm100FmhaBwd<Element, ElementAccumulator, TileShape, ActiveMask>;
|
||||
|
||||
using TensorStride = Stride<int, _1, Stride<int, int>>; // Seq D (H B)
|
||||
using StrideQ = TensorStride;
|
||||
@ -380,9 +363,6 @@ struct BwdRunner {
|
||||
DeviceAllocation<Element> block_O;
|
||||
DeviceAllocation<ElementAccumulator> block_LSE;
|
||||
|
||||
DeviceAllocation<int> block_cumulative_seqlen_q;
|
||||
DeviceAllocation<int> block_cumulative_seqlen_kv;
|
||||
|
||||
DeviceAllocation<Element> block_dQ;
|
||||
DeviceAllocation<Element> block_dK;
|
||||
DeviceAllocation<Element> block_dV;
|
||||
@ -395,7 +375,7 @@ struct BwdRunner {
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
bool verify(const ProblemShape& problem_shape) {
|
||||
bool verify(const ProblemShapeType& problem_shape) {
|
||||
auto [Q, K, D, HB] = problem_shape;
|
||||
auto [H, B] = HB;
|
||||
|
||||
@ -479,85 +459,22 @@ struct BwdRunner {
|
||||
return passed_dQ && passed_dK && passed_dV;
|
||||
}
|
||||
|
||||
auto initialize_problem_shape(Options const& options) {
|
||||
if constexpr (kIsVarlen) {
|
||||
int num_batches = options.b;
|
||||
|
||||
// generate Q as --b times
|
||||
// gaussian (--Q, --Q / 2) sampled positive
|
||||
// track cumulative
|
||||
std::mt19937 rng(0x202305151552ull);
|
||||
std::normal_distribution<double> dist_q(options.q, options.q / 2);
|
||||
std::normal_distribution<double> dist_kv(options.k, options.k / 2);
|
||||
|
||||
auto generate_positive_int = [](auto& dist, auto& gen) {
|
||||
// "0" is a valid value we test here
|
||||
return std::max(0, static_cast<int>(dist(gen)));
|
||||
};
|
||||
|
||||
std::vector<int> cumulative_seqlen_q = {0};
|
||||
std::vector<int> cumulative_seqlen_kv = {0};
|
||||
|
||||
int total_seqlen_q = 0;
|
||||
int total_seqlen_kv = 0;
|
||||
int max_seqlen_q = 0;
|
||||
int max_seqlen_kv = 0;
|
||||
|
||||
const bool kVarlenSame = false;
|
||||
for (int i = 0; i < num_batches; i++) {
|
||||
int seqlen_q = kVarlenSame ? options.q : generate_positive_int(dist_q, rng);
|
||||
int seqlen_kv = kVarlenSame ? options.k : generate_positive_int(dist_kv, rng);
|
||||
|
||||
total_seqlen_q += seqlen_q;
|
||||
total_seqlen_kv += seqlen_kv;
|
||||
|
||||
max_seqlen_q = std::max(max_seqlen_q, seqlen_q);
|
||||
max_seqlen_kv = std::max(max_seqlen_kv, seqlen_kv);
|
||||
|
||||
cumulative_seqlen_q.push_back(cumulative_seqlen_q.back() + seqlen_q);
|
||||
cumulative_seqlen_kv.push_back(cumulative_seqlen_kv.back() + seqlen_kv);
|
||||
}
|
||||
|
||||
block_cumulative_seqlen_q.reset(cumulative_seqlen_q.size());
|
||||
block_cumulative_seqlen_q.copy_from_host(cumulative_seqlen_q.data(), cumulative_seqlen_q.size());
|
||||
block_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size());
|
||||
block_cumulative_seqlen_kv.copy_from_host(cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size());
|
||||
|
||||
ProblemShape problem_shape{
|
||||
{max_seqlen_q, block_cumulative_seqlen_q.get(), total_seqlen_q},
|
||||
{max_seqlen_kv, block_cumulative_seqlen_kv.get(), total_seqlen_kv},
|
||||
options.d, {options.h, options.b}
|
||||
};
|
||||
auto tensor_shape = make_shape(total_seqlen_q, total_seqlen_kv, options.d, make_shape(options.h, 1));
|
||||
|
||||
return cute::make_tuple(problem_shape, tensor_shape);
|
||||
}
|
||||
else {
|
||||
ProblemShape problem_shape{options.q, options.k, options.d, {options.h, options.b}};
|
||||
return cute::make_tuple(problem_shape, problem_shape);
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
ProblemShape initialize(Options const& options) {
|
||||
auto [problem_shape, tensor_shape] = initialize_problem_shape(options);
|
||||
auto [Q, K, D, HB] = tensor_shape;
|
||||
void initialize(const ProblemShapeType& problem_shape, Options const& options) {
|
||||
auto [Q, K, D, HB] = problem_shape;
|
||||
auto [H, B] = HB;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
Q = cutlass::round_up(Q, 8); // Alignment
|
||||
|
||||
// for varlen, Q == total_Q, K == total_K, B = 1
|
||||
// but in problem_shape, they've got to be max_Q/max_K, and B = B
|
||||
|
||||
auto shape_QO = make_shape(Q, D, make_shape(H, B));
|
||||
auto shape_KV = make_shape(K, D, make_shape(H, B));
|
||||
auto shape_LSE = make_shape(Q, make_shape(H, B));
|
||||
|
||||
stride_Q = make_stride(D, _1{}, make_stride(D*Q, B == 1 ? 0 : D*Q*H));
|
||||
stride_K = make_stride(D, _1{}, make_stride(D*K, B == 1 ? 0 : D*K*H));
|
||||
stride_LSE = make_stride(_1{}, make_stride(Q, B == 1 ? 0 : Q*H));
|
||||
auto shape_QO = select<0,2,3>(problem_shape);
|
||||
auto shape_KV = select<1,2,3>(problem_shape);
|
||||
auto shape_LSE = select<0,3>(problem_shape);
|
||||
|
||||
stride_Q = make_stride(D, _1{}, make_stride(D*Q, D*Q*H));
|
||||
stride_K = make_stride(D, _1{}, make_stride(D*K, D*K*H));
|
||||
stride_V = stride_K;
|
||||
stride_O = stride_Q;
|
||||
stride_LSE = make_stride(_1{}, make_stride(Q, Q*H));
|
||||
|
||||
stride_dQ = stride_Q;
|
||||
stride_dK = stride_K;
|
||||
@ -588,13 +505,6 @@ struct BwdRunner {
|
||||
initialize_block(block_V, seed + 2021, options.init_style_v);
|
||||
initialize_block(block_dO, seed + 2020, options.init_style_do);
|
||||
|
||||
initialize_block(block_dQ, seed + 2030, InitStyle::kOne);
|
||||
initialize_block(block_dK, seed + 2031, InitStyle::kOne);
|
||||
initialize_block(block_dV, seed + 2032, InitStyle::kOne);
|
||||
initialize_block(block_ref_dQ, seed + 2033);
|
||||
initialize_block(block_ref_dK, seed + 2034);
|
||||
initialize_block(block_ref_dV, seed + 2035);
|
||||
|
||||
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()),
|
||||
select<0,2,3>(problem_shape),
|
||||
stride_Q);
|
||||
@ -618,19 +528,15 @@ struct BwdRunner {
|
||||
if (! options.skip_reference) {
|
||||
fmha_reference(problem_shape, mQ, mK, mV, mO, mLSE, ActiveMask{});
|
||||
}
|
||||
|
||||
return problem_shape;
|
||||
}
|
||||
|
||||
ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
|
||||
auto problem_shape = initialize(options);
|
||||
auto problem_shape = make_shape(options.q, options.k, options.d, make_shape(options.h, options.b));
|
||||
|
||||
initialize(problem_shape, options);
|
||||
|
||||
ElementAccumulator softmax_scale = 1.0f / sqrtf(options.d);
|
||||
|
||||
ExampleResult example_result;
|
||||
|
||||
using Operation = cutlass::fmha::device::Sm100FmhaBwd<ProblemShape, Element, ElementAccumulator, TileShape, ActiveMask>;
|
||||
|
||||
typename Operation::Arguments arguments{
|
||||
problem_shape,
|
||||
block_Q.get(), stride_Q,
|
||||
@ -648,6 +554,8 @@ struct BwdRunner {
|
||||
|
||||
Operation op;
|
||||
|
||||
ExampleResult example_result;
|
||||
|
||||
example_result.smem_size = Operation::Kernel::SharedStorageSize;
|
||||
|
||||
size_t workspace_size = 0;
|
||||
@ -742,7 +650,7 @@ struct BwdRunner {
|
||||
|
||||
runtime_ms /= static_cast<float>(options.iterations);
|
||||
|
||||
double flops = 10.0 * (std::is_same_v<ActiveMask, CausalForBackwardMask> ? 0.5 : 1.0);
|
||||
double flops = 10.0 * (std::is_same_v<ActiveMask, CausalMask> ? 0.5 : 1.0);
|
||||
flops *= static_cast<double>(get<0>(problem_shape));
|
||||
flops *= static_cast<double>(get<1>(problem_shape));
|
||||
flops *= static_cast<double>(get<2>(problem_shape));
|
||||
@ -798,28 +706,14 @@ void print_result(const std::string& description, ExampleResult result, bool ver
|
||||
|
||||
struct KernelCoop {};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<class Fn>
|
||||
auto dispatch_bool(bool value, Fn fn) {
|
||||
if (value) {
|
||||
return fn(std::true_type{});
|
||||
}
|
||||
else {
|
||||
return fn(std::false_type{});
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<class Mask>
|
||||
void run_bwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
|
||||
auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) {
|
||||
dispatch_bool(options.varlen, [&](auto is_varlen) {
|
||||
BwdRunner<decltype(is_varlen)::value, decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
|
||||
auto result = runner.run(options, hw_info);
|
||||
print_result(name, result, options.verbose);
|
||||
});
|
||||
BwdRunner<decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
|
||||
auto result = runner.run(options, hw_info);
|
||||
print_result(name, result, options.verbose);
|
||||
};
|
||||
|
||||
using HeadDim = _64;
|
||||
@ -832,11 +726,9 @@ void run_bwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInf
|
||||
template<class Mask>
|
||||
void run_bwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
|
||||
auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) {
|
||||
dispatch_bool(options.varlen, [&](auto is_varlen) {
|
||||
BwdRunner<decltype(is_varlen)::value, decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
|
||||
auto result = runner.run(options, hw_info);
|
||||
print_result(name, result, options.verbose);
|
||||
});
|
||||
BwdRunner<decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
|
||||
auto result = runner.run(options, hw_info);
|
||||
print_result(name, result, options.verbose);
|
||||
};
|
||||
|
||||
using HeadDim = _128;
|
||||
@ -911,10 +803,7 @@ int main_single(int argc, char const **args) {
|
||||
|
||||
auto with_causal = [&](auto fn) {
|
||||
if (options.causal) {
|
||||
fn(CausalForBackwardMask{});
|
||||
}
|
||||
else if (options.residual) {
|
||||
fn(ResidualMaskForBackward{});
|
||||
fn(CausalMask{});
|
||||
}
|
||||
else {
|
||||
fn(NoMask{});
|
||||
|
||||
@ -394,7 +394,6 @@ struct ExampleRunner {
|
||||
fmha_fwd_gen_reference<ElementAcc>(
|
||||
problem_shape, block_seqlen_kv.get(), block_cache_batch_idx.get(),
|
||||
mQ, mNewK, mNewV, mCacheK, mCacheV, mO);
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Reference kernel failed. Last CUDA error: "
|
||||
@ -409,7 +408,6 @@ struct ExampleRunner {
|
||||
double max_diff = 0;
|
||||
double mean_diff = 0;
|
||||
reference_abs_diff(block_o, block_ref_o, max_diff, mean_diff);
|
||||
|
||||
bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
if (! passed_O) {
|
||||
std::cerr << "failed O: max diff " << max_diff
|
||||
@ -417,7 +415,6 @@ struct ExampleRunner {
|
||||
}
|
||||
|
||||
reference_abs_diff(block_cache_k, block_ref_cache_k, max_diff, mean_diff);
|
||||
|
||||
bool passed_K = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
if ( ! passed_K) {
|
||||
std::cerr << "failed Cache K: max diff " << max_diff
|
||||
@ -425,7 +422,6 @@ struct ExampleRunner {
|
||||
}
|
||||
|
||||
reference_abs_diff(block_cache_v, block_ref_cache_v, max_diff, mean_diff);
|
||||
|
||||
bool passed_V = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
if ( ! passed_V) {
|
||||
std::cerr << "failed Cache V: max diff " << max_diff
|
||||
@ -507,7 +503,6 @@ struct ExampleRunner {
|
||||
|
||||
block_ref_cache_k.copy_from_device(block_cache_k.get(), block_cache_k.size());
|
||||
block_ref_cache_v.copy_from_device(block_cache_v.get(), block_cache_v.size());
|
||||
|
||||
block_seqlen_kv.reset(seqlen_kv.size());
|
||||
block_seqlen_kv.copy_from_host(seqlen_kv.data(), seqlen_kv.size());
|
||||
|
||||
@ -726,7 +721,6 @@ int main_single(int argc, char const **args) {
|
||||
<< "later (compute capability 90 or greater) and CUDA 12.0 or greater.\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -59,7 +59,7 @@ using namespace cutlass::fmha::kernel;
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
enum class InitStyle {
|
||||
kOne, kLinearStride128, kLinearStride1, kRandom, kRandomLarge, kNone
|
||||
kOne, kLinearStride128, kLinearStride1, kRandom, kNone
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -98,9 +98,6 @@ struct Options {
|
||||
if (s == "r") {
|
||||
dst = InitStyle::kRandom;
|
||||
}
|
||||
else if (s == "l") {
|
||||
dst = InitStyle::kRandomLarge;
|
||||
}
|
||||
else if (s == "1") {
|
||||
dst = InitStyle::kOne;
|
||||
}
|
||||
@ -206,11 +203,6 @@ void initialize_block(
|
||||
block.get(), block.size(), seed, (Element) -1, (Element) 1);
|
||||
break;
|
||||
}
|
||||
case InitStyle::kRandomLarge: {
|
||||
cutlass::reference::device::BlockFillRandomGaussian(
|
||||
block.get(), block.size(), seed, (Element) -1, (Element) 100);
|
||||
break;
|
||||
}
|
||||
case InitStyle::kLinearStride1: {
|
||||
std::vector<Element> data(block.size());
|
||||
for (size_t i = 0; i < block.size() / 128; i ++) {
|
||||
|
||||
@ -63,7 +63,7 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_BASIC
|
||||
# TEST_CAUSAL
|
||||
TEST_VARLEN
|
||||
# TEST_VARLEN
|
||||
# TEST_HDIM64
|
||||
# TEST_GQA)
|
||||
)
|
||||
@ -119,7 +119,7 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
|
||||
77_blackwell_fmha_bwd.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_BASIC
|
||||
TEST_VARLEN
|
||||
# TEST_GEN_VARLEN
|
||||
# TEST_GEN_HDIM64
|
||||
# TEST_GEN_GQA
|
||||
# TEST_GEN_REMAP
|
||||
@ -128,24 +128,20 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
|
||||
target_include_directories(77_blackwell_fmha_bwd_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_fmha_bwd_${PREC} PRIVATE ${PREC_MACRO})
|
||||
target_compile_options(77_blackwell_fmha_bwd_${PREC} PRIVATE -Xptxas -v)
|
||||
endforeach()
|
||||
|
||||
# Add a target that builds all examples
|
||||
add_custom_target(77_blackwell_fmha_all
|
||||
DEPENDS
|
||||
77_blackwell_fmha_fp8
|
||||
77_blackwell_fmha_fp16
|
||||
77_blackwell_fmha_gen_fp8
|
||||
77_blackwell_fmha_gen_fp16
|
||||
77_blackwell_mla_2sm_fp8
|
||||
77_blackwell_mla_2sm_fp16
|
||||
77_blackwell_mla_2sm_cpasync_fp8
|
||||
77_blackwell_mla_2sm_cpasync_fp16
|
||||
77_blackwell_mla_b2b_2sm_fp8
|
||||
77_blackwell_mla_b2b_2sm_fp16
|
||||
77_blackwell_fmha_bwd_fp8
|
||||
77_blackwell_fmha_bwd_fp16
|
||||
77_blackwell_fmha_bwd_sat_fp8
|
||||
77_blackwell_fmha_bwd_sat_fp16
|
||||
)
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_fmha_bwd_sat_${PREC}
|
||||
77_blackwell_fmha_bwd.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_BASIC
|
||||
# TEST_GEN_VARLEN
|
||||
TEST_GEN_HDIM64
|
||||
# TEST_GEN_GQA
|
||||
# TEST_GEN_REMAP
|
||||
# TEST_GEN_CACHEONLY)
|
||||
)
|
||||
target_include_directories(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE ${PREC_MACRO} SKIP_ATOMIC)
|
||||
target_compile_options(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE -Xptxas -v)
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
@ -132,58 +132,6 @@ struct ResidualMask : NoMask {
|
||||
}
|
||||
};
|
||||
|
||||
struct ResidualMaskForBackward : NoMask {
|
||||
|
||||
using Base = NoMask;
|
||||
|
||||
template <class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE int get_masked_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
if (get<1>(problem_size) % get<1>(tile_shape) != 0) {
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_unmasked_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
// if the sequence length does not divide the tile size evenly
|
||||
if (get<1>(problem_size) % get<1>(tile_shape) != 0) {
|
||||
return get_trip_count(blk_coord, tile_shape, problem_size) - 1;
|
||||
}
|
||||
return get_trip_count(blk_coord, tile_shape, problem_size);
|
||||
}
|
||||
|
||||
template<class AccQK, class IndexQK, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
void apply_mask(
|
||||
AccQK& acc_qk,
|
||||
IndexQK const& index_qk,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
// This is useful is seqlen_k % kBlockN != 0 since it masks
|
||||
// the remaining elements out from softmax.
|
||||
// d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar
|
||||
// issues as they are transparently taken care of by TMA and the
|
||||
// epilogue, if it is instantiated with predication support.
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(acc_qk); i++) {
|
||||
auto pos = index_qk(i);
|
||||
if (! elem_less(pos, select<0,1>(problem_size))) {
|
||||
acc_qk(i) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct CausalMask : NoMask {
|
||||
|
||||
using Base = NoMask;
|
||||
@ -209,7 +157,8 @@ struct CausalMask : NoMask {
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
return ceil_div(get<0>(tile_shape), get<1>(tile_shape));
|
||||
int trip_count = get_trip_count(blk_coord, tile_shape, problem_size);
|
||||
return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));
|
||||
}
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
@ -248,57 +197,25 @@ struct CausalMask : NoMask {
|
||||
|
||||
};
|
||||
|
||||
struct CausalForBackwardMask : CausalMask, ResidualMaskForBackward {
|
||||
|
||||
using Base = CausalMask;
|
||||
|
||||
template<class AccQK, class IndexQK, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
void apply_mask(
|
||||
AccQK& acc_qk,
|
||||
IndexQK const& index_qk,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
// There are two ways to do causal if N_Q != N_K
|
||||
// (1) is to assume that the Q is at the beginning of the matrix
|
||||
// - this is what we demonstrate here
|
||||
// (2) is that it is at the end of the matrix
|
||||
// - this is usually what we want for inference settings
|
||||
// where we only compute the next row and use cache for the rest
|
||||
// - if you'd like this, you only need to add an offset like so:
|
||||
// get<0>(pos) + offset_q < get<1>(pos)
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(acc_qk); i++) {
|
||||
auto pos = index_qk(i);
|
||||
bool masked = (get<0>(pos) < get<1>(pos)) || !elem_less(pos, problem_size);
|
||||
if (masked) {
|
||||
acc_qk(i) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
struct VariableLength {
|
||||
int max_length;
|
||||
int* cumulative_length = nullptr;
|
||||
int total_length = -1;
|
||||
|
||||
CUTE_HOST_DEVICE operator int() const {
|
||||
return max_length;
|
||||
}
|
||||
};
|
||||
|
||||
template<class T> struct is_variable_length_impl : std::false_type {};
|
||||
template<> struct is_variable_length_impl<VariableLength> : std::true_type {};
|
||||
template<class T> constexpr bool is_variable_length_v = is_variable_length_impl<remove_cvref_t<T>>::value;
|
||||
template<class T> struct is_variable_length : std::false_type {};
|
||||
template<> struct is_variable_length<VariableLength> : std::true_type {};
|
||||
template<class T> constexpr bool is_variable_length_v = is_variable_length<T>::value;
|
||||
|
||||
template<class Shape, class Idx>
|
||||
CUTE_HOST_DEVICE
|
||||
constexpr auto
|
||||
apply_variable_length(Shape const& shape, Idx const& idx) {
|
||||
return transform_leaf(shape, [&](auto const& s) {
|
||||
if constexpr (is_variable_length_v<decltype(s)>) {
|
||||
if constexpr (is_variable_length_v<remove_cvref_t<decltype(s)>>) {
|
||||
return s.cumulative_length[idx+1] - s.cumulative_length[idx];
|
||||
}
|
||||
else {
|
||||
@ -313,7 +230,7 @@ constexpr auto
|
||||
apply_variable_length(Shape const& shape, Coord const& coord, Idx const& idx) {
|
||||
auto new_shape = apply_variable_length(shape, idx);
|
||||
auto new_coord = transform_leaf(shape, coord, [&](auto const& s, auto const& c) {
|
||||
if constexpr (is_variable_length_v<decltype(s)>) {
|
||||
if constexpr (is_variable_length_v<remove_cvref_t<decltype(s)>>) {
|
||||
return cute::make_tuple(c, s.cumulative_length[idx]);
|
||||
}
|
||||
else {
|
||||
@ -323,30 +240,6 @@ apply_variable_length(Shape const& shape, Coord const& coord, Idx const& idx) {
|
||||
return cute::make_tuple(new_shape, new_coord);
|
||||
}
|
||||
|
||||
template<class Shape, class Coord>
|
||||
CUTE_HOST_DEVICE
|
||||
constexpr auto
|
||||
apply_variable_length_offset(Shape const& shape, Coord const& coord) {
|
||||
auto idx = back(back(coord));
|
||||
auto result_shape = transform_leaf(shape, [&](auto const& s) {
|
||||
if constexpr (is_variable_length_v<decltype(s)>) {
|
||||
return s.cumulative_length[idx+1] - s.cumulative_length[idx];
|
||||
}
|
||||
else {
|
||||
return s;
|
||||
}
|
||||
});
|
||||
auto result_offset = transform_leaf(coord, shape, [&](auto const& c, auto const& s) {
|
||||
if constexpr (is_variable_length_v<decltype(s)>) {
|
||||
return s.cumulative_length[idx];
|
||||
}
|
||||
else {
|
||||
return _0{};
|
||||
}
|
||||
});
|
||||
return cute::make_tuple(result_shape, result_offset);
|
||||
}
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
|
||||
namespace cute {
|
||||
|
||||
@ -42,7 +42,7 @@ template<
|
||||
class ElementAcc,
|
||||
class TileShape, // Q, D, _
|
||||
class StrideO, // Q, D, B
|
||||
class StrideLSE_ // Q, B
|
||||
class StrideLSE // Q, B
|
||||
>
|
||||
struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
||||
|
||||
@ -54,7 +54,6 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
||||
// using SmemLayoutAtomO = decltype(make_ordered_layout(select<0,1>(TileShape{}), Step<_1, _0>{}));
|
||||
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, replace<2>(TileShape{}, _2{}), Step<_2, _1, _3>{}));
|
||||
using SmemLayoutO_ = SmemLayoutO;
|
||||
using StrideLSE = StrideLSE_;
|
||||
|
||||
struct TensorStorage {
|
||||
|
||||
@ -80,9 +79,6 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
||||
|
||||
struct Params {
|
||||
TMA_O tma_store_o;
|
||||
|
||||
ElementAcc* ptr_LSE;
|
||||
StrideLSE dLSE;
|
||||
};
|
||||
|
||||
template<class ProblemShape>
|
||||
@ -114,9 +110,7 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
||||
);
|
||||
|
||||
return {
|
||||
tma_store_o,
|
||||
args.ptr_LSE,
|
||||
args.dLSE
|
||||
tma_store_o
|
||||
};
|
||||
}
|
||||
|
||||
@ -125,10 +119,6 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
||||
cute::prefetch_tma_descriptor(params.tma_store_o.get_tma_descriptor());
|
||||
}
|
||||
|
||||
const Params& params;
|
||||
|
||||
CUTLASS_DEVICE Sm100FmhaFwdEpilogueTmaWarpspecialized(const Params& params) : params(params) {}
|
||||
|
||||
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
|
||||
CUTLASS_DEVICE auto
|
||||
store(
|
||||
|
||||
@ -637,7 +637,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
}
|
||||
tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv);
|
||||
|
||||
|
||||
|
||||
if (i == size(tTMEM_LOADrS) - kReleasePipeCount) {
|
||||
order_s.arrive();
|
||||
}
|
||||
@ -691,7 +691,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3);
|
||||
cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2);
|
||||
float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y;
|
||||
|
||||
|
||||
row_sum = local_row_sum;
|
||||
|
||||
if (final_call) {
|
||||
@ -787,14 +787,14 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
// good values would be either 32 or 64
|
||||
const int kCorrectionTileSize = 32 / sizeof(ElementOut);
|
||||
|
||||
using TMEM_LOAD = std::conditional_t<kCorrectionTileSize == 32, SM100_TMEM_LOAD_32dp32b32x, SM100_TMEM_LOAD_32dp32b16x>; // 4x32 threads with 64 cols of 32b elem
|
||||
using TMEM_LOAD = std::conditional_t<kCorrectionTileSize == 32, SM100_TMEM_LOAD_32dp32b32x, SM100_TMEM_LOAD_32dp32b16x>; // 4x32 threads with 64 cols of 32b elem
|
||||
|
||||
typename CollectiveMmaPV::TiledMma mma;
|
||||
Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));
|
||||
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));
|
||||
Tensor tOcO = mma.get_slice(0).partition_C(cO);
|
||||
Tensor tOsO = mma.get_slice(0).partition_C(sO);
|
||||
|
||||
|
||||
Tensor tOtO_i = logical_divide(tOtO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
||||
Tensor tOcO_i = logical_divide(tOcO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
||||
Tensor tOsO_i = logical_divide(tOsO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
||||
@ -809,7 +809,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
|
||||
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i(make_coord(_, _), _0{}));
|
||||
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
|
||||
|
||||
|
||||
Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i(make_coord(_, _), _));
|
||||
Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i(make_coord(_, _), _));
|
||||
Tensor tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i(make_coord(_, _), _));
|
||||
@ -824,9 +824,9 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i);
|
||||
|
||||
Tensor tTMrO = make_tensor<ElementPV>(shape(tTMEM_LOADcO(_, _0{}, _0{}, i)));
|
||||
|
||||
|
||||
copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO);
|
||||
|
||||
|
||||
#ifndef ONLY_SOFTMAX
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size(tTMrO); j += 2) {
|
||||
@ -872,24 +872,24 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
// good values would be either 32 or 64
|
||||
const int kCorrectionTileSize = 16;
|
||||
|
||||
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem
|
||||
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem
|
||||
using TMEM_STORE = SM100_TMEM_STORE_32dp32b16x; // 4x32 threads with 64 cols of 32b elem
|
||||
|
||||
typename CollectiveMmaPV::TiledMma mma;
|
||||
Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));
|
||||
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));
|
||||
Tensor tOcO = mma.get_slice(0).partition_C(cO);
|
||||
|
||||
|
||||
Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
||||
Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
||||
|
||||
tOtO_i.data() = tOtO_i.data().get() + tmem_O;
|
||||
|
||||
|
||||
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i);
|
||||
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
|
||||
auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i);
|
||||
auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx);
|
||||
|
||||
|
||||
Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i);
|
||||
Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i);
|
||||
Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i);
|
||||
@ -899,7 +899,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
float2 scale_f32x2 = make_float2(scale, scale);
|
||||
|
||||
Tensor tTMrO = make_tensor<ElementPV>(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{}));
|
||||
|
||||
|
||||
auto copy_in = [&](int i) {
|
||||
Tensor tTMEM_LOADtO_i = tTMEM_LOADtO;
|
||||
tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize);
|
||||
@ -942,7 +942,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
}
|
||||
}
|
||||
|
||||
template<class BlkCoord, class ProblemShape, class TensorStorageEpi, class CollectiveEpilogue>
|
||||
template<class BlkCoord, class ProblemShape, class TensorStorageEpi>
|
||||
CUTLASS_DEVICE auto
|
||||
correction(
|
||||
BlkCoord const& blk_coord,
|
||||
@ -951,8 +951,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state,
|
||||
PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state,
|
||||
PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state,
|
||||
PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state,
|
||||
CollectiveEpilogue& epilogue) {
|
||||
PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state) {
|
||||
|
||||
int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape);
|
||||
|
||||
@ -962,7 +961,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
|
||||
Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{}));
|
||||
Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS);
|
||||
|
||||
|
||||
Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{})));
|
||||
Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));
|
||||
|
||||
@ -1061,25 +1060,13 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
// F2FP
|
||||
// store to smem
|
||||
Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{});
|
||||
Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), repeat_like(typename CollectiveEpilogue::StrideLSE{}, _1{}), epilogue.params.dLSE);
|
||||
|
||||
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO);
|
||||
|
||||
if (epilogue.params.ptr_LSE != nullptr) {
|
||||
int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord);
|
||||
|
||||
ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);
|
||||
|
||||
if (row_idx < get<0>(problem_shape)) {
|
||||
gLSE(row_idx, get<2>(blk_coord)) = lse;
|
||||
}
|
||||
}
|
||||
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
|
||||
pipeline_o.consumer_release(pipeline_o_consumer_state);
|
||||
++pipeline_o_consumer_state;
|
||||
|
||||
|
||||
pipeline_epi.producer_commit(pipeline_epi_producer_state);
|
||||
++pipeline_epi_producer_state;
|
||||
|
||||
@ -1096,16 +1083,6 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
|
||||
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _1{}, sO);
|
||||
|
||||
if (epilogue.params.ptr_LSE != nullptr) {
|
||||
int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{});
|
||||
|
||||
ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);
|
||||
|
||||
if (row_idx < get<0>(problem_shape)) {
|
||||
gLSE(row_idx, get<2>(blk_coord)) = lse;
|
||||
}
|
||||
}
|
||||
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
|
||||
pipeline_o.consumer_release(pipeline_o_consumer_state);
|
||||
|
||||
@ -50,19 +50,13 @@ namespace cutlass::fmha::device {
|
||||
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class ProblemShape,
|
||||
class Element,
|
||||
class ElementAccumulator,
|
||||
class TileShape,
|
||||
class Mask
|
||||
>
|
||||
template<class Element, class ElementAccumulator, class TileShape, class Mask>
|
||||
class Sm100FmhaBwd {
|
||||
public:
|
||||
/// Argument structure: User API
|
||||
struct Arguments {
|
||||
// Q K D HB
|
||||
ProblemShape problem_shape;
|
||||
cute::tuple<int, int, int, cute::tuple<int, int>> problem_size;
|
||||
|
||||
const Element* ptr_Q;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_Q;
|
||||
@ -92,16 +86,14 @@ public:
|
||||
};
|
||||
|
||||
using OperationSumOdO = cutlass::fmha::device::FMHA<
|
||||
cutlass::fmha::kernel::FmhaKernelBwdSumOdO<ProblemShape, Element, ElementAccumulator>
|
||||
cutlass::fmha::kernel::FmhaKernelBwdSumOdO<Element, ElementAccumulator>
|
||||
>;
|
||||
using OperationConvert = cutlass::fmha::device::FMHA<
|
||||
cutlass::fmha::kernel::FmhaKernelBwdConvert<ProblemShape, Element, ElementAccumulator>
|
||||
cutlass::fmha::kernel::FmhaKernelBwdConvert<Element, ElementAccumulator>
|
||||
>;
|
||||
|
||||
using Operation = cutlass::fmha::device::FMHA<
|
||||
cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized<
|
||||
ProblemShape, Element, ElementAccumulator, TileShape, Mask
|
||||
>
|
||||
cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized<Element, ElementAccumulator, TileShape, Mask>
|
||||
>;
|
||||
using Kernel = typename Operation::Kernel;
|
||||
|
||||
@ -121,15 +113,15 @@ private:
|
||||
ElementAccumulator* sum_odo = nullptr,
|
||||
ElementAccumulator* scaled_lse = nullptr) {
|
||||
using namespace cute;
|
||||
auto [Q_, K, D, HB] = args.problem_shape;
|
||||
auto [Q, K, D, HB] = args.problem_size;
|
||||
auto [H, B] = HB;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
|
||||
Q = cutlass::round_up(Q, 8); // Alignment
|
||||
auto stride_sum_OdO = make_stride(_1{}, make_stride(Q, Q*H));
|
||||
auto stride_scaled_lse = make_stride(_1{}, make_stride(Q, Q*H));
|
||||
auto log2_e = log2f(expf(1.0f));
|
||||
return typename OperationSumOdO::Arguments {
|
||||
args.problem_shape,
|
||||
args.problem_size,
|
||||
args.ptr_O, args.stride_O,
|
||||
args.ptr_dO, args.stride_dO,
|
||||
sum_odo, stride_sum_OdO,
|
||||
@ -141,13 +133,13 @@ private:
|
||||
|
||||
static typename OperationConvert::Arguments to_convert_arguments(Arguments const& args, ElementAccumulator* src = nullptr) {
|
||||
using namespace cute;
|
||||
auto [Q_, K, D, HB] = args.problem_shape;
|
||||
auto [Q, K, D, HB] = args.problem_size;
|
||||
auto [H, B] = HB;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
|
||||
Q = cutlass::round_up(Q, 8); // Alignment
|
||||
auto stride_src_dQ = make_stride(D, _1{}, make_stride(D*Q, D*Q*H));
|
||||
return typename OperationConvert::Arguments {
|
||||
args.problem_shape,
|
||||
args.problem_size,
|
||||
src, stride_src_dQ,
|
||||
nullptr, stride_src_dQ,
|
||||
nullptr, stride_src_dQ,
|
||||
@ -164,7 +156,7 @@ private:
|
||||
ElementAccumulator* scaled_lse = nullptr, cute::tuple<cute::_1, cute::tuple<int, int>> const& stride_scaled_lse = {},
|
||||
ElementAccumulator* dQ_acc = nullptr, cute::tuple<int, cute::_1, cute::tuple<int, int>> const& stride_dQ = {}) {
|
||||
return typename Operation::Arguments{
|
||||
args.problem_shape,
|
||||
args.problem_size,
|
||||
{ args.ptr_Q, args.stride_Q,
|
||||
args.ptr_K, args.stride_K,
|
||||
args.ptr_V, args.stride_V,
|
||||
@ -207,10 +199,10 @@ public:
|
||||
/// Gets the workspace size
|
||||
static size_t
|
||||
get_workspace_size(Arguments const& args) {
|
||||
auto [Q_, K, D, HB] = args.problem_shape;
|
||||
auto [Q, K, D, HB] = args.problem_size;
|
||||
auto [H, B] = HB;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
|
||||
Q = cutlass::round_up(Q, 8); // Alignment
|
||||
size_t workspace_bytes = 0;
|
||||
// OdO vector
|
||||
workspace_bytes += B*H*Q * sizeof(ElementAccumulator);
|
||||
@ -227,10 +219,10 @@ public:
|
||||
CUTLASS_TRACE_HOST("Universal::initialize_split() - workspace_dQ="
|
||||
<< workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
auto [Q_, K, D, HB] = args.problem_shape;
|
||||
auto [Q, K, D, HB] = args.problem_size;
|
||||
auto [H, B] = HB;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
|
||||
Q = cutlass::round_up(Q, 8); // Alignment
|
||||
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_sum_OdO);
|
||||
ElementAccumulator* scaled_lse = reinterpret_cast<ElementAccumulator*>(workspace_scaled_lse);
|
||||
ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_dQ);
|
||||
@ -256,10 +248,10 @@ public:
|
||||
CUTLASS_TRACE_HOST("Universal::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
auto [Q_, K, D, HB] = args.problem_shape;
|
||||
auto [Q, K, D, HB] = args.problem_size;
|
||||
auto [H, B] = HB;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
|
||||
Q = cutlass::round_up(Q, 8); // Alignment
|
||||
char* workspace_chr = reinterpret_cast<char*>(workspace);
|
||||
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_chr);
|
||||
workspace_chr += B*H*Q * sizeof(ElementAccumulator);
|
||||
|
||||
@ -39,11 +39,11 @@ namespace cutlass::fmha::kernel {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<class ProblemShape, class Element, class ElementAcc>
|
||||
template<class Element, class ElementAcc>
|
||||
struct FmhaKernelBwdConvert {
|
||||
|
||||
struct Arguments {
|
||||
ProblemShape problem_shape;
|
||||
tuple<int, int, int, tuple<int, int>> problem_size;
|
||||
|
||||
const ElementAcc* ptr_src_dQ;
|
||||
tuple<int, _1, tuple<int, int>> stride_src_dQ;
|
||||
@ -85,11 +85,11 @@ struct FmhaKernelBwdConvert {
|
||||
static const int kIterationsSeq = kBlockSeq / kNumThreadsSeq;
|
||||
|
||||
static bool can_implement(Arguments const& args) {
|
||||
return get<2>(args.problem_shape) % kElementsPerLoad == 0;
|
||||
return get<2>(args.problem_size) % kElementsPerLoad == 0;
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
dim3 grid(size<3,0>(params.problem_shape), size<3,1>(params.problem_shape), ceil_div(std::max(size<0>(params.problem_shape), size<1>(params.problem_shape)), kBlockSeq));
|
||||
dim3 grid(size<3,0>(params.problem_size), size<3,1>(params.problem_size), ceil_div(std::max(size<0>(params.problem_size), size<1>(params.problem_size)), kBlockSeq));
|
||||
return grid;
|
||||
}
|
||||
|
||||
@ -102,25 +102,18 @@ struct FmhaKernelBwdConvert {
|
||||
return args;
|
||||
}
|
||||
|
||||
template<class StrideSrc, class StrideDest, class Count>
|
||||
CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, Count const& count) {
|
||||
template<class StrideSrc, class StrideDest>
|
||||
CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, int count) {
|
||||
auto ptr_src_bh = ptr_src + get<2,0>(stride_src) * blockIdx.x + get<2,1>(stride_src) * blockIdx.y;
|
||||
auto ptr_dest_bh = ptr_dest + get<2,0>(stride_dest) * blockIdx.x + get<2,1>(stride_dest) * blockIdx.y;
|
||||
|
||||
int seqlen = count;
|
||||
if constexpr (is_variable_length_v<decltype(count)>) {
|
||||
int offset = count.cumulative_length[blockIdx.y];
|
||||
ptr_dest_bh += offset * get<0>(stride_dest);
|
||||
seqlen = count.cumulative_length[blockIdx.y + 1] - offset;
|
||||
}
|
||||
|
||||
for (int idx_s_t = threadIdx.y; idx_s_t < kBlockSeq; idx_s_t += kNumThreadsSeq) {
|
||||
int idx_s = idx_s_t + kBlockSeq * blockIdx.z;
|
||||
if (idx_s >= seqlen) continue;
|
||||
if (idx_s >= count) continue;
|
||||
auto ptr_src_bhs = ptr_src_bh + idx_s * get<0>(stride_src);
|
||||
auto ptr_dest_bhs = ptr_dest_bh + idx_s * get<0>(stride_dest);
|
||||
|
||||
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_shape); idx_d += kElementsPerLoad * kNumThreadsD) {
|
||||
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) {
|
||||
ElementAcc value_src[kElementsPerLoad];
|
||||
Element value_dest[kElementsPerLoad];
|
||||
|
||||
@ -139,13 +132,13 @@ struct FmhaKernelBwdConvert {
|
||||
|
||||
CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) {
|
||||
if (params.ptr_src_dQ != nullptr) {
|
||||
copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_shape));
|
||||
copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_size));
|
||||
}
|
||||
if (params.ptr_src_dK != nullptr) {
|
||||
copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<1>(params.problem_shape));
|
||||
copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<1>(params.problem_size));
|
||||
}
|
||||
if (params.ptr_src_dV != nullptr) {
|
||||
copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<1>(params.problem_shape));
|
||||
copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<1>(params.problem_size));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -39,11 +39,11 @@ namespace cutlass::fmha::kernel {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<class ProblemShape, class Element, class ElementAcc>
|
||||
template<class Element, class ElementAcc>
|
||||
struct FmhaKernelBwdSumOdO {
|
||||
|
||||
struct Arguments {
|
||||
ProblemShape problem_shape;
|
||||
cute::tuple<int, int, int, cute::tuple<int, int>> problem_size;
|
||||
|
||||
const Element* ptr_O;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_O;
|
||||
@ -86,11 +86,11 @@ struct FmhaKernelBwdSumOdO {
|
||||
static const int kIterationsQ = kBlockQ / kNumThreadsQ;
|
||||
|
||||
static bool can_implement(Arguments const& args) {
|
||||
return get<2>(args.problem_shape) % kElementsPerLoad == 0;
|
||||
return get<2>(args.problem_size) % kElementsPerLoad == 0;
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
dim3 grid(ceil_div(size<0>(params.problem_shape), kBlockQ), size<3,0>(params.problem_shape), size<3,1>(params.problem_shape));
|
||||
dim3 grid(ceil_div(size<0>(params.problem_size), kBlockQ), size<3,0>(params.problem_size), size<3,1>(params.problem_size));
|
||||
return grid;
|
||||
}
|
||||
|
||||
@ -110,20 +110,10 @@ struct FmhaKernelBwdSumOdO {
|
||||
auto ptr_lse_bh = params.ptr_lse + blockIdx.y * get<1,0>(params.stride_lse) + blockIdx.z * get<1,1>(params.stride_lse);
|
||||
auto ptr_scaled_lse_bh = params.ptr_scaled_lse + blockIdx.y * get<1,0>(params.stride_scaled_lse) + blockIdx.z * get<1,1>(params.stride_scaled_lse);
|
||||
|
||||
auto problem_q = get<0>(params.problem_shape);
|
||||
int seqlen_q = problem_q;
|
||||
if constexpr (is_variable_length_v<decltype(problem_q)>) {
|
||||
int offset = problem_q.cumulative_length[blockIdx.z];
|
||||
ptr_O_bh += offset * get<0>(params.stride_O);
|
||||
ptr_dO_bh += offset * get<0>(params.stride_dO);
|
||||
ptr_lse_bh += offset * get<0>(params.stride_lse);
|
||||
seqlen_q = problem_q.cumulative_length[blockIdx.z + 1] - offset;
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int idx_q_t = threadIdx.y; idx_q_t < kBlockQ; idx_q_t += kNumThreadsQ) {
|
||||
int idx_q = idx_q_t + kBlockQ * blockIdx.x;
|
||||
if (idx_q >= seqlen_q) continue;
|
||||
if (idx_q >= get<0>(params.problem_size)) continue;
|
||||
ElementAcc acc = 0;
|
||||
auto ptr_O_bhq = ptr_O_bh + idx_q * get<0>(params.stride_O);
|
||||
auto ptr_dO_bhq = ptr_dO_bh + idx_q * get<0>(params.stride_dO);
|
||||
@ -131,7 +121,7 @@ struct FmhaKernelBwdSumOdO {
|
||||
auto ptr_lse_bhq = ptr_lse_bh + idx_q * get<0>(params.stride_lse);
|
||||
auto ptr_scaled_lse_bhq = ptr_scaled_lse_bh + idx_q * get<0>(params.stride_scaled_lse);
|
||||
|
||||
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_shape); idx_d += kElementsPerLoad * kNumThreadsD) {
|
||||
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) {
|
||||
Element value_O[kElementsPerLoad];
|
||||
Element value_dO[kElementsPerLoad];
|
||||
|
||||
|
||||
@ -82,4 +82,4 @@ struct Option {
|
||||
using option_value = Value;
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
||||
} // namespace cutlass::fmha::kernel
|
||||
|
||||
@ -90,8 +90,8 @@ struct PersistentTileScheduler {
|
||||
struct Params {
|
||||
int num_blocks;
|
||||
FastDivmod divmod_m_block;
|
||||
FastDivmod divmod_b;
|
||||
FastDivmod divmod_h;
|
||||
FastDivmod divmod_b;
|
||||
|
||||
KernelHardwareInfo hw_info;
|
||||
};
|
||||
@ -146,7 +146,7 @@ struct PersistentTileScheduler {
|
||||
params.divmod_m_block(block_decode, m_block, block_decode);
|
||||
params.divmod_b(block_decode, bidb, block_decode);
|
||||
params.divmod_h(block_decode, bidh, block_decode);
|
||||
return make_coord(m_block, _0{}, make_coord(bidb, bidh));
|
||||
return make_coord(m_block, _0{}, make_coord(bidh, bidb));
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
|
||||
@ -43,8 +43,6 @@
|
||||
|
||||
#include "collective/fmha_common.hpp"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
using namespace cutlass::fmha::collective;
|
||||
@ -52,7 +50,6 @@ using namespace cutlass::fmha::collective;
|
||||
using namespace cute;
|
||||
|
||||
template<
|
||||
class ProblemShape,
|
||||
class Element,
|
||||
class ElementAcc,
|
||||
class TileShape,
|
||||
@ -121,7 +118,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
|
||||
using TensorStrideContiguousK = Stride<int, _1, Stride<int, int>>;
|
||||
using TensorStrideContiguousMN = Stride<_1, int, Stride<int, int>>;
|
||||
|
||||
|
||||
// compute S
|
||||
using CollectiveMmaKQ = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
|
||||
@ -277,6 +274,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
static constexpr int SharedStorageSize = offsetof(SharedStorage, tmem_base_ptr) + sizeof(uint32_t);
|
||||
static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem");
|
||||
|
||||
using ProblemShape = Shape<int, int, int, Shape<int, int>>; // Q K D (H B), eventuall D = (D_QK, D_VO)
|
||||
using TensorStride = TensorStrideContiguousK; // S D (H B)
|
||||
using RowTensorStride = Stride<_1, Stride<int, int>>; // S (H B)
|
||||
|
||||
@ -362,16 +360,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args, void*) {
|
||||
auto [Q_, K_, D, HB] = args.problem_shape;
|
||||
int Q = Q_;
|
||||
int K = K_;
|
||||
|
||||
if constexpr (is_variable_length_v<decltype(Q_)>) {
|
||||
Q = Q_.total_length;
|
||||
}
|
||||
if constexpr (is_variable_length_v<decltype(K_)>) {
|
||||
K = K_.total_length;
|
||||
}
|
||||
auto [Q, K, D, HB] = args.problem_shape;
|
||||
|
||||
auto params_kq = CollectiveMmaKQ::to_underlying_arguments(
|
||||
make_shape(K, Q, D, HB),
|
||||
@ -389,10 +378,10 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
|
||||
TMA_DQ tma_red_dq = make_tma_copy(
|
||||
SM90_TMA_REDUCE_ADD{},
|
||||
make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q_, D, HB), args.mainloop.stride_dq_acc),
|
||||
make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q, D, HB), args.mainloop.stride_dq_acc),
|
||||
SmemLayoutDQ{}(_, _, _0{})
|
||||
);
|
||||
|
||||
|
||||
return Params{
|
||||
args.problem_shape,
|
||||
args.mainloop,
|
||||
@ -427,11 +416,10 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
}
|
||||
|
||||
|
||||
template<class BlkCoord, class BlkOffset, class ProblemShape_>
|
||||
template<class BlkCoord>
|
||||
CUTLASS_DEVICE void load(
|
||||
BlkCoord const& blk_coord,
|
||||
BlkOffset const& blk_offset,
|
||||
ProblemShape_ const& problem_shape,
|
||||
ProblemShape const& problem_shape,
|
||||
int iter_index,
|
||||
int iter_count,
|
||||
MainloopArguments const& mainloop_args,
|
||||
@ -452,15 +440,10 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
|
||||
uint16_t mcast_mask = 0;
|
||||
|
||||
auto mK_in = mainloop_params.tma_load_k.get_tma_tensor(make_shape(K, D, HB));
|
||||
auto mV_in = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D, HB));
|
||||
auto mQ_in = mainloop_params.tma_load_q.get_tma_tensor(make_shape(Q, D, HB));
|
||||
auto mDO_in = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D, HB));
|
||||
|
||||
auto mK = domain_offset(select<1,2,3>(blk_offset), mK_in);
|
||||
auto mV = domain_offset(select<1,2,3>(blk_offset), mV_in);
|
||||
auto mQ = domain_offset(select<0,2,3>(blk_offset), mQ_in);
|
||||
auto mDO = domain_offset(select<0,2,3>(blk_offset), mDO_in);
|
||||
auto mK = mainloop_params.tma_load_k.get_tma_tensor(make_shape(K, D, HB));
|
||||
auto mQ = mainloop_params.tma_load_q.get_tma_tensor(make_shape(Q, D, HB));
|
||||
auto mV = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D, HB));
|
||||
auto mDO = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D, HB));
|
||||
|
||||
auto gK = local_tile(mK, TileShapeKQ{}, make_coord(_,_,_), Step<_1, X, _1>{});
|
||||
auto gQ = local_tile(mQ, TileShapeKQ{}, make_coord(_,_,_), Step<X, _1, _1>{});
|
||||
@ -469,7 +452,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
|
||||
ThrMMA cta_mma_kq = TiledMmaKQ{}.get_slice(_0{});
|
||||
ThrMMA cta_mma_vdo = TiledMmaVDO{}.get_slice(_0{});
|
||||
|
||||
|
||||
auto tSTgK = cta_mma_kq.partition_A(gK);
|
||||
auto tSTgQ = cta_mma_kq.partition_B(gQ);
|
||||
auto tDPTgV = cta_mma_vdo.partition_A(gV);
|
||||
@ -494,8 +477,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
group_modes<0,3>(sDO), group_modes<0,3>(tDPTgDO));
|
||||
|
||||
// set up lse and sum_odo
|
||||
|
||||
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch] = blk_coord;
|
||||
|
||||
auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord;
|
||||
|
||||
pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state);
|
||||
auto tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state);
|
||||
@ -512,7 +495,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
}
|
||||
|
||||
// load Q
|
||||
if (cute::elect_one_sync()) {
|
||||
if (cute::elect_one_sync()) {
|
||||
cute::copy(
|
||||
mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask),
|
||||
tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch),
|
||||
@ -532,14 +515,12 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
int smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4;
|
||||
int gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
|
||||
auto mLSE = make_tensor(mainloop_args.ptr_lse, make_shape(Q, HB), mainloop_args.stride_lse);
|
||||
for (int i = 0; i < 4; i++) {
|
||||
cutlass::arch::cp_async_zfill<4>(
|
||||
shared_tensors.smem_lse.begin() + smem_idx + i,
|
||||
&mLSE(gmem_idx + i, blk_coord_batch),
|
||||
gmem_idx + i < Q
|
||||
);
|
||||
}
|
||||
|
||||
cutlass::arch::cp_async_zfill<16>(
|
||||
shared_tensors.smem_lse.begin() + smem_idx,
|
||||
&mLSE(gmem_idx, blk_coord_batch),
|
||||
gmem_idx < Q
|
||||
);
|
||||
|
||||
pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive);
|
||||
++pipeline_load_compute_lse_producer_state;
|
||||
|
||||
@ -548,7 +529,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state);
|
||||
|
||||
pipeline_load_mma_do.producer_expect_transaction(pipeline_load_mma_do_producer_state, kTransactionsBytesLoadV);
|
||||
|
||||
|
||||
// load V
|
||||
if (cute::elect_one_sync()) {
|
||||
cute::copy(
|
||||
@ -559,7 +540,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
}
|
||||
|
||||
// load dO
|
||||
if (cute::elect_one_sync()) {
|
||||
if (cute::elect_one_sync()) {
|
||||
cute::copy(
|
||||
mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask),
|
||||
tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch),
|
||||
@ -575,13 +556,11 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4;
|
||||
gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
|
||||
auto mSumOdO = make_tensor(mainloop_args.ptr_sum_odo, make_shape(Q, HB), mainloop_args.stride_sum_odo);
|
||||
for (int i = 0; i < 4; i++) {
|
||||
cutlass::arch::cp_async_zfill<4>(
|
||||
shared_tensors.smem_sum_odo.begin() + smem_idx + i,
|
||||
&mSumOdO(gmem_idx + i, blk_coord_batch),
|
||||
gmem_idx + i < Q
|
||||
);
|
||||
}
|
||||
cutlass::arch::cp_async<16>(
|
||||
shared_tensors.smem_sum_odo.begin() + smem_idx,
|
||||
&mSumOdO(gmem_idx, blk_coord_batch),
|
||||
gmem_idx < Q
|
||||
);
|
||||
|
||||
pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive);
|
||||
++pipeline_load_compute_sum_odo_producer_state;
|
||||
@ -594,7 +573,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state);
|
||||
|
||||
// load Q
|
||||
if (cute::elect_one_sync()) {
|
||||
if (cute::elect_one_sync()) {
|
||||
cute::copy(
|
||||
mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask),
|
||||
tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch),
|
||||
@ -605,26 +584,24 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
++pipeline_load_mma_q_producer_state;
|
||||
|
||||
pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state);
|
||||
|
||||
|
||||
// load LSE
|
||||
smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4;
|
||||
gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
cutlass::arch::cp_async_zfill<4>(
|
||||
shared_tensors.smem_lse.begin() + smem_idx + i,
|
||||
&mLSE(gmem_idx + i, blk_coord_batch),
|
||||
gmem_idx + i < Q
|
||||
);
|
||||
}
|
||||
|
||||
cutlass::arch::cp_async<16>(
|
||||
shared_tensors.smem_lse.begin() + smem_idx,
|
||||
&mLSE(gmem_idx, blk_coord_batch),
|
||||
gmem_idx < Q
|
||||
);
|
||||
|
||||
pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive);
|
||||
++pipeline_load_compute_lse_producer_state;
|
||||
|
||||
pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state);
|
||||
tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state);
|
||||
|
||||
// load dO
|
||||
if (cute::elect_one_sync()) {
|
||||
// load dO
|
||||
if (cute::elect_one_sync()) {
|
||||
cute::copy(
|
||||
mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask),
|
||||
tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch),
|
||||
@ -635,18 +612,16 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
++pipeline_load_mma_do_producer_state;
|
||||
|
||||
pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state);
|
||||
|
||||
|
||||
// load sum_OdO
|
||||
smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4;
|
||||
gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
cutlass::arch::cp_async_zfill<4>(
|
||||
shared_tensors.smem_sum_odo.begin() + smem_idx + i,
|
||||
&mSumOdO(gmem_idx + i, blk_coord_batch),
|
||||
gmem_idx + i < Q
|
||||
);
|
||||
}
|
||||
|
||||
cutlass::arch::cp_async_zfill<16>(
|
||||
shared_tensors.smem_sum_odo.begin() + smem_idx,
|
||||
&mSumOdO(gmem_idx, blk_coord_batch),
|
||||
gmem_idx < Q
|
||||
);
|
||||
|
||||
pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive);
|
||||
++pipeline_load_compute_sum_odo_producer_state;
|
||||
|
||||
@ -656,31 +631,31 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
}
|
||||
|
||||
|
||||
template<class BlkCoord, class ProblemShape_>
|
||||
template<class BlkCoord>
|
||||
CUTLASS_DEVICE void mma(
|
||||
BlkCoord const& blk_coord,
|
||||
ProblemShape_ const& problem_shape,
|
||||
ProblemShape const& problem_shape,
|
||||
int iter_index,
|
||||
int iter_count,
|
||||
MainloopArguments const& mainloop_args,
|
||||
TensorStorage& shared_tensors,
|
||||
PipelineLoadMmaQ& pipeline_load_mma_q,
|
||||
typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state,
|
||||
PipelineLoadMmaDO& pipeline_load_mma_do,
|
||||
PipelineLoadMmaQ& pipeline_load_mma_q,
|
||||
typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state,
|
||||
PipelineLoadMmaDO& pipeline_load_mma_do,
|
||||
typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_consumer_state,
|
||||
PipelineMmaComputeS& pipeline_mma_compute_s,
|
||||
PipelineMmaComputeS& pipeline_mma_compute_s,
|
||||
typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_producer_state,
|
||||
PipelineMmaComputeDP& pipeline_mma_compute_dp,
|
||||
PipelineMmaComputeDP& pipeline_mma_compute_dp,
|
||||
typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_producer_state,
|
||||
PipelineMmaReduceDQ& pipeline_mma_reduce_dq,
|
||||
PipelineMmaReduceDQ& pipeline_mma_reduce_dq,
|
||||
typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_producer_state,
|
||||
PipelineComputeMmaP& pipeline_compute_mma_p,
|
||||
PipelineComputeMmaP& pipeline_compute_mma_p,
|
||||
typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_consumer_state,
|
||||
PipelineComputeMmaDS& pipeline_compute_mma_ds,
|
||||
PipelineComputeMmaDS& pipeline_compute_mma_ds,
|
||||
typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_consumer_state,
|
||||
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
|
||||
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) {
|
||||
|
||||
|
||||
auto [Q, K, D, HB] = problem_shape;
|
||||
|
||||
auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{});
|
||||
@ -710,7 +685,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
Tensor tDVrP = TiledMmaPDO::make_fragment_A(sP)(_, _, _, _0{});
|
||||
tDVrP.data() = TmemAllocation::kP;
|
||||
Tensor tDVrDOT = TiledMmaPDO::make_fragment_B(sDOT);
|
||||
|
||||
|
||||
TiledMmaKQ tiled_mma_kq;
|
||||
TiledMmaVDO tiled_mma_vdo;
|
||||
TiledMmaDSK tiled_mma_dsk;
|
||||
@ -948,8 +923,6 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
TensorC const& coord,
|
||||
TensorShape const& tensor_shape) {
|
||||
|
||||
Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); });
|
||||
|
||||
auto copy_op = make_cotiled_copy(
|
||||
Copy_Atom<UniversalCopy<uint128_t>, Element>{},
|
||||
make_layout(make_shape(_1{}, Int<sizeof(uint128_t) / sizeof(Element)>{})),
|
||||
@ -957,91 +930,42 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
);
|
||||
auto thr_copy = copy_op.get_slice(_0{});
|
||||
|
||||
Tensor tCg = thr_copy.partition_D(gmem);
|
||||
Tensor tCr = thr_copy.partition_S(quantize(regs));
|
||||
Tensor tPc = thr_copy.partition_D(preds);
|
||||
auto tCg = thr_copy.partition_D(gmem);
|
||||
auto tCr = thr_copy.partition_S(quantize(regs));
|
||||
auto tCc = thr_copy.partition_D(coord);
|
||||
|
||||
copy_if(copy_op, tPc, tCr, tCg);
|
||||
}
|
||||
constexpr int R = decltype(tCr.layout())::rank;
|
||||
auto tCg_v = group_modes<1, R>(tCg);
|
||||
auto tCr_v = group_modes<1, R>(tCr);
|
||||
auto tCc_v = group_modes<1, R>(tCc);
|
||||
auto tCp_v = make_tensor<bool>(shape<1>(tCc_v));
|
||||
|
||||
|
||||
template<class BlkCoord, class BlkOffset, class ProblemShape_>
|
||||
CUTLASS_DEVICE void epilogue_clear(
|
||||
BlkCoord const& blk_coord,
|
||||
BlkOffset const& blk_offset,
|
||||
ProblemShape_ const& problem_shape,
|
||||
MainloopArguments const& mainloop_args,
|
||||
EpilogueArguments const& epilogue_args) {
|
||||
|
||||
auto [Q, K, D, HB] = problem_shape;
|
||||
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch] = blk_coord;
|
||||
|
||||
auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk);
|
||||
auto mDK = domain_offset(select<1,2,3>(blk_offset), mDK_in);
|
||||
auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{})
|
||||
(_, _, blk_coord_k, _0{}, blk_coord_batch);
|
||||
|
||||
Tensor cDK = domain_offset(
|
||||
make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}),
|
||||
make_identity_tensor(take<0,2>(TileShapeDSQ{}))
|
||||
);
|
||||
|
||||
auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv);
|
||||
auto mDV = domain_offset(select<1,2,3>(blk_offset), mDV_in);
|
||||
auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{})
|
||||
(_, _, blk_coord_k, _0{}, blk_coord_batch);
|
||||
|
||||
Tensor cDV = domain_offset(
|
||||
make_coord(blk_coord_k * TileShapeK{}, _0{}),
|
||||
make_identity_tensor(take<0,2>(TileShapePDO{}))
|
||||
);
|
||||
|
||||
if (threadIdx.x >= 256) {
|
||||
return;
|
||||
for (int i = 0; i < size(tCp_v); ++i) {
|
||||
tCp_v(i) = elem_less(tCc_v(_0{},i), tensor_shape);
|
||||
}
|
||||
|
||||
auto tiled_copy = make_cotiled_copy(
|
||||
Copy_Atom<UniversalCopy<uint128_t>, Element>{},
|
||||
make_ordered_layout(make_shape(_256{}, Int<sizeof(uint128_t) / sizeof(Element)>{}), Step<_1, _0>{}),
|
||||
make_ordered_layout(make_shape(TileShapeK{}, TileShapeDQK{}), Step<_1, _0>{}));
|
||||
|
||||
auto thr_copy = tiled_copy.get_slice(threadIdx.x);
|
||||
auto tCgDK = thr_copy.partition_D(gDK);
|
||||
auto tCcDK = thr_copy.partition_S(cDK);
|
||||
auto tCrDK = make_tensor<Element>(shape(tCcDK));
|
||||
|
||||
clear(tCrDK);
|
||||
store(tCgDK, tCrDK, tCcDK, select<1,2>(problem_shape));
|
||||
|
||||
auto tCgDV = thr_copy.partition_D(gDV);
|
||||
auto tCcDV = thr_copy.partition_S(cDV);
|
||||
auto tCrDV = make_tensor<Element>(shape(tCcDV));
|
||||
|
||||
clear(tCrDV);
|
||||
store(tCgDV, tCrDV, tCcDV, select<1,2>(problem_shape));
|
||||
copy_if(copy_op, tCp_v, tCr_v, tCg_v);
|
||||
}
|
||||
|
||||
|
||||
template<class BlkCoord, class BlkOffset, class ProblemShape_>
|
||||
template<class BlkCoord>
|
||||
CUTLASS_DEVICE void epilogue(
|
||||
BlkCoord const& blk_coord,
|
||||
BlkOffset const& blk_offset,
|
||||
ProblemShape_ const& problem_shape,
|
||||
ProblemShape const& problem_shape,
|
||||
MainloopArguments const& mainloop_args,
|
||||
EpilogueArguments const& epilogue_args,
|
||||
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
|
||||
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) {
|
||||
|
||||
auto [Q, K, D, HB] = problem_shape;
|
||||
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch] = blk_coord;
|
||||
auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord;
|
||||
|
||||
auto load_op = SM100_TMEM_LOAD_32dp32b16x{};
|
||||
|
||||
auto tDKtDK = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{});
|
||||
tDKtDK.data() = TmemAllocation::kDK;
|
||||
|
||||
auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk);
|
||||
auto mDK = domain_offset(select<1,2,3>(blk_offset), mDK_in);
|
||||
auto mDK = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk);
|
||||
auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{})
|
||||
(_, _, blk_coord_k, _0{}, blk_coord_batch);
|
||||
|
||||
@ -1076,13 +1000,12 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
auto tDVtDV = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{});
|
||||
tDVtDV.data() = TmemAllocation::kDV;
|
||||
|
||||
auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv);
|
||||
auto mDV = domain_offset(select<1,2,3>(blk_offset), mDV_in);
|
||||
auto mDV = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv);
|
||||
auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{})
|
||||
(_, _, blk_coord_k, _0{}, blk_coord_batch);
|
||||
|
||||
Tensor cDV = domain_offset(
|
||||
make_coord(blk_coord_k * TileShapeK{}, _0{}),
|
||||
make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}),
|
||||
make_identity_tensor(take<0,2>(TileShapePDO{}))
|
||||
);
|
||||
|
||||
@ -1126,11 +1049,10 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
}
|
||||
|
||||
|
||||
template<class BlkCoord, class BlkOffset, class ProblemShape_>
|
||||
template<class BlkCoord>
|
||||
CUTLASS_DEVICE void compute(
|
||||
BlkCoord const& blk_coord,
|
||||
BlkOffset const& blk_offset,
|
||||
ProblemShape_ const& problem_shape,
|
||||
ProblemShape const& problem_shape,
|
||||
int iter_index,
|
||||
int iter_count,
|
||||
MainloopArguments const& mainloop_args,
|
||||
@ -1151,7 +1073,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
|
||||
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) {
|
||||
|
||||
|
||||
|
||||
auto [Q, K, D, HB] = problem_shape;
|
||||
|
||||
// in tmem, S & P overlap
|
||||
@ -1192,7 +1114,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
Tensor tTR_cST = split_wg(thread_t2r.partition_D(cST));
|
||||
Tensor tTR_rST = make_tensor<ElementAcc>(shape(tTR_cST));
|
||||
Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST));
|
||||
|
||||
|
||||
Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT);
|
||||
Tensor tTR_cDPT = split_wg(tTR_cDPT_p);
|
||||
Tensor tTR_rDPT = make_tensor<ElementAcc>(shape(tTR_cDPT));
|
||||
@ -1214,9 +1136,6 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
auto tRT_tP = split_wg(thread_r2t.partition_D(tDVrP));
|
||||
auto tRT_cST = split_wg(thread_r2t.partition_S(tDVcST));
|
||||
|
||||
bool is_residual_k = get<1>(blk_coord) * TileShapeK{} + TileShapeK{} >= get<1>(problem_shape);
|
||||
int last_iter = iter_count - 1 + iter_index;
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
while (iter_count > 0) {
|
||||
// wait for S and P
|
||||
@ -1233,28 +1152,20 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
fn(cute::false_type{});
|
||||
}
|
||||
};
|
||||
|
||||
bool leading_causal_masking = false;
|
||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask, Mask>) {
|
||||
leading_causal_masking = warp_uniform(iter_index == get<1>(blk_coord));
|
||||
}
|
||||
bool trailing_residual_masking = false;
|
||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::ResidualMaskForBackward, Mask>) {
|
||||
trailing_residual_masking = warp_uniform((iter_index == last_iter) || is_residual_k);
|
||||
}
|
||||
|
||||
dispatch_bool(leading_causal_masking || trailing_residual_masking, [&](auto is_masked_tile) {
|
||||
|
||||
dispatch_bool(std::is_base_of_v<cutlass::fmha::collective::CausalMask, Mask> &&
|
||||
warp_uniform(iter_index == get<1>(blk_coord)), [&](auto is_causal_masked_tile) {
|
||||
|
||||
// compute P = softmax(S, LSE)
|
||||
cute::copy(tiled_t2r, tTR_tST, tTR_rST);
|
||||
|
||||
if constexpr (decltype(is_masked_tile)::value) {
|
||||
|
||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask, Mask> && decltype(is_causal_masked_tile)::value) {
|
||||
Mask{}.apply_mask(tTR_rST, [&](int i) {
|
||||
auto c_transpose = tTR_cST(i);
|
||||
return make_coord(get<1>(c_transpose) + iter_index * TileShapeQ{}, get<0>(c_transpose) + get<1>(blk_coord) * TileShapeK{});
|
||||
}, problem_shape);
|
||||
}
|
||||
|
||||
|
||||
ElementAcc log2_e = static_cast<ElementAcc>(M_LOG2E);
|
||||
float2 softmax_scale_log2_e;
|
||||
softmax_scale_log2_e.x = mainloop_args.softmax_scale * log2_e;
|
||||
@ -1273,16 +1184,16 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
tTR_rST(i) = ::exp2f(out.x);
|
||||
tTR_rST(i+1) = ::exp2f(out.y);
|
||||
}
|
||||
|
||||
|
||||
auto tRT_rST = quantize(tTR_rST);
|
||||
auto tRT_rST_reshaped = make_tensor(tRT_rST.data(), shape(tRT_cST));
|
||||
|
||||
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
cutlass::arch::NamedBarrier(
|
||||
kNumComputeWarps * NumThreadsPerWarp,
|
||||
cutlass::arch::ReservedNamedBarriers::TransformBarrier
|
||||
).arrive_and_wait();
|
||||
|
||||
|
||||
cute::copy(tiled_r2t, tRT_rST_reshaped, tRT_tP);
|
||||
});
|
||||
|
||||
@ -1364,15 +1275,15 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
}
|
||||
|
||||
epilogue(
|
||||
blk_coord, blk_offset, problem_shape, mainloop_args, epilogue_args,
|
||||
blk_coord, problem_shape, mainloop_args, epilogue_args,
|
||||
pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state
|
||||
);
|
||||
}
|
||||
|
||||
template<class BlkCoord, class ProblemShape_>
|
||||
template<class BlkCoord>
|
||||
CUTLASS_DEVICE void reduce(
|
||||
BlkCoord const& blk_coord,
|
||||
ProblemShape_ const& problem_shape,
|
||||
ProblemShape const& problem_shape,
|
||||
int iter_index,
|
||||
int iter_count,
|
||||
MainloopArguments const& mainloop_args,
|
||||
@ -1382,12 +1293,12 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state,
|
||||
PipelineReduceTmaStore& pipeline_reduce_tma_store,
|
||||
typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state) {
|
||||
|
||||
|
||||
using X = Underscore;
|
||||
|
||||
|
||||
auto [Q, K, D, HB] = problem_shape;
|
||||
|
||||
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch] = blk_coord;
|
||||
auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord;
|
||||
|
||||
// must match TileShapeDQ
|
||||
auto load_op = SM100_TMEM_LOAD_32dp32b32x{};
|
||||
@ -1396,7 +1307,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
tDQtDQ.data() = TmemAllocation::kDQ;
|
||||
|
||||
Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB));
|
||||
auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step<X, _1, _1>{})
|
||||
auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step<_1, _1, X>{})
|
||||
(_, _, _, _0{}, blk_coord_batch);
|
||||
|
||||
Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{}));
|
||||
@ -1465,7 +1376,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
iter_index += 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
CUTLASS_DEVICE void operator()(Params const& params, char* smem) {
|
||||
int warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
@ -1650,7 +1561,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
typename decltype(pipeline_compute_mma_p)::PipelineState pipeline_compute_mma_p_consumer_state;
|
||||
typename decltype(pipeline_compute_mma_ds)::PipelineState pipeline_compute_mma_ds_consumer_state;
|
||||
typename decltype(pipeline_mma_compute_dkdv)::PipelineState pipeline_mma_compute_dkdv_consumer_state;
|
||||
|
||||
|
||||
auto pipeline_load_mma_q_producer_state = make_producer_start_state<decltype(pipeline_load_mma_q)>();
|
||||
auto pipeline_load_mma_do_producer_state = make_producer_start_state<decltype(pipeline_load_mma_do)>();
|
||||
auto pipeline_load_compute_lse_producer_state = make_producer_start_state<decltype(pipeline_load_compute_lse)>();
|
||||
@ -1665,45 +1576,27 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
|
||||
pipeline_init_wait(size(ClusterShape{}));
|
||||
|
||||
auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, make_coord(blockIdx.y, blockIdx.z));
|
||||
auto [problem_shape, blk_offset] = apply_variable_length_offset(
|
||||
params.problem_shape,
|
||||
blk_coord
|
||||
);
|
||||
auto blk_coord = make_coord(_0{}, blockIdx.x, make_coord(blockIdx.y, blockIdx.z));
|
||||
auto problem_shape = params.problem_shape;
|
||||
int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{});
|
||||
int iter_start = 0;
|
||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask, Mask>) {
|
||||
iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{};
|
||||
}
|
||||
if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) {
|
||||
return;
|
||||
}
|
||||
iter_count -= iter_start;
|
||||
|
||||
if (iter_count <= 0) {
|
||||
epilogue_clear(
|
||||
blk_coord,
|
||||
blk_offset,
|
||||
problem_shape,
|
||||
params.mainloop,
|
||||
params.epilogue
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (role == WarpRole::Load) {
|
||||
warpgroup_reg_set<RegisterAllocation::kLoad>();
|
||||
|
||||
|
||||
load(
|
||||
blk_coord,
|
||||
blk_offset,
|
||||
problem_shape,
|
||||
iter_start,
|
||||
iter_count,
|
||||
params.mainloop,
|
||||
params.mainloop_params,
|
||||
shared_storage.tensors,
|
||||
pipeline_load_mma_q, pipeline_load_mma_q_producer_state,
|
||||
pipeline_load_mma_q, pipeline_load_mma_q_producer_state,
|
||||
pipeline_load_mma_do, pipeline_load_mma_do_producer_state,
|
||||
pipeline_load_compute_lse, pipeline_load_compute_lse_producer_state,
|
||||
pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_producer_state
|
||||
@ -1715,7 +1608,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
|
||||
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr);
|
||||
__syncwarp();
|
||||
|
||||
|
||||
mma(
|
||||
blk_coord,
|
||||
problem_shape,
|
||||
@ -1723,7 +1616,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
iter_count,
|
||||
params.mainloop,
|
||||
shared_storage.tensors,
|
||||
pipeline_load_mma_q, pipeline_load_mma_q_consumer_state,
|
||||
pipeline_load_mma_q, pipeline_load_mma_q_consumer_state,
|
||||
pipeline_load_mma_do, pipeline_load_mma_do_consumer_state,
|
||||
pipeline_mma_compute_s, pipeline_mma_compute_s_producer_state,
|
||||
pipeline_mma_compute_dp, pipeline_mma_compute_dp_producer_state,
|
||||
@ -1736,10 +1629,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
}
|
||||
else if (role == WarpRole::Compute) {
|
||||
warpgroup_reg_set<RegisterAllocation::kCompute>();
|
||||
|
||||
|
||||
compute(
|
||||
blk_coord,
|
||||
blk_offset,
|
||||
problem_shape,
|
||||
iter_start,
|
||||
iter_count,
|
||||
@ -1768,7 +1660,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
}
|
||||
else if (role == WarpRole::Reduce) {
|
||||
warpgroup_reg_set<RegisterAllocation::kReduce>();
|
||||
|
||||
|
||||
reduce(
|
||||
blk_coord,
|
||||
problem_shape,
|
||||
@ -1785,9 +1677,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
}
|
||||
else {
|
||||
warpgroup_reg_set<RegisterAllocation::kEmpty>();
|
||||
|
||||
|
||||
/* no-op */
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -356,7 +356,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineO>();
|
||||
|
||||
CollectiveMainloop mainloop;
|
||||
CollectiveEpilogue epilogue{params.epilogue};
|
||||
CollectiveEpilogue epilogue;
|
||||
|
||||
if (role == WarpRole::Softmax0 || role == WarpRole::Softmax1) {
|
||||
warpgroup_reg_set<NumRegsSoftmax>();
|
||||
@ -407,8 +407,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
pipeline_s0_corr, pipeline_s0_corr_consumer_state,
|
||||
pipeline_s1_corr, pipeline_s1_corr_consumer_state,
|
||||
pipeline_mma_corr, pipeline_mma_corr_consumer_state,
|
||||
pipeline_corr_epi, pipeline_corr_epi_producer_state,
|
||||
epilogue
|
||||
pipeline_corr_epi, pipeline_corr_epi_producer_state
|
||||
);
|
||||
|
||||
|
||||
|
||||
@ -146,7 +146,7 @@ struct Sm100FmhaMlaReductionKernel {
|
||||
ElementAcc sum_lse = 0;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
sum_lse = sum_lse + expf(local_lse[i] - lse_max);
|
||||
sum_lse = sum_lse + expf(local_lse[i] - params.scale * lse_max);
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
@ -156,7 +156,7 @@ struct Sm100FmhaMlaReductionKernel {
|
||||
|
||||
sum_lse = __shfl_sync(0xffffffff, sum_lse, 0);
|
||||
|
||||
ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementAcc>::infinity() : logf(sum_lse) + lse_max;
|
||||
ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementAcc>::infinity() : logf(sum_lse) + params.scale * lse_max;
|
||||
if (threadIdx.x == 0 and params.ptr_lse != nullptr) {
|
||||
gLSE(0) = global_lse;
|
||||
}
|
||||
|
||||
@ -44,10 +44,10 @@ template<
|
||||
class Fusion
|
||||
>
|
||||
void __global__ fmha_bwd_reference_dQ_kernel(
|
||||
ProblemShape problem_shape_in,
|
||||
TensorQ mQ_in, TensorK mK_in, TensorV mV_in,
|
||||
TensorO mO_in, TensorLSE mLSE_in, TensorDO mDO_in,
|
||||
TensorDQ mDQ_in, /* TensorDK mDK, TensorDV mDV, */
|
||||
ProblemShape problem_shape,
|
||||
TensorQ mQ, TensorK mK, TensorV mV,
|
||||
TensorO mO, TensorLSE mLSE, TensorDO mDO,
|
||||
TensorDQ mDQ, /* TensorDK mDK, TensorDV mDV, */
|
||||
Fusion fusion) {
|
||||
|
||||
using namespace cute;
|
||||
@ -58,28 +58,15 @@ void __global__ fmha_bwd_reference_dQ_kernel(
|
||||
extern __shared__ char mS_mem[];
|
||||
Element* mS = reinterpret_cast<Element*>(mS_mem);
|
||||
|
||||
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<2>(problem_shape_in)));
|
||||
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<1>(mO)));
|
||||
|
||||
for (int idx_L = blockIdx.y; idx_L < size<3>(problem_shape_in); idx_L += gridDim.y) {
|
||||
auto [problem_shape, offset] = apply_variable_length_offset(
|
||||
problem_shape_in,
|
||||
make_coord(_0{}, _0{}, _0{}, idx2crd(idx_L, get<3>(problem_shape_in)))
|
||||
);
|
||||
// problem_shape = problem_shape_in;
|
||||
// offset = repeat_like(problem_shape_in, _0{});
|
||||
auto mQ = domain_offset(select<0,2,3>(offset), mQ_in);
|
||||
auto mK = domain_offset(select<1,2,3>(offset), mK_in);
|
||||
auto mV = domain_offset(select<1,2,3>(offset), mV_in);
|
||||
auto mO = domain_offset(select<0,2,3>(offset), mO_in);
|
||||
auto mLSE = domain_offset(select<0,3>(offset), mLSE_in);
|
||||
auto mDO = domain_offset(select<0,2,3>(offset), mDO_in);
|
||||
auto mDQ = domain_offset(select<0,2,3>(offset), mDQ_in);
|
||||
for (int idx_Q = blockIdx.x; idx_Q < size<0>(problem_shape); idx_Q += gridDim.x) {
|
||||
for (int idx_K = threadIdx.x; idx_K < size<1>(problem_shape); idx_K += blockDim.x) {
|
||||
for (int idx_L = blockIdx.y; idx_L < size<2>(mDQ); idx_L += gridDim.y) {
|
||||
for (int idx_Q = blockIdx.x; idx_Q < size<0>(mDQ); idx_Q += gridDim.x) {
|
||||
for (int idx_K = threadIdx.x; idx_K < size<0>(mK); idx_K += blockDim.x) {
|
||||
ElementAccumulator acc_qk = 0;
|
||||
ElementAccumulator acc_dov = 0;
|
||||
ElementAccumulator acc_doo = 0;
|
||||
for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) {
|
||||
for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) {
|
||||
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
|
||||
acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
|
||||
acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L);
|
||||
@ -96,9 +83,9 @@ void __global__ fmha_bwd_reference_dQ_kernel(
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) {
|
||||
for (int idx_D = threadIdx.x; idx_D < size<1>(mDQ); idx_D += blockDim.x) {
|
||||
ElementAccumulator acc = 0;
|
||||
for (int idx_K = 0; idx_K < size<1>(problem_shape); idx_K++) {
|
||||
for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) {
|
||||
acc += mS[idx_K] * mK(idx_K, idx_D, idx_L);
|
||||
}
|
||||
mDQ(idx_Q, idx_D, idx_L) = static_cast<typename TensorDQ::value_type>(acc);
|
||||
@ -117,10 +104,10 @@ template<
|
||||
class Fusion
|
||||
>
|
||||
void __global__ fmha_bwd_reference_dK_kernel(
|
||||
ProblemShape problem_shape_in,
|
||||
TensorQ mQ_in, TensorK mK_in, TensorV mV_in,
|
||||
TensorO mO_in, TensorLSE mLSE_in, TensorDO mDO_in,
|
||||
/* TensorDQ mDQ_in, */ TensorDK mDK_in, /* TensorDV mDV_in, */
|
||||
ProblemShape problem_shape,
|
||||
TensorQ mQ, TensorK mK, TensorV mV,
|
||||
TensorO mO, TensorLSE mLSE, TensorDO mDO,
|
||||
/* TensorDQ mDQ, */ TensorDK mDK, /* TensorDV mDV, */
|
||||
Fusion fusion) {
|
||||
|
||||
using namespace cute;
|
||||
@ -131,28 +118,15 @@ void __global__ fmha_bwd_reference_dK_kernel(
|
||||
extern __shared__ char mS_mem[];
|
||||
Element* mS = reinterpret_cast<Element*>(mS_mem);
|
||||
|
||||
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<2>(problem_shape_in)));
|
||||
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<1>(mO)));
|
||||
|
||||
for (int idx_L = blockIdx.y; idx_L < size<3>(problem_shape_in); idx_L += gridDim.y) {
|
||||
auto [problem_shape, offset] = apply_variable_length_offset(
|
||||
problem_shape_in,
|
||||
make_coord(_0{}, _0{}, _0{}, idx2crd(idx_L, get<3>(problem_shape_in)))
|
||||
);
|
||||
// problem_shape = problem_shape_in;
|
||||
// offset = repeat_like(problem_shape_in, _0{});
|
||||
auto mQ = domain_offset(select<0,2,3>(offset), mQ_in);
|
||||
auto mK = domain_offset(select<1,2,3>(offset), mK_in);
|
||||
auto mV = domain_offset(select<1,2,3>(offset), mV_in);
|
||||
auto mO = domain_offset(select<0,2,3>(offset), mO_in);
|
||||
auto mLSE = domain_offset(select<0,3>(offset), mLSE_in);
|
||||
auto mDO = domain_offset(select<0,2,3>(offset), mDO_in);
|
||||
auto mDK = domain_offset(select<1,2,3>(offset), mDK_in);
|
||||
for (int idx_K = blockIdx.x; idx_K < size<1>(problem_shape); idx_K += gridDim.x) {
|
||||
for (int idx_Q = threadIdx.x; idx_Q < size<0>(problem_shape); idx_Q += blockDim.x) {
|
||||
for (int idx_L = blockIdx.y; idx_L < size<2>(mDK); idx_L += gridDim.y) {
|
||||
for (int idx_K = blockIdx.x; idx_K < size<0>(mDK); idx_K += gridDim.x) {
|
||||
for (int idx_Q = threadIdx.x; idx_Q < size<0>(mDO); idx_Q += blockDim.x) {
|
||||
ElementAccumulator acc_qk = 0;
|
||||
ElementAccumulator acc_dov = 0;
|
||||
ElementAccumulator acc_doo = 0;
|
||||
for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) {
|
||||
for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) {
|
||||
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
|
||||
acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
|
||||
acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L);
|
||||
@ -169,9 +143,9 @@ void __global__ fmha_bwd_reference_dK_kernel(
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) {
|
||||
for (int idx_D = threadIdx.x; idx_D < size<1>(mDK); idx_D += blockDim.x) {
|
||||
ElementAccumulator acc = 0;
|
||||
for (int idx_Q = 0; idx_Q < size<0>(problem_shape); idx_Q++) {
|
||||
for (int idx_Q = 0; idx_Q < size<0>(mDO); idx_Q++) {
|
||||
acc += mS[idx_Q] * mQ(idx_Q, idx_D, idx_L);
|
||||
}
|
||||
mDK(idx_K, idx_D, idx_L) = static_cast<typename TensorDK::value_type>(acc);
|
||||
@ -190,10 +164,10 @@ template<
|
||||
class Fusion
|
||||
>
|
||||
void __global__ fmha_bwd_reference_dV_kernel(
|
||||
ProblemShape problem_shape_in,
|
||||
TensorQ mQ_in, TensorK mK_in, TensorV mV_in,
|
||||
TensorO mO_in, TensorLSE mLSE_in, TensorDO mDO_in,
|
||||
/* TensorDQ mDQ_in, TensorDK mDK_in, */ TensorDV mDV_in,
|
||||
ProblemShape problem_shape,
|
||||
TensorQ mQ, TensorK mK, TensorV mV,
|
||||
TensorO mO, TensorLSE mLSE, TensorDO mDO,
|
||||
/* TensorDQ mDQ, TensorDK mDK, */ TensorDV mDV,
|
||||
Fusion fusion) {
|
||||
|
||||
using namespace cute;
|
||||
@ -204,27 +178,14 @@ void __global__ fmha_bwd_reference_dV_kernel(
|
||||
extern __shared__ char mS_mem[];
|
||||
Element* mS = reinterpret_cast<Element*>(mS_mem);
|
||||
|
||||
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<2>(problem_shape_in)));
|
||||
ElementAcc softmax_scale = static_cast<ElementAcc>(1.0 / sqrt(1.0 * size<1>(mO)));
|
||||
|
||||
for (int idx_L = blockIdx.y; idx_L < size<3>(problem_shape_in); idx_L += gridDim.y) {
|
||||
auto [problem_shape, offset] = apply_variable_length_offset(
|
||||
problem_shape_in,
|
||||
make_coord(_0{}, _0{}, _0{}, idx2crd(idx_L, get<3>(problem_shape_in)))
|
||||
);
|
||||
// problem_shape = problem_shape_in;
|
||||
// offset = repeat_like(problem_shape_in, _0{});
|
||||
auto mQ = domain_offset(select<0,2,3>(offset), mQ_in);
|
||||
auto mK = domain_offset(select<1,2,3>(offset), mK_in);
|
||||
auto mV = domain_offset(select<1,2,3>(offset), mV_in);
|
||||
auto mO = domain_offset(select<0,2,3>(offset), mO_in);
|
||||
auto mLSE = domain_offset(select<0,3>(offset), mLSE_in);
|
||||
auto mDO = domain_offset(select<0,2,3>(offset), mDO_in);
|
||||
auto mDV = domain_offset(select<1,2,3>(offset), mDV_in);
|
||||
for (int idx_K = blockIdx.x; idx_K < size<1>(problem_shape); idx_K += gridDim.x) {
|
||||
for (int idx_Q = threadIdx.x; idx_Q < size<0>(problem_shape); idx_Q += blockDim.x) {
|
||||
for (int idx_L = blockIdx.y; idx_L < size<2>(mDV); idx_L += gridDim.y) {
|
||||
for (int idx_K = blockIdx.x; idx_K < size<0>(mDV); idx_K += gridDim.x) {
|
||||
for (int idx_Q = threadIdx.x; idx_Q < size<0>(mDO); idx_Q += blockDim.x) {
|
||||
ElementAcc acc_qk = 0;
|
||||
|
||||
for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) {
|
||||
for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) {
|
||||
ElementAcc rQ = mQ(idx_Q, idx_D0, idx_L);
|
||||
ElementAcc rK = mK(idx_K, idx_D0, idx_L);
|
||||
acc_qk += rQ * rK;
|
||||
@ -241,9 +202,9 @@ void __global__ fmha_bwd_reference_dV_kernel(
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) {
|
||||
for (int idx_D = threadIdx.x; idx_D < size<1>(mDV); idx_D += blockDim.x) {
|
||||
ElementAcc acc = 0;
|
||||
for (int idx_Q = 0; idx_Q < size<0>(problem_shape); idx_Q++) {
|
||||
for (int idx_Q = 0; idx_Q < size<0>(mDO); idx_Q++) {
|
||||
ElementAcc rS = mS[idx_Q];
|
||||
ElementAcc rDO = mDO(idx_Q, idx_D, idx_L);
|
||||
acc += rS * rDO;
|
||||
|
||||
@ -127,7 +127,7 @@ void __global__ fmha_reference_kernel(
|
||||
mO(idx_Q + offset_Q, idx_D, idx_L) = static_cast<typename TensorO::value_type>(acc * scale);
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0 && mLSE.data() != nullptr) {
|
||||
if (threadIdx.x == 0) {
|
||||
mLSE(idx_Q + offset_Q, idx_L) = log(sum) + softmax_scale * maxS;
|
||||
}
|
||||
|
||||
|
||||
@ -75,8 +75,6 @@ struct DeviceAllocation {
|
||||
|
||||
size_t size() const { return size_; }
|
||||
|
||||
size_t get_storage_size() const { return (size_ + offset_) * sizeof(T); }
|
||||
|
||||
void copy_from_host(const T* ptr, size_t sz) {
|
||||
auto ret = cudaMemcpy(ptr_, ptr, sz * sizeof(T), cudaMemcpyDefault);
|
||||
assert(ret == cudaSuccess);
|
||||
|
||||
Reference in New Issue
Block a user