This reverts commit f12b1d75c9.
This commit is contained in:
@ -117,17 +117,15 @@ struct Options {
|
|||||||
int q = 256;
|
int q = 256;
|
||||||
int k = 256;
|
int k = 256;
|
||||||
int d = 128;
|
int d = 128;
|
||||||
int warmup_iterations = 1;
|
|
||||||
int iterations = 3;
|
int iterations = 3;
|
||||||
int tensor_ring_buffers = 1;
|
|
||||||
bool verify = false;
|
bool verify = false;
|
||||||
bool verbose = false;
|
bool verbose = false;
|
||||||
|
|
||||||
bool causal = false;
|
bool causal = false;
|
||||||
bool residual = false;
|
bool residual = false;
|
||||||
bool varlen = false;
|
bool varlen = false;
|
||||||
bool persistent = false;
|
|
||||||
int sm_count = 0;
|
int sm_count = 0;
|
||||||
|
|
||||||
std::string kernel_filter;
|
std::string kernel_filter;
|
||||||
|
|
||||||
InitStyle init_style_q = InitStyle::kRandom;
|
InitStyle init_style_q = InitStyle::kRandom;
|
||||||
@ -191,15 +189,10 @@ struct Options {
|
|||||||
if (b == -1) b = 16384 / k;
|
if (b == -1) b = 16384 / k;
|
||||||
if (b == 0) b = 1;
|
if (b == 0) b = 1;
|
||||||
|
|
||||||
cmd.get_cmd_line_argument("warmup_iterations", warmup_iterations, defaults.warmup_iterations);
|
|
||||||
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
|
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
|
||||||
cmd.get_cmd_line_argument("tensor_ring_buffers", tensor_ring_buffers, defaults.tensor_ring_buffers);
|
|
||||||
|
|
||||||
verify = cmd.check_cmd_line_flag("verify");
|
verify = cmd.check_cmd_line_flag("verify");
|
||||||
verbose = cmd.check_cmd_line_flag("verbose");
|
verbose = cmd.check_cmd_line_flag("verbose");
|
||||||
varlen = cmd.check_cmd_line_flag("varlen");
|
varlen = cmd.check_cmd_line_flag("varlen");
|
||||||
persistent = cmd.check_cmd_line_flag("persistent");
|
|
||||||
|
|
||||||
std::string mask;
|
std::string mask;
|
||||||
cmd.get_cmd_line_argument<std::string>("mask", mask, "");
|
cmd.get_cmd_line_argument<std::string>("mask", mask, "");
|
||||||
if (mask == "no" || mask == "") {
|
if (mask == "no" || mask == "") {
|
||||||
@ -217,7 +210,7 @@ struct Options {
|
|||||||
causal = false;
|
causal = false;
|
||||||
}
|
}
|
||||||
cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count);
|
cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count);
|
||||||
|
|
||||||
get_init_style_argument(cmd, "init-style", init_style_q, defaults.init_style_q);
|
get_init_style_argument(cmd, "init-style", init_style_q, defaults.init_style_q);
|
||||||
get_init_style_argument(cmd, "init-style", init_style_k, defaults.init_style_q);
|
get_init_style_argument(cmd, "init-style", init_style_k, defaults.init_style_q);
|
||||||
get_init_style_argument(cmd, "init-style", init_style_v, defaults.init_style_q);
|
get_init_style_argument(cmd, "init-style", init_style_v, defaults.init_style_q);
|
||||||
@ -242,13 +235,10 @@ struct Options {
|
|||||||
<< " --q=<int> Sets the Q extent\n"
|
<< " --q=<int> Sets the Q extent\n"
|
||||||
<< " --k=<int> Sets the K extent\n"
|
<< " --k=<int> Sets the K extent\n"
|
||||||
<< " --d=<int> Sets the D extentn"
|
<< " --d=<int> Sets the D extentn"
|
||||||
<< " --tensor_ring_buffers=<int> Sets the number of tensor ring buffers\n"
|
|
||||||
<< " --warmup_iterations=<int> Sets the warmup iterations\n"
|
|
||||||
<< " --iterations=<int> Benchmarking iterations\n"
|
<< " --iterations=<int> Benchmarking iterations\n"
|
||||||
<< " --verify Verify results\n"
|
<< " --verify Verify results\n"
|
||||||
<< " --verbose Print smem and execution time per kernel\n"
|
<< " --verbose Print smem and execution time per kernel\n"
|
||||||
<< " --mask=<no|residual|causal> Enables masking\n"
|
<< " --mask=<no|residual|causal> Enables masking\n"
|
||||||
<< " --persistent Enables persistent scheduler\n"
|
|
||||||
<< " --varlen Enables variable sequence length\n"
|
<< " --varlen Enables variable sequence length\n"
|
||||||
<< " B*Q and B*K become the total sequence length\n"
|
<< " B*Q and B*K become the total sequence length\n"
|
||||||
<< " and are split B-ways, alternatingly +10% and -10%\n"
|
<< " and are split B-ways, alternatingly +10% and -10%\n"
|
||||||
@ -389,55 +379,40 @@ struct FwdRunner {
|
|||||||
StrideLSE stride_LSE;
|
StrideLSE stride_LSE;
|
||||||
uint64_t seed = 0;
|
uint64_t seed = 0;
|
||||||
|
|
||||||
struct DeviceBuffer {
|
DeviceAllocation<Element> block_Q;
|
||||||
DeviceAllocation<Element> block_Q;
|
DeviceAllocation<Element> block_K;
|
||||||
DeviceAllocation<Element> block_K;
|
DeviceAllocation<Element> block_V;
|
||||||
DeviceAllocation<Element> block_V;
|
DeviceAllocation<ElementOut> block_O;
|
||||||
DeviceAllocation<ElementOut> block_O;
|
DeviceAllocation<ElementAccumulatorPV> block_LSE;
|
||||||
DeviceAllocation<ElementAccumulatorPV> block_LSE;
|
DeviceAllocation<ElementOut> block_ref_O;
|
||||||
DeviceAllocation<ElementOut> block_ref_O;
|
DeviceAllocation<ElementAccumulatorPV> block_ref_LSE;
|
||||||
DeviceAllocation<ElementAccumulatorPV> block_ref_LSE;
|
|
||||||
DeviceAllocation<int> device_cumulative_seqlen_q;
|
|
||||||
DeviceAllocation<int> device_cumulative_seqlen_kv;
|
|
||||||
|
|
||||||
DeviceBuffer() = default;
|
|
||||||
DeviceBuffer(const DeviceBuffer&) = delete;
|
|
||||||
DeviceBuffer& operator=(const DeviceBuffer&) = delete;
|
|
||||||
|
|
||||||
size_t get_storage_size() const {
|
|
||||||
return block_Q.get_storage_size() + block_K.get_storage_size() + block_V.get_storage_size()
|
|
||||||
+ block_O.get_storage_size() + block_LSE.get_storage_size() + block_ref_O.get_storage_size()
|
|
||||||
+ block_ref_LSE.get_storage_size() + device_cumulative_seqlen_q.get_storage_size()
|
|
||||||
+ device_cumulative_seqlen_kv.get_storage_size();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
std::vector<std::unique_ptr<DeviceBuffer>> buffers;
|
|
||||||
|
|
||||||
std::vector<int> cumulative_seqlen_q;
|
std::vector<int> cumulative_seqlen_q;
|
||||||
std::vector<int> cumulative_seqlen_kv;
|
std::vector<int> cumulative_seqlen_kv;
|
||||||
|
DeviceAllocation<int> device_cumulative_seqlen_q;
|
||||||
|
DeviceAllocation<int> device_cumulative_seqlen_kv;
|
||||||
|
|
||||||
//
|
//
|
||||||
// Methods
|
// Methods
|
||||||
//
|
//
|
||||||
bool verify(const ProblemShapeType& problem_shape, DeviceBuffer& buffer) {
|
bool verify(const ProblemShapeType& problem_shape) {
|
||||||
Tensor mQ = make_tensor(make_gmem_ptr(buffer.block_Q.get()),
|
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()),
|
||||||
select<0,2,3>(problem_shape),
|
select<0,2,3>(problem_shape),
|
||||||
stride_Q);
|
stride_Q);
|
||||||
|
|
||||||
Tensor mK = make_tensor(make_gmem_ptr(buffer.block_K.get()),
|
Tensor mK = make_tensor(make_gmem_ptr(block_K.get()),
|
||||||
select<1,2,3>(problem_shape),
|
select<1,2,3>(problem_shape),
|
||||||
stride_K);
|
stride_K);
|
||||||
|
|
||||||
Tensor mV = make_tensor(make_gmem_ptr(buffer.block_V.get()),
|
Tensor mV = make_tensor(make_gmem_ptr(block_V.get()),
|
||||||
select<1,2,3>(problem_shape),
|
select<1,2,3>(problem_shape),
|
||||||
stride_V);
|
stride_V);
|
||||||
|
|
||||||
Tensor mO = make_tensor(make_gmem_ptr(buffer.block_ref_O.get()),
|
Tensor mO = make_tensor(make_gmem_ptr(block_ref_O.get()),
|
||||||
select<0,2,3>(problem_shape),
|
select<0,2,3>(problem_shape),
|
||||||
stride_O);
|
stride_O);
|
||||||
|
|
||||||
Tensor mLSE = make_tensor(make_gmem_ptr(buffer.block_ref_LSE.get()),
|
Tensor mLSE = make_tensor(make_gmem_ptr(block_ref_LSE.get()),
|
||||||
select<0,3>(problem_shape),
|
select<0,3>(problem_shape),
|
||||||
stride_LSE);
|
stride_LSE);
|
||||||
|
|
||||||
@ -456,7 +431,7 @@ struct FwdRunner {
|
|||||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||||
double max_diff = 0;
|
double max_diff = 0;
|
||||||
double mean_diff = 0;
|
double mean_diff = 0;
|
||||||
reference_abs_diff(buffer.block_O, buffer.block_ref_O, max_diff, mean_diff);
|
reference_abs_diff(block_O, block_ref_O, max_diff, mean_diff);
|
||||||
|
|
||||||
bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||||
if (! passed_O) {
|
if (! passed_O) {
|
||||||
@ -464,13 +439,14 @@ struct FwdRunner {
|
|||||||
<< " mean " << mean_diff << std::endl;
|
<< " mean " << mean_diff << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
reference_abs_diff(buffer.block_LSE, buffer.block_ref_LSE, max_diff, mean_diff);
|
// reference_abs_diff(block_LSE, block_ref_LSE, max_diff, mean_diff);
|
||||||
|
|
||||||
bool passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
bool passed_LSE = true; // future work
|
||||||
if ( ! passed_LSE) {
|
// bool passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||||
std::cerr << "failed LSE: max diff " << max_diff
|
// if ( ! passed_LSE) {
|
||||||
<< " mean " << mean_diff << std::endl;
|
// std::cerr << "failed LSE: max diff " << max_diff
|
||||||
}
|
// << " mean " << mean_diff << std::endl;
|
||||||
|
// }
|
||||||
|
|
||||||
return passed_O && passed_LSE;
|
return passed_O && passed_LSE;
|
||||||
}
|
}
|
||||||
@ -583,71 +559,50 @@ struct FwdRunner {
|
|||||||
get<1,1>(stride_LSE) = 0;
|
get<1,1>(stride_LSE) = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto buffer_init_fn = [&](auto& buffer) {
|
block_Q.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
|
||||||
buffer.block_Q.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
|
block_K.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
|
||||||
buffer.block_K.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
|
block_V.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
|
||||||
buffer.block_V.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
|
block_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
|
||||||
buffer.block_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
|
block_LSE.reset(size(shape_LSE));
|
||||||
buffer.block_LSE.reset(size(shape_LSE));
|
block_ref_O.reset(size(shape_QO));
|
||||||
|
block_ref_LSE.reset(size(shape_LSE));
|
||||||
|
|
||||||
initialize_block(buffer.block_Q, seed + 2023, options.init_style_q);
|
initialize_block(block_Q, seed + 2023, options.init_style_q);
|
||||||
initialize_block(buffer.block_K, seed + 2022, options.init_style_k);
|
initialize_block(block_K, seed + 2022, options.init_style_k);
|
||||||
initialize_block(buffer.block_V, seed + 2021, options.init_style_v);
|
initialize_block(block_V, seed + 2021, options.init_style_v);
|
||||||
|
|
||||||
if ( ! cumulative_seqlen_q.empty()) {
|
if ( ! cumulative_seqlen_q.empty()) {
|
||||||
buffer.device_cumulative_seqlen_q.reset(cumulative_seqlen_q.size());
|
device_cumulative_seqlen_q.reset(cumulative_seqlen_q.size());
|
||||||
buffer.device_cumulative_seqlen_q.copy_from_host(
|
device_cumulative_seqlen_q.copy_from_host(
|
||||||
cumulative_seqlen_q.data(), cumulative_seqlen_q.size());
|
cumulative_seqlen_q.data(), cumulative_seqlen_q.size());
|
||||||
}
|
}
|
||||||
if ( ! cumulative_seqlen_kv.empty()) {
|
if ( ! cumulative_seqlen_kv.empty()) {
|
||||||
buffer.device_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size());
|
device_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size());
|
||||||
buffer.device_cumulative_seqlen_kv.copy_from_host(
|
device_cumulative_seqlen_kv.copy_from_host(
|
||||||
cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size());
|
cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size());
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
buffers.push_back(std::make_unique<DeviceBuffer>());
|
|
||||||
buffer_init_fn(*buffers.back());
|
|
||||||
|
|
||||||
int tensor_ring_buffers = options.tensor_ring_buffers;
|
|
||||||
|
|
||||||
for (int i = 1; i < tensor_ring_buffers; i++) {
|
|
||||||
buffers.push_back(std::make_unique<DeviceBuffer>());
|
|
||||||
buffer_init_fn(*buffers.back());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if constexpr (kIsVarlen) {
|
if constexpr (kIsVarlen) {
|
||||||
get<0>(problem_shape).cumulative_length = buffers[0]->device_cumulative_seqlen_q.get();
|
get<0>(problem_shape).cumulative_length = device_cumulative_seqlen_q.get();
|
||||||
get<1>(problem_shape).cumulative_length = buffers[0]->device_cumulative_seqlen_kv.get();
|
get<1>(problem_shape).cumulative_length = device_cumulative_seqlen_kv.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
return problem_shape;
|
return problem_shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto get_arguments(const ProblemShapeType& problem_shape, const cutlass::KernelHardwareInfo& hw_info, int buffer_index) {
|
|
||||||
auto problem_shape_ = problem_shape;
|
|
||||||
if constexpr (kIsVarlen) {
|
|
||||||
get<0>(problem_shape_).cumulative_length = buffers[buffer_index]->device_cumulative_seqlen_q.get();
|
|
||||||
get<1>(problem_shape_).cumulative_length = buffers[buffer_index]->device_cumulative_seqlen_kv.get();
|
|
||||||
}
|
|
||||||
typename Operation::Arguments arguments{
|
|
||||||
problem_shape_,
|
|
||||||
{ buffers[buffer_index]->block_Q.get(), stride_Q,
|
|
||||||
buffers[buffer_index]->block_K.get(), stride_K,
|
|
||||||
buffers[buffer_index]->block_V.get(), stride_V },
|
|
||||||
{ buffers[buffer_index]->block_O.get(), stride_O,
|
|
||||||
buffers[buffer_index]->block_LSE.get(), stride_LSE },
|
|
||||||
hw_info
|
|
||||||
};
|
|
||||||
return arguments;
|
|
||||||
}
|
|
||||||
|
|
||||||
ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
|
ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
|
||||||
|
|
||||||
ProblemShapeType problem_shape = initialize(options);
|
ProblemShapeType problem_shape = initialize(options);
|
||||||
|
|
||||||
int buffer_index = 0;
|
typename Operation::Arguments arguments{
|
||||||
typename Operation::Arguments arguments = get_arguments(problem_shape, hw_info, buffer_index);
|
problem_shape,
|
||||||
|
{ block_Q.get(), stride_Q,
|
||||||
|
block_K.get(), stride_K,
|
||||||
|
block_V.get(), stride_V },
|
||||||
|
{ block_O.get(), stride_O,
|
||||||
|
block_LSE.get(), stride_LSE },
|
||||||
|
hw_info
|
||||||
|
};
|
||||||
|
|
||||||
Operation op;
|
Operation op;
|
||||||
|
|
||||||
@ -675,21 +630,11 @@ struct FwdRunner {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Run
|
// Run
|
||||||
for (int i = 0; i < options.warmup_iterations; i++) {
|
status = op.run();
|
||||||
status = op.run();
|
if (status != cutlass::Status::kSuccess) {
|
||||||
if (status != cutlass::Status::kSuccess) {
|
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
||||||
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
return example_result;
|
||||||
return example_result;
|
|
||||||
}
|
|
||||||
buffer_index = (buffer_index + 1) % buffers.size();
|
|
||||||
arguments = get_arguments(problem_shape, hw_info, buffer_index);
|
|
||||||
status = op.update(arguments, workspace.get());
|
|
||||||
if (status != cutlass::Status::kSuccess) {
|
|
||||||
std::cerr << "Failed to update the CUTLASS kernel's parameters. Last CUDA error is: "
|
|
||||||
<< std::endl;
|
|
||||||
return example_result;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cudaError_t result = cudaDeviceSynchronize();
|
cudaError_t result = cudaDeviceSynchronize();
|
||||||
@ -727,14 +672,6 @@ struct FwdRunner {
|
|||||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||||
return example_result;
|
return example_result;
|
||||||
}
|
}
|
||||||
buffer_index = (buffer_index + 1) % buffers.size();
|
|
||||||
arguments = get_arguments(problem_shape, hw_info, buffer_index);
|
|
||||||
status = op.update(arguments, workspace.get());
|
|
||||||
if (status != cutlass::Status::kSuccess) {
|
|
||||||
std::cerr << "Failed to update the CUTLASS kernel's parameters. Last CUDA error is: "
|
|
||||||
<< std::endl;
|
|
||||||
return example_result;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -797,10 +734,10 @@ struct FwdRunner {
|
|||||||
// Verify that the result is correct
|
// Verify that the result is correct
|
||||||
bool passed = true;
|
bool passed = true;
|
||||||
if (options.verify) {
|
if (options.verify) {
|
||||||
passed = verify(problem_shape, *buffers[0]);
|
passed = verify(problem_shape);
|
||||||
if (passed) example_result.verified = true;
|
if (passed) example_result.verified = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!passed) {
|
if (!passed) {
|
||||||
std::cerr << "Reference check failed" << std::endl;
|
std::cerr << "Reference check failed" << std::endl;
|
||||||
return example_result;
|
return example_result;
|
||||||
@ -852,14 +789,10 @@ void run_fwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareIn
|
|||||||
|
|
||||||
using HeadDim = _128;
|
using HeadDim = _128;
|
||||||
|
|
||||||
if (options.persistent) {
|
// Persistent Tile Scheduler
|
||||||
// Persistent Tile Scheduler
|
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
|
||||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
|
// Individual Tile Scheduler
|
||||||
}
|
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
|
||||||
else {
|
|
||||||
// Individual Tile Scheduler
|
|
||||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
@ -885,14 +818,10 @@ void run_fwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInf
|
|||||||
|
|
||||||
using HeadDim = _64;
|
using HeadDim = _64;
|
||||||
|
|
||||||
if (options.persistent) {
|
// Persistent Tile Scheduler
|
||||||
// Persistent Tile Scheduler
|
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
|
||||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
|
// Individual Tile Scheduler
|
||||||
}
|
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
|
||||||
else {
|
|
||||||
// Individual Tile Scheduler
|
|
||||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -916,14 +845,10 @@ void run_fwd_32(Mask fusion, Options const & options, cutlass::KernelHardwareInf
|
|||||||
using HeadDim = _32;
|
using HeadDim = _32;
|
||||||
|
|
||||||
#ifdef FP8
|
#ifdef FP8
|
||||||
if (options.persistent) {
|
// Persistent Tile Scheduler
|
||||||
// Persistent Tile Scheduler
|
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
|
||||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
|
// Individual Tile Scheduler
|
||||||
}
|
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
|
||||||
else {
|
|
||||||
// Individual Tile Scheduler
|
|
||||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
|
|
||||||
}
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -988,7 +913,6 @@ int main_single(int argc, char const **args) {
|
|||||||
hw_info.sm_count = options.sm_count;
|
hw_info.sm_count = options.sm_count;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
std::cout << "###### B " << options.b << " H " << options.h << " H_K " << options.h_k << " Q " << options.q << " K " << options.k << " D " << options.d << " ";
|
std::cout << "###### B " << options.b << " H " << options.h << " H_K " << options.h_k << " Q " << options.q << " K " << options.k << " D " << options.d << " ";
|
||||||
std::cout << "Forward" << " " << (options.causal ? "Causal" : (options.residual ? "Residual" : "None")) << " ";
|
std::cout << "Forward" << " " << (options.causal ? "Causal" : (options.residual ? "Residual" : "None")) << " ";
|
||||||
std::cout << "#SM " << hw_info.sm_count << std::endl;
|
std::cout << "#SM " << hw_info.sm_count << std::endl;
|
||||||
|
|||||||
@ -120,8 +120,6 @@ struct Options {
|
|||||||
bool verbose = false;
|
bool verbose = false;
|
||||||
|
|
||||||
bool causal = false;
|
bool causal = false;
|
||||||
bool residual = false;
|
|
||||||
bool varlen = false;
|
|
||||||
int sm_count = 0;
|
int sm_count = 0;
|
||||||
|
|
||||||
std::string kernel_filter;
|
std::string kernel_filter;
|
||||||
@ -192,21 +190,14 @@ struct Options {
|
|||||||
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
|
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
|
||||||
verify = cmd.check_cmd_line_flag("verify");
|
verify = cmd.check_cmd_line_flag("verify");
|
||||||
verbose = cmd.check_cmd_line_flag("verbose");
|
verbose = cmd.check_cmd_line_flag("verbose");
|
||||||
varlen = cmd.check_cmd_line_flag("varlen");
|
|
||||||
std::string mask;
|
std::string mask;
|
||||||
cmd.get_cmd_line_argument<std::string>("mask", mask, "");
|
cmd.get_cmd_line_argument<std::string>("mask", mask, "");
|
||||||
if (mask == "causal") {
|
if (mask == "causal") {
|
||||||
causal = true;
|
causal = true;
|
||||||
}
|
}
|
||||||
else if (mask == "residual") {
|
|
||||||
residual = true;
|
|
||||||
}
|
|
||||||
else {
|
else {
|
||||||
causal = defaults.causal;
|
causal = defaults.causal;
|
||||||
}
|
}
|
||||||
if (varlen) {
|
|
||||||
residual = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
skip_reference = cmd.check_cmd_line_flag("skip-reference");
|
skip_reference = cmd.check_cmd_line_flag("skip-reference");
|
||||||
cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count);
|
cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count);
|
||||||
@ -239,12 +230,7 @@ struct Options {
|
|||||||
<< " --iterations=<int> Benchmarking iterations\n"
|
<< " --iterations=<int> Benchmarking iterations\n"
|
||||||
<< " --verify Verify results\n"
|
<< " --verify Verify results\n"
|
||||||
<< " --verbose Print smem and execution time per kernel\n"
|
<< " --verbose Print smem and execution time per kernel\n"
|
||||||
<< " --mask=<no|residual|causal> Enables masking\n"
|
<< " --mask=<no|causal> Enables masking\n"
|
||||||
<< " --varlen Enables variable sequence length\n"
|
|
||||||
<< " B*Q and B*K become the total sequence length\n"
|
|
||||||
<< " and are split B-ways, alternatingly +10% and -10%\n"
|
|
||||||
<< " with the last batch sized to make it fit\n"
|
|
||||||
<< " implies at least residual masking for correctness\n"
|
|
||||||
<< " --sm-count Sets SM count rather than querying it\n"
|
<< " --sm-count Sets SM count rather than querying it\n"
|
||||||
<< " --kernel-filter=<filter> Sets regexp to match kernel against\n"
|
<< " --kernel-filter=<filter> Sets regexp to match kernel against\n"
|
||||||
<< "\n";
|
<< "\n";
|
||||||
@ -321,7 +307,6 @@ struct ExampleResult {
|
|||||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template<
|
template<
|
||||||
bool kIsVarlen,
|
|
||||||
class TileShape,
|
class TileShape,
|
||||||
class DispatchPolicy,
|
class DispatchPolicy,
|
||||||
class ActiveMask,
|
class ActiveMask,
|
||||||
@ -337,11 +322,9 @@ struct BwdRunner {
|
|||||||
using ElementAccumulator = float;
|
using ElementAccumulator = float;
|
||||||
|
|
||||||
// Q K D (H B)
|
// Q K D (H B)
|
||||||
using ProblemShape = std::conditional_t<
|
using ProblemShapeType = cute::tuple<int, int, int, cute::tuple<int, int>>;
|
||||||
kIsVarlen,
|
|
||||||
cute::tuple<VariableLength, VariableLength, int, cute::tuple<int, int>>,
|
using Operation = cutlass::fmha::device::Sm100FmhaBwd<Element, ElementAccumulator, TileShape, ActiveMask>;
|
||||||
cute::tuple<int, int, int, cute::tuple<int, int>>
|
|
||||||
>;
|
|
||||||
|
|
||||||
using TensorStride = Stride<int, _1, Stride<int, int>>; // Seq D (H B)
|
using TensorStride = Stride<int, _1, Stride<int, int>>; // Seq D (H B)
|
||||||
using StrideQ = TensorStride;
|
using StrideQ = TensorStride;
|
||||||
@ -380,9 +363,6 @@ struct BwdRunner {
|
|||||||
DeviceAllocation<Element> block_O;
|
DeviceAllocation<Element> block_O;
|
||||||
DeviceAllocation<ElementAccumulator> block_LSE;
|
DeviceAllocation<ElementAccumulator> block_LSE;
|
||||||
|
|
||||||
DeviceAllocation<int> block_cumulative_seqlen_q;
|
|
||||||
DeviceAllocation<int> block_cumulative_seqlen_kv;
|
|
||||||
|
|
||||||
DeviceAllocation<Element> block_dQ;
|
DeviceAllocation<Element> block_dQ;
|
||||||
DeviceAllocation<Element> block_dK;
|
DeviceAllocation<Element> block_dK;
|
||||||
DeviceAllocation<Element> block_dV;
|
DeviceAllocation<Element> block_dV;
|
||||||
@ -395,7 +375,7 @@ struct BwdRunner {
|
|||||||
//
|
//
|
||||||
// Methods
|
// Methods
|
||||||
//
|
//
|
||||||
bool verify(const ProblemShape& problem_shape) {
|
bool verify(const ProblemShapeType& problem_shape) {
|
||||||
auto [Q, K, D, HB] = problem_shape;
|
auto [Q, K, D, HB] = problem_shape;
|
||||||
auto [H, B] = HB;
|
auto [H, B] = HB;
|
||||||
|
|
||||||
@ -479,85 +459,22 @@ struct BwdRunner {
|
|||||||
return passed_dQ && passed_dK && passed_dV;
|
return passed_dQ && passed_dK && passed_dV;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto initialize_problem_shape(Options const& options) {
|
|
||||||
if constexpr (kIsVarlen) {
|
|
||||||
int num_batches = options.b;
|
|
||||||
|
|
||||||
// generate Q as --b times
|
|
||||||
// gaussian (--Q, --Q / 2) sampled positive
|
|
||||||
// track cumulative
|
|
||||||
std::mt19937 rng(0x202305151552ull);
|
|
||||||
std::normal_distribution<double> dist_q(options.q, options.q / 2);
|
|
||||||
std::normal_distribution<double> dist_kv(options.k, options.k / 2);
|
|
||||||
|
|
||||||
auto generate_positive_int = [](auto& dist, auto& gen) {
|
|
||||||
// "0" is a valid value we test here
|
|
||||||
return std::max(0, static_cast<int>(dist(gen)));
|
|
||||||
};
|
|
||||||
|
|
||||||
std::vector<int> cumulative_seqlen_q = {0};
|
|
||||||
std::vector<int> cumulative_seqlen_kv = {0};
|
|
||||||
|
|
||||||
int total_seqlen_q = 0;
|
|
||||||
int total_seqlen_kv = 0;
|
|
||||||
int max_seqlen_q = 0;
|
|
||||||
int max_seqlen_kv = 0;
|
|
||||||
|
|
||||||
const bool kVarlenSame = false;
|
|
||||||
for (int i = 0; i < num_batches; i++) {
|
|
||||||
int seqlen_q = kVarlenSame ? options.q : generate_positive_int(dist_q, rng);
|
|
||||||
int seqlen_kv = kVarlenSame ? options.k : generate_positive_int(dist_kv, rng);
|
|
||||||
|
|
||||||
total_seqlen_q += seqlen_q;
|
|
||||||
total_seqlen_kv += seqlen_kv;
|
|
||||||
|
|
||||||
max_seqlen_q = std::max(max_seqlen_q, seqlen_q);
|
|
||||||
max_seqlen_kv = std::max(max_seqlen_kv, seqlen_kv);
|
|
||||||
|
|
||||||
cumulative_seqlen_q.push_back(cumulative_seqlen_q.back() + seqlen_q);
|
|
||||||
cumulative_seqlen_kv.push_back(cumulative_seqlen_kv.back() + seqlen_kv);
|
|
||||||
}
|
|
||||||
|
|
||||||
block_cumulative_seqlen_q.reset(cumulative_seqlen_q.size());
|
|
||||||
block_cumulative_seqlen_q.copy_from_host(cumulative_seqlen_q.data(), cumulative_seqlen_q.size());
|
|
||||||
block_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size());
|
|
||||||
block_cumulative_seqlen_kv.copy_from_host(cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size());
|
|
||||||
|
|
||||||
ProblemShape problem_shape{
|
|
||||||
{max_seqlen_q, block_cumulative_seqlen_q.get(), total_seqlen_q},
|
|
||||||
{max_seqlen_kv, block_cumulative_seqlen_kv.get(), total_seqlen_kv},
|
|
||||||
options.d, {options.h, options.b}
|
|
||||||
};
|
|
||||||
auto tensor_shape = make_shape(total_seqlen_q, total_seqlen_kv, options.d, make_shape(options.h, 1));
|
|
||||||
|
|
||||||
return cute::make_tuple(problem_shape, tensor_shape);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
ProblemShape problem_shape{options.q, options.k, options.d, {options.h, options.b}};
|
|
||||||
return cute::make_tuple(problem_shape, problem_shape);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||||
ProblemShape initialize(Options const& options) {
|
void initialize(const ProblemShapeType& problem_shape, Options const& options) {
|
||||||
auto [problem_shape, tensor_shape] = initialize_problem_shape(options);
|
auto [Q, K, D, HB] = problem_shape;
|
||||||
auto [Q, K, D, HB] = tensor_shape;
|
|
||||||
auto [H, B] = HB;
|
auto [H, B] = HB;
|
||||||
D = cutlass::round_up(D, 8); // Alignment
|
D = cutlass::round_up(D, 8); // Alignment
|
||||||
|
Q = cutlass::round_up(Q, 8); // Alignment
|
||||||
|
|
||||||
// for varlen, Q == total_Q, K == total_K, B = 1
|
auto shape_QO = select<0,2,3>(problem_shape);
|
||||||
// but in problem_shape, they've got to be max_Q/max_K, and B = B
|
auto shape_KV = select<1,2,3>(problem_shape);
|
||||||
|
auto shape_LSE = select<0,3>(problem_shape);
|
||||||
auto shape_QO = make_shape(Q, D, make_shape(H, B));
|
|
||||||
auto shape_KV = make_shape(K, D, make_shape(H, B));
|
|
||||||
auto shape_LSE = make_shape(Q, make_shape(H, B));
|
|
||||||
|
|
||||||
stride_Q = make_stride(D, _1{}, make_stride(D*Q, B == 1 ? 0 : D*Q*H));
|
|
||||||
stride_K = make_stride(D, _1{}, make_stride(D*K, B == 1 ? 0 : D*K*H));
|
|
||||||
stride_LSE = make_stride(_1{}, make_stride(Q, B == 1 ? 0 : Q*H));
|
|
||||||
|
|
||||||
|
stride_Q = make_stride(D, _1{}, make_stride(D*Q, D*Q*H));
|
||||||
|
stride_K = make_stride(D, _1{}, make_stride(D*K, D*K*H));
|
||||||
stride_V = stride_K;
|
stride_V = stride_K;
|
||||||
stride_O = stride_Q;
|
stride_O = stride_Q;
|
||||||
|
stride_LSE = make_stride(_1{}, make_stride(Q, Q*H));
|
||||||
|
|
||||||
stride_dQ = stride_Q;
|
stride_dQ = stride_Q;
|
||||||
stride_dK = stride_K;
|
stride_dK = stride_K;
|
||||||
@ -588,13 +505,6 @@ struct BwdRunner {
|
|||||||
initialize_block(block_V, seed + 2021, options.init_style_v);
|
initialize_block(block_V, seed + 2021, options.init_style_v);
|
||||||
initialize_block(block_dO, seed + 2020, options.init_style_do);
|
initialize_block(block_dO, seed + 2020, options.init_style_do);
|
||||||
|
|
||||||
initialize_block(block_dQ, seed + 2030, InitStyle::kOne);
|
|
||||||
initialize_block(block_dK, seed + 2031, InitStyle::kOne);
|
|
||||||
initialize_block(block_dV, seed + 2032, InitStyle::kOne);
|
|
||||||
initialize_block(block_ref_dQ, seed + 2033);
|
|
||||||
initialize_block(block_ref_dK, seed + 2034);
|
|
||||||
initialize_block(block_ref_dV, seed + 2035);
|
|
||||||
|
|
||||||
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()),
|
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()),
|
||||||
select<0,2,3>(problem_shape),
|
select<0,2,3>(problem_shape),
|
||||||
stride_Q);
|
stride_Q);
|
||||||
@ -618,19 +528,15 @@ struct BwdRunner {
|
|||||||
if (! options.skip_reference) {
|
if (! options.skip_reference) {
|
||||||
fmha_reference(problem_shape, mQ, mK, mV, mO, mLSE, ActiveMask{});
|
fmha_reference(problem_shape, mQ, mK, mV, mO, mLSE, ActiveMask{});
|
||||||
}
|
}
|
||||||
|
|
||||||
return problem_shape;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
|
ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
|
||||||
auto problem_shape = initialize(options);
|
auto problem_shape = make_shape(options.q, options.k, options.d, make_shape(options.h, options.b));
|
||||||
|
|
||||||
|
initialize(problem_shape, options);
|
||||||
|
|
||||||
ElementAccumulator softmax_scale = 1.0f / sqrtf(options.d);
|
ElementAccumulator softmax_scale = 1.0f / sqrtf(options.d);
|
||||||
|
|
||||||
ExampleResult example_result;
|
|
||||||
|
|
||||||
using Operation = cutlass::fmha::device::Sm100FmhaBwd<ProblemShape, Element, ElementAccumulator, TileShape, ActiveMask>;
|
|
||||||
|
|
||||||
typename Operation::Arguments arguments{
|
typename Operation::Arguments arguments{
|
||||||
problem_shape,
|
problem_shape,
|
||||||
block_Q.get(), stride_Q,
|
block_Q.get(), stride_Q,
|
||||||
@ -648,6 +554,8 @@ struct BwdRunner {
|
|||||||
|
|
||||||
Operation op;
|
Operation op;
|
||||||
|
|
||||||
|
ExampleResult example_result;
|
||||||
|
|
||||||
example_result.smem_size = Operation::Kernel::SharedStorageSize;
|
example_result.smem_size = Operation::Kernel::SharedStorageSize;
|
||||||
|
|
||||||
size_t workspace_size = 0;
|
size_t workspace_size = 0;
|
||||||
@ -742,7 +650,7 @@ struct BwdRunner {
|
|||||||
|
|
||||||
runtime_ms /= static_cast<float>(options.iterations);
|
runtime_ms /= static_cast<float>(options.iterations);
|
||||||
|
|
||||||
double flops = 10.0 * (std::is_same_v<ActiveMask, CausalForBackwardMask> ? 0.5 : 1.0);
|
double flops = 10.0 * (std::is_same_v<ActiveMask, CausalMask> ? 0.5 : 1.0);
|
||||||
flops *= static_cast<double>(get<0>(problem_shape));
|
flops *= static_cast<double>(get<0>(problem_shape));
|
||||||
flops *= static_cast<double>(get<1>(problem_shape));
|
flops *= static_cast<double>(get<1>(problem_shape));
|
||||||
flops *= static_cast<double>(get<2>(problem_shape));
|
flops *= static_cast<double>(get<2>(problem_shape));
|
||||||
@ -798,28 +706,14 @@ void print_result(const std::string& description, ExampleResult result, bool ver
|
|||||||
|
|
||||||
struct KernelCoop {};
|
struct KernelCoop {};
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template<class Fn>
|
|
||||||
auto dispatch_bool(bool value, Fn fn) {
|
|
||||||
if (value) {
|
|
||||||
return fn(std::true_type{});
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
return fn(std::false_type{});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template<class Mask>
|
template<class Mask>
|
||||||
void run_bwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
|
void run_bwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
|
||||||
auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) {
|
auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) {
|
||||||
dispatch_bool(options.varlen, [&](auto is_varlen) {
|
BwdRunner<decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
|
||||||
BwdRunner<decltype(is_varlen)::value, decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
|
auto result = runner.run(options, hw_info);
|
||||||
auto result = runner.run(options, hw_info);
|
print_result(name, result, options.verbose);
|
||||||
print_result(name, result, options.verbose);
|
|
||||||
});
|
|
||||||
};
|
};
|
||||||
|
|
||||||
using HeadDim = _64;
|
using HeadDim = _64;
|
||||||
@ -832,11 +726,9 @@ void run_bwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInf
|
|||||||
template<class Mask>
|
template<class Mask>
|
||||||
void run_bwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
|
void run_bwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
|
||||||
auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) {
|
auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) {
|
||||||
dispatch_bool(options.varlen, [&](auto is_varlen) {
|
BwdRunner<decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
|
||||||
BwdRunner<decltype(is_varlen)::value, decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
|
auto result = runner.run(options, hw_info);
|
||||||
auto result = runner.run(options, hw_info);
|
print_result(name, result, options.verbose);
|
||||||
print_result(name, result, options.verbose);
|
|
||||||
});
|
|
||||||
};
|
};
|
||||||
|
|
||||||
using HeadDim = _128;
|
using HeadDim = _128;
|
||||||
@ -911,10 +803,7 @@ int main_single(int argc, char const **args) {
|
|||||||
|
|
||||||
auto with_causal = [&](auto fn) {
|
auto with_causal = [&](auto fn) {
|
||||||
if (options.causal) {
|
if (options.causal) {
|
||||||
fn(CausalForBackwardMask{});
|
fn(CausalMask{});
|
||||||
}
|
|
||||||
else if (options.residual) {
|
|
||||||
fn(ResidualMaskForBackward{});
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
fn(NoMask{});
|
fn(NoMask{});
|
||||||
|
|||||||
@ -394,7 +394,6 @@ struct ExampleRunner {
|
|||||||
fmha_fwd_gen_reference<ElementAcc>(
|
fmha_fwd_gen_reference<ElementAcc>(
|
||||||
problem_shape, block_seqlen_kv.get(), block_cache_batch_idx.get(),
|
problem_shape, block_seqlen_kv.get(), block_cache_batch_idx.get(),
|
||||||
mQ, mNewK, mNewV, mCacheK, mCacheV, mO);
|
mQ, mNewK, mNewV, mCacheK, mCacheV, mO);
|
||||||
|
|
||||||
cudaError_t result = cudaDeviceSynchronize();
|
cudaError_t result = cudaDeviceSynchronize();
|
||||||
if (result != cudaSuccess) {
|
if (result != cudaSuccess) {
|
||||||
std::cerr << "Reference kernel failed. Last CUDA error: "
|
std::cerr << "Reference kernel failed. Last CUDA error: "
|
||||||
@ -409,7 +408,6 @@ struct ExampleRunner {
|
|||||||
double max_diff = 0;
|
double max_diff = 0;
|
||||||
double mean_diff = 0;
|
double mean_diff = 0;
|
||||||
reference_abs_diff(block_o, block_ref_o, max_diff, mean_diff);
|
reference_abs_diff(block_o, block_ref_o, max_diff, mean_diff);
|
||||||
|
|
||||||
bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||||
if (! passed_O) {
|
if (! passed_O) {
|
||||||
std::cerr << "failed O: max diff " << max_diff
|
std::cerr << "failed O: max diff " << max_diff
|
||||||
@ -417,7 +415,6 @@ struct ExampleRunner {
|
|||||||
}
|
}
|
||||||
|
|
||||||
reference_abs_diff(block_cache_k, block_ref_cache_k, max_diff, mean_diff);
|
reference_abs_diff(block_cache_k, block_ref_cache_k, max_diff, mean_diff);
|
||||||
|
|
||||||
bool passed_K = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
bool passed_K = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||||
if ( ! passed_K) {
|
if ( ! passed_K) {
|
||||||
std::cerr << "failed Cache K: max diff " << max_diff
|
std::cerr << "failed Cache K: max diff " << max_diff
|
||||||
@ -425,7 +422,6 @@ struct ExampleRunner {
|
|||||||
}
|
}
|
||||||
|
|
||||||
reference_abs_diff(block_cache_v, block_ref_cache_v, max_diff, mean_diff);
|
reference_abs_diff(block_cache_v, block_ref_cache_v, max_diff, mean_diff);
|
||||||
|
|
||||||
bool passed_V = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
bool passed_V = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||||
if ( ! passed_V) {
|
if ( ! passed_V) {
|
||||||
std::cerr << "failed Cache V: max diff " << max_diff
|
std::cerr << "failed Cache V: max diff " << max_diff
|
||||||
@ -507,7 +503,6 @@ struct ExampleRunner {
|
|||||||
|
|
||||||
block_ref_cache_k.copy_from_device(block_cache_k.get(), block_cache_k.size());
|
block_ref_cache_k.copy_from_device(block_cache_k.get(), block_cache_k.size());
|
||||||
block_ref_cache_v.copy_from_device(block_cache_v.get(), block_cache_v.size());
|
block_ref_cache_v.copy_from_device(block_cache_v.get(), block_cache_v.size());
|
||||||
|
|
||||||
block_seqlen_kv.reset(seqlen_kv.size());
|
block_seqlen_kv.reset(seqlen_kv.size());
|
||||||
block_seqlen_kv.copy_from_host(seqlen_kv.data(), seqlen_kv.size());
|
block_seqlen_kv.copy_from_host(seqlen_kv.data(), seqlen_kv.size());
|
||||||
|
|
||||||
@ -726,7 +721,6 @@ int main_single(int argc, char const **args) {
|
|||||||
<< "later (compute capability 90 or greater) and CUDA 12.0 or greater.\n";
|
<< "later (compute capability 90 or greater) and CUDA 12.0 or greater.\n";
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// Parse options
|
// Parse options
|
||||||
//
|
//
|
||||||
|
|||||||
@ -59,7 +59,7 @@ using namespace cutlass::fmha::kernel;
|
|||||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
enum class InitStyle {
|
enum class InitStyle {
|
||||||
kOne, kLinearStride128, kLinearStride1, kRandom, kRandomLarge, kNone
|
kOne, kLinearStride128, kLinearStride1, kRandom, kNone
|
||||||
};
|
};
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
@ -98,9 +98,6 @@ struct Options {
|
|||||||
if (s == "r") {
|
if (s == "r") {
|
||||||
dst = InitStyle::kRandom;
|
dst = InitStyle::kRandom;
|
||||||
}
|
}
|
||||||
else if (s == "l") {
|
|
||||||
dst = InitStyle::kRandomLarge;
|
|
||||||
}
|
|
||||||
else if (s == "1") {
|
else if (s == "1") {
|
||||||
dst = InitStyle::kOne;
|
dst = InitStyle::kOne;
|
||||||
}
|
}
|
||||||
@ -206,11 +203,6 @@ void initialize_block(
|
|||||||
block.get(), block.size(), seed, (Element) -1, (Element) 1);
|
block.get(), block.size(), seed, (Element) -1, (Element) 1);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case InitStyle::kRandomLarge: {
|
|
||||||
cutlass::reference::device::BlockFillRandomGaussian(
|
|
||||||
block.get(), block.size(), seed, (Element) -1, (Element) 100);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case InitStyle::kLinearStride1: {
|
case InitStyle::kLinearStride1: {
|
||||||
std::vector<Element> data(block.size());
|
std::vector<Element> data(block.size());
|
||||||
for (size_t i = 0; i < block.size() / 128; i ++) {
|
for (size_t i = 0; i < block.size() / 128; i ++) {
|
||||||
|
|||||||
@ -63,7 +63,7 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
|
|||||||
TEST_COMMAND_OPTIONS
|
TEST_COMMAND_OPTIONS
|
||||||
TEST_BASIC
|
TEST_BASIC
|
||||||
# TEST_CAUSAL
|
# TEST_CAUSAL
|
||||||
TEST_VARLEN
|
# TEST_VARLEN
|
||||||
# TEST_HDIM64
|
# TEST_HDIM64
|
||||||
# TEST_GQA)
|
# TEST_GQA)
|
||||||
)
|
)
|
||||||
@ -119,7 +119,7 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
|
|||||||
77_blackwell_fmha_bwd.cu
|
77_blackwell_fmha_bwd.cu
|
||||||
TEST_COMMAND_OPTIONS
|
TEST_COMMAND_OPTIONS
|
||||||
TEST_BASIC
|
TEST_BASIC
|
||||||
TEST_VARLEN
|
# TEST_GEN_VARLEN
|
||||||
# TEST_GEN_HDIM64
|
# TEST_GEN_HDIM64
|
||||||
# TEST_GEN_GQA
|
# TEST_GEN_GQA
|
||||||
# TEST_GEN_REMAP
|
# TEST_GEN_REMAP
|
||||||
@ -128,24 +128,20 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
|
|||||||
target_include_directories(77_blackwell_fmha_bwd_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
target_include_directories(77_blackwell_fmha_bwd_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||||
target_compile_definitions(77_blackwell_fmha_bwd_${PREC} PRIVATE ${PREC_MACRO})
|
target_compile_definitions(77_blackwell_fmha_bwd_${PREC} PRIVATE ${PREC_MACRO})
|
||||||
target_compile_options(77_blackwell_fmha_bwd_${PREC} PRIVATE -Xptxas -v)
|
target_compile_options(77_blackwell_fmha_bwd_${PREC} PRIVATE -Xptxas -v)
|
||||||
endforeach()
|
|
||||||
|
|
||||||
# Add a target that builds all examples
|
cutlass_example_add_executable(
|
||||||
add_custom_target(77_blackwell_fmha_all
|
77_blackwell_fmha_bwd_sat_${PREC}
|
||||||
DEPENDS
|
77_blackwell_fmha_bwd.cu
|
||||||
77_blackwell_fmha_fp8
|
TEST_COMMAND_OPTIONS
|
||||||
77_blackwell_fmha_fp16
|
TEST_BASIC
|
||||||
77_blackwell_fmha_gen_fp8
|
# TEST_GEN_VARLEN
|
||||||
77_blackwell_fmha_gen_fp16
|
TEST_GEN_HDIM64
|
||||||
77_blackwell_mla_2sm_fp8
|
# TEST_GEN_GQA
|
||||||
77_blackwell_mla_2sm_fp16
|
# TEST_GEN_REMAP
|
||||||
77_blackwell_mla_2sm_cpasync_fp8
|
# TEST_GEN_CACHEONLY)
|
||||||
77_blackwell_mla_2sm_cpasync_fp16
|
)
|
||||||
77_blackwell_mla_b2b_2sm_fp8
|
target_include_directories(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||||
77_blackwell_mla_b2b_2sm_fp16
|
target_compile_definitions(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE ${PREC_MACRO} SKIP_ATOMIC)
|
||||||
77_blackwell_fmha_bwd_fp8
|
target_compile_options(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE -Xptxas -v)
|
||||||
77_blackwell_fmha_bwd_fp16
|
endforeach()
|
||||||
77_blackwell_fmha_bwd_sat_fp8
|
|
||||||
77_blackwell_fmha_bwd_sat_fp16
|
|
||||||
)
|
|
||||||
endif()
|
endif()
|
||||||
|
|||||||
@ -132,58 +132,6 @@ struct ResidualMask : NoMask {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ResidualMaskForBackward : NoMask {
|
|
||||||
|
|
||||||
using Base = NoMask;
|
|
||||||
|
|
||||||
template <class BlkCoord, class TileShape, class ProblemSize>
|
|
||||||
CUTLASS_DEVICE int get_masked_trip_count(
|
|
||||||
BlkCoord const& blk_coord,
|
|
||||||
TileShape const& tile_shape,
|
|
||||||
ProblemSize const& problem_size) {
|
|
||||||
|
|
||||||
if (get<1>(problem_size) % get<1>(tile_shape) != 0) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
|
||||||
CUTLASS_DEVICE
|
|
||||||
int get_unmasked_trip_count(
|
|
||||||
BlkCoord const& blk_coord,
|
|
||||||
TileShape const& tile_shape,
|
|
||||||
ProblemSize const& problem_size) {
|
|
||||||
|
|
||||||
// if the sequence length does not divide the tile size evenly
|
|
||||||
if (get<1>(problem_size) % get<1>(tile_shape) != 0) {
|
|
||||||
return get_trip_count(blk_coord, tile_shape, problem_size) - 1;
|
|
||||||
}
|
|
||||||
return get_trip_count(blk_coord, tile_shape, problem_size);
|
|
||||||
}
|
|
||||||
|
|
||||||
template<class AccQK, class IndexQK, class ProblemSize>
|
|
||||||
CUTLASS_DEVICE
|
|
||||||
void apply_mask(
|
|
||||||
AccQK& acc_qk,
|
|
||||||
IndexQK const& index_qk,
|
|
||||||
ProblemSize const& problem_size) {
|
|
||||||
|
|
||||||
// This is useful is seqlen_k % kBlockN != 0 since it masks
|
|
||||||
// the remaining elements out from softmax.
|
|
||||||
// d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar
|
|
||||||
// issues as they are transparently taken care of by TMA and the
|
|
||||||
// epilogue, if it is instantiated with predication support.
|
|
||||||
CUTLASS_PRAGMA_UNROLL
|
|
||||||
for (int i = 0; i < size(acc_qk); i++) {
|
|
||||||
auto pos = index_qk(i);
|
|
||||||
if (! elem_less(pos, select<0,1>(problem_size))) {
|
|
||||||
acc_qk(i) = -INFINITY;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct CausalMask : NoMask {
|
struct CausalMask : NoMask {
|
||||||
|
|
||||||
using Base = NoMask;
|
using Base = NoMask;
|
||||||
@ -209,7 +157,8 @@ struct CausalMask : NoMask {
|
|||||||
TileShape const& tile_shape,
|
TileShape const& tile_shape,
|
||||||
ProblemSize const& problem_size) {
|
ProblemSize const& problem_size) {
|
||||||
|
|
||||||
return ceil_div(get<0>(tile_shape), get<1>(tile_shape));
|
int trip_count = get_trip_count(blk_coord, tile_shape, problem_size);
|
||||||
|
return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));
|
||||||
}
|
}
|
||||||
|
|
||||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||||
@ -248,57 +197,25 @@ struct CausalMask : NoMask {
|
|||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CausalForBackwardMask : CausalMask, ResidualMaskForBackward {
|
|
||||||
|
|
||||||
using Base = CausalMask;
|
|
||||||
|
|
||||||
template<class AccQK, class IndexQK, class ProblemSize>
|
|
||||||
CUTLASS_DEVICE
|
|
||||||
void apply_mask(
|
|
||||||
AccQK& acc_qk,
|
|
||||||
IndexQK const& index_qk,
|
|
||||||
ProblemSize const& problem_size) {
|
|
||||||
|
|
||||||
// There are two ways to do causal if N_Q != N_K
|
|
||||||
// (1) is to assume that the Q is at the beginning of the matrix
|
|
||||||
// - this is what we demonstrate here
|
|
||||||
// (2) is that it is at the end of the matrix
|
|
||||||
// - this is usually what we want for inference settings
|
|
||||||
// where we only compute the next row and use cache for the rest
|
|
||||||
// - if you'd like this, you only need to add an offset like so:
|
|
||||||
// get<0>(pos) + offset_q < get<1>(pos)
|
|
||||||
CUTLASS_PRAGMA_UNROLL
|
|
||||||
for (int i = 0; i < size(acc_qk); i++) {
|
|
||||||
auto pos = index_qk(i);
|
|
||||||
bool masked = (get<0>(pos) < get<1>(pos)) || !elem_less(pos, problem_size);
|
|
||||||
if (masked) {
|
|
||||||
acc_qk(i) = -INFINITY;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
struct VariableLength {
|
struct VariableLength {
|
||||||
int max_length;
|
int max_length;
|
||||||
int* cumulative_length = nullptr;
|
int* cumulative_length = nullptr;
|
||||||
int total_length = -1;
|
|
||||||
|
|
||||||
CUTE_HOST_DEVICE operator int() const {
|
CUTE_HOST_DEVICE operator int() const {
|
||||||
return max_length;
|
return max_length;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template<class T> struct is_variable_length_impl : std::false_type {};
|
template<class T> struct is_variable_length : std::false_type {};
|
||||||
template<> struct is_variable_length_impl<VariableLength> : std::true_type {};
|
template<> struct is_variable_length<VariableLength> : std::true_type {};
|
||||||
template<class T> constexpr bool is_variable_length_v = is_variable_length_impl<remove_cvref_t<T>>::value;
|
template<class T> constexpr bool is_variable_length_v = is_variable_length<T>::value;
|
||||||
|
|
||||||
template<class Shape, class Idx>
|
template<class Shape, class Idx>
|
||||||
CUTE_HOST_DEVICE
|
CUTE_HOST_DEVICE
|
||||||
constexpr auto
|
constexpr auto
|
||||||
apply_variable_length(Shape const& shape, Idx const& idx) {
|
apply_variable_length(Shape const& shape, Idx const& idx) {
|
||||||
return transform_leaf(shape, [&](auto const& s) {
|
return transform_leaf(shape, [&](auto const& s) {
|
||||||
if constexpr (is_variable_length_v<decltype(s)>) {
|
if constexpr (is_variable_length_v<remove_cvref_t<decltype(s)>>) {
|
||||||
return s.cumulative_length[idx+1] - s.cumulative_length[idx];
|
return s.cumulative_length[idx+1] - s.cumulative_length[idx];
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
@ -313,7 +230,7 @@ constexpr auto
|
|||||||
apply_variable_length(Shape const& shape, Coord const& coord, Idx const& idx) {
|
apply_variable_length(Shape const& shape, Coord const& coord, Idx const& idx) {
|
||||||
auto new_shape = apply_variable_length(shape, idx);
|
auto new_shape = apply_variable_length(shape, idx);
|
||||||
auto new_coord = transform_leaf(shape, coord, [&](auto const& s, auto const& c) {
|
auto new_coord = transform_leaf(shape, coord, [&](auto const& s, auto const& c) {
|
||||||
if constexpr (is_variable_length_v<decltype(s)>) {
|
if constexpr (is_variable_length_v<remove_cvref_t<decltype(s)>>) {
|
||||||
return cute::make_tuple(c, s.cumulative_length[idx]);
|
return cute::make_tuple(c, s.cumulative_length[idx]);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
@ -323,30 +240,6 @@ apply_variable_length(Shape const& shape, Coord const& coord, Idx const& idx) {
|
|||||||
return cute::make_tuple(new_shape, new_coord);
|
return cute::make_tuple(new_shape, new_coord);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<class Shape, class Coord>
|
|
||||||
CUTE_HOST_DEVICE
|
|
||||||
constexpr auto
|
|
||||||
apply_variable_length_offset(Shape const& shape, Coord const& coord) {
|
|
||||||
auto idx = back(back(coord));
|
|
||||||
auto result_shape = transform_leaf(shape, [&](auto const& s) {
|
|
||||||
if constexpr (is_variable_length_v<decltype(s)>) {
|
|
||||||
return s.cumulative_length[idx+1] - s.cumulative_length[idx];
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
return s;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
auto result_offset = transform_leaf(coord, shape, [&](auto const& c, auto const& s) {
|
|
||||||
if constexpr (is_variable_length_v<decltype(s)>) {
|
|
||||||
return s.cumulative_length[idx];
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
return _0{};
|
|
||||||
}
|
|
||||||
});
|
|
||||||
return cute::make_tuple(result_shape, result_offset);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace cutlass::fmha::collective
|
} // namespace cutlass::fmha::collective
|
||||||
|
|
||||||
namespace cute {
|
namespace cute {
|
||||||
|
|||||||
@ -42,7 +42,7 @@ template<
|
|||||||
class ElementAcc,
|
class ElementAcc,
|
||||||
class TileShape, // Q, D, _
|
class TileShape, // Q, D, _
|
||||||
class StrideO, // Q, D, B
|
class StrideO, // Q, D, B
|
||||||
class StrideLSE_ // Q, B
|
class StrideLSE // Q, B
|
||||||
>
|
>
|
||||||
struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
||||||
|
|
||||||
@ -54,7 +54,6 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
|||||||
// using SmemLayoutAtomO = decltype(make_ordered_layout(select<0,1>(TileShape{}), Step<_1, _0>{}));
|
// using SmemLayoutAtomO = decltype(make_ordered_layout(select<0,1>(TileShape{}), Step<_1, _0>{}));
|
||||||
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, replace<2>(TileShape{}, _2{}), Step<_2, _1, _3>{}));
|
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, replace<2>(TileShape{}, _2{}), Step<_2, _1, _3>{}));
|
||||||
using SmemLayoutO_ = SmemLayoutO;
|
using SmemLayoutO_ = SmemLayoutO;
|
||||||
using StrideLSE = StrideLSE_;
|
|
||||||
|
|
||||||
struct TensorStorage {
|
struct TensorStorage {
|
||||||
|
|
||||||
@ -80,9 +79,6 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
|||||||
|
|
||||||
struct Params {
|
struct Params {
|
||||||
TMA_O tma_store_o;
|
TMA_O tma_store_o;
|
||||||
|
|
||||||
ElementAcc* ptr_LSE;
|
|
||||||
StrideLSE dLSE;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template<class ProblemShape>
|
template<class ProblemShape>
|
||||||
@ -114,9 +110,7 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
|||||||
);
|
);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
tma_store_o,
|
tma_store_o
|
||||||
args.ptr_LSE,
|
|
||||||
args.dLSE
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -125,10 +119,6 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
|||||||
cute::prefetch_tma_descriptor(params.tma_store_o.get_tma_descriptor());
|
cute::prefetch_tma_descriptor(params.tma_store_o.get_tma_descriptor());
|
||||||
}
|
}
|
||||||
|
|
||||||
const Params& params;
|
|
||||||
|
|
||||||
CUTLASS_DEVICE Sm100FmhaFwdEpilogueTmaWarpspecialized(const Params& params) : params(params) {}
|
|
||||||
|
|
||||||
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
|
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
|
||||||
CUTLASS_DEVICE auto
|
CUTLASS_DEVICE auto
|
||||||
store(
|
store(
|
||||||
|
|||||||
@ -637,7 +637,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
|||||||
}
|
}
|
||||||
tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv);
|
tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv);
|
||||||
|
|
||||||
|
|
||||||
if (i == size(tTMEM_LOADrS) - kReleasePipeCount) {
|
if (i == size(tTMEM_LOADrS) - kReleasePipeCount) {
|
||||||
order_s.arrive();
|
order_s.arrive();
|
||||||
}
|
}
|
||||||
@ -691,7 +691,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
|||||||
cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3);
|
cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3);
|
||||||
cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2);
|
cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2);
|
||||||
float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y;
|
float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y;
|
||||||
|
|
||||||
row_sum = local_row_sum;
|
row_sum = local_row_sum;
|
||||||
|
|
||||||
if (final_call) {
|
if (final_call) {
|
||||||
@ -787,14 +787,14 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
|||||||
// good values would be either 32 or 64
|
// good values would be either 32 or 64
|
||||||
const int kCorrectionTileSize = 32 / sizeof(ElementOut);
|
const int kCorrectionTileSize = 32 / sizeof(ElementOut);
|
||||||
|
|
||||||
using TMEM_LOAD = std::conditional_t<kCorrectionTileSize == 32, SM100_TMEM_LOAD_32dp32b32x, SM100_TMEM_LOAD_32dp32b16x>; // 4x32 threads with 64 cols of 32b elem
|
using TMEM_LOAD = std::conditional_t<kCorrectionTileSize == 32, SM100_TMEM_LOAD_32dp32b32x, SM100_TMEM_LOAD_32dp32b16x>; // 4x32 threads with 64 cols of 32b elem
|
||||||
|
|
||||||
typename CollectiveMmaPV::TiledMma mma;
|
typename CollectiveMmaPV::TiledMma mma;
|
||||||
Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));
|
Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));
|
||||||
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));
|
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));
|
||||||
Tensor tOcO = mma.get_slice(0).partition_C(cO);
|
Tensor tOcO = mma.get_slice(0).partition_C(cO);
|
||||||
Tensor tOsO = mma.get_slice(0).partition_C(sO);
|
Tensor tOsO = mma.get_slice(0).partition_C(sO);
|
||||||
|
|
||||||
Tensor tOtO_i = logical_divide(tOtO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
Tensor tOtO_i = logical_divide(tOtO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
||||||
Tensor tOcO_i = logical_divide(tOcO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
Tensor tOcO_i = logical_divide(tOcO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
||||||
Tensor tOsO_i = logical_divide(tOsO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
Tensor tOsO_i = logical_divide(tOsO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
||||||
@ -809,7 +809,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
|||||||
|
|
||||||
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i(make_coord(_, _), _0{}));
|
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i(make_coord(_, _), _0{}));
|
||||||
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
|
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
|
||||||
|
|
||||||
Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i(make_coord(_, _), _));
|
Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i(make_coord(_, _), _));
|
||||||
Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i(make_coord(_, _), _));
|
Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i(make_coord(_, _), _));
|
||||||
Tensor tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i(make_coord(_, _), _));
|
Tensor tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i(make_coord(_, _), _));
|
||||||
@ -824,9 +824,9 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
|||||||
Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i);
|
Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i);
|
||||||
|
|
||||||
Tensor tTMrO = make_tensor<ElementPV>(shape(tTMEM_LOADcO(_, _0{}, _0{}, i)));
|
Tensor tTMrO = make_tensor<ElementPV>(shape(tTMEM_LOADcO(_, _0{}, _0{}, i)));
|
||||||
|
|
||||||
copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO);
|
copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO);
|
||||||
|
|
||||||
#ifndef ONLY_SOFTMAX
|
#ifndef ONLY_SOFTMAX
|
||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
for (int j = 0; j < size(tTMrO); j += 2) {
|
for (int j = 0; j < size(tTMrO); j += 2) {
|
||||||
@ -872,24 +872,24 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
|||||||
// good values would be either 32 or 64
|
// good values would be either 32 or 64
|
||||||
const int kCorrectionTileSize = 16;
|
const int kCorrectionTileSize = 16;
|
||||||
|
|
||||||
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem
|
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem
|
||||||
using TMEM_STORE = SM100_TMEM_STORE_32dp32b16x; // 4x32 threads with 64 cols of 32b elem
|
using TMEM_STORE = SM100_TMEM_STORE_32dp32b16x; // 4x32 threads with 64 cols of 32b elem
|
||||||
|
|
||||||
typename CollectiveMmaPV::TiledMma mma;
|
typename CollectiveMmaPV::TiledMma mma;
|
||||||
Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));
|
Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));
|
||||||
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));
|
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));
|
||||||
Tensor tOcO = mma.get_slice(0).partition_C(cO);
|
Tensor tOcO = mma.get_slice(0).partition_C(cO);
|
||||||
|
|
||||||
Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
||||||
Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
||||||
|
|
||||||
tOtO_i.data() = tOtO_i.data().get() + tmem_O;
|
tOtO_i.data() = tOtO_i.data().get() + tmem_O;
|
||||||
|
|
||||||
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i);
|
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i);
|
||||||
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
|
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
|
||||||
auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i);
|
auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i);
|
||||||
auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx);
|
auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx);
|
||||||
|
|
||||||
Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i);
|
Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i);
|
||||||
Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i);
|
Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i);
|
||||||
Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i);
|
Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i);
|
||||||
@ -899,7 +899,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
|||||||
float2 scale_f32x2 = make_float2(scale, scale);
|
float2 scale_f32x2 = make_float2(scale, scale);
|
||||||
|
|
||||||
Tensor tTMrO = make_tensor<ElementPV>(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{}));
|
Tensor tTMrO = make_tensor<ElementPV>(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{}));
|
||||||
|
|
||||||
auto copy_in = [&](int i) {
|
auto copy_in = [&](int i) {
|
||||||
Tensor tTMEM_LOADtO_i = tTMEM_LOADtO;
|
Tensor tTMEM_LOADtO_i = tTMEM_LOADtO;
|
||||||
tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize);
|
tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize);
|
||||||
@ -942,7 +942,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<class BlkCoord, class ProblemShape, class TensorStorageEpi, class CollectiveEpilogue>
|
template<class BlkCoord, class ProblemShape, class TensorStorageEpi>
|
||||||
CUTLASS_DEVICE auto
|
CUTLASS_DEVICE auto
|
||||||
correction(
|
correction(
|
||||||
BlkCoord const& blk_coord,
|
BlkCoord const& blk_coord,
|
||||||
@ -951,8 +951,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
|||||||
PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state,
|
PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state,
|
||||||
PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state,
|
PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state,
|
||||||
PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state,
|
PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state,
|
||||||
PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state,
|
PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state) {
|
||||||
CollectiveEpilogue& epilogue) {
|
|
||||||
|
|
||||||
int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape);
|
int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape);
|
||||||
|
|
||||||
@ -962,7 +961,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
|||||||
|
|
||||||
Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{}));
|
Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{}));
|
||||||
Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS);
|
Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS);
|
||||||
|
|
||||||
Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{})));
|
Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{})));
|
||||||
Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));
|
Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));
|
||||||
|
|
||||||
@ -1061,25 +1060,13 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
|||||||
// F2FP
|
// F2FP
|
||||||
// store to smem
|
// store to smem
|
||||||
Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{});
|
Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{});
|
||||||
Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), repeat_like(typename CollectiveEpilogue::StrideLSE{}, _1{}), epilogue.params.dLSE);
|
|
||||||
|
|
||||||
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO);
|
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO);
|
||||||
|
|
||||||
if (epilogue.params.ptr_LSE != nullptr) {
|
|
||||||
int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord);
|
|
||||||
|
|
||||||
ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);
|
|
||||||
|
|
||||||
if (row_idx < get<0>(problem_shape)) {
|
|
||||||
gLSE(row_idx, get<2>(blk_coord)) = lse;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cutlass::arch::fence_view_async_tmem_load();
|
cutlass::arch::fence_view_async_tmem_load();
|
||||||
|
|
||||||
pipeline_o.consumer_release(pipeline_o_consumer_state);
|
pipeline_o.consumer_release(pipeline_o_consumer_state);
|
||||||
++pipeline_o_consumer_state;
|
++pipeline_o_consumer_state;
|
||||||
|
|
||||||
pipeline_epi.producer_commit(pipeline_epi_producer_state);
|
pipeline_epi.producer_commit(pipeline_epi_producer_state);
|
||||||
++pipeline_epi_producer_state;
|
++pipeline_epi_producer_state;
|
||||||
|
|
||||||
@ -1096,16 +1083,6 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
|||||||
|
|
||||||
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _1{}, sO);
|
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _1{}, sO);
|
||||||
|
|
||||||
if (epilogue.params.ptr_LSE != nullptr) {
|
|
||||||
int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{});
|
|
||||||
|
|
||||||
ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);
|
|
||||||
|
|
||||||
if (row_idx < get<0>(problem_shape)) {
|
|
||||||
gLSE(row_idx, get<2>(blk_coord)) = lse;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cutlass::arch::fence_view_async_tmem_load();
|
cutlass::arch::fence_view_async_tmem_load();
|
||||||
|
|
||||||
pipeline_o.consumer_release(pipeline_o_consumer_state);
|
pipeline_o.consumer_release(pipeline_o_consumer_state);
|
||||||
|
|||||||
@ -50,19 +50,13 @@ namespace cutlass::fmha::device {
|
|||||||
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
|
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template<
|
template<class Element, class ElementAccumulator, class TileShape, class Mask>
|
||||||
class ProblemShape,
|
|
||||||
class Element,
|
|
||||||
class ElementAccumulator,
|
|
||||||
class TileShape,
|
|
||||||
class Mask
|
|
||||||
>
|
|
||||||
class Sm100FmhaBwd {
|
class Sm100FmhaBwd {
|
||||||
public:
|
public:
|
||||||
/// Argument structure: User API
|
/// Argument structure: User API
|
||||||
struct Arguments {
|
struct Arguments {
|
||||||
// Q K D HB
|
// Q K D HB
|
||||||
ProblemShape problem_shape;
|
cute::tuple<int, int, int, cute::tuple<int, int>> problem_size;
|
||||||
|
|
||||||
const Element* ptr_Q;
|
const Element* ptr_Q;
|
||||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_Q;
|
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_Q;
|
||||||
@ -92,16 +86,14 @@ public:
|
|||||||
};
|
};
|
||||||
|
|
||||||
using OperationSumOdO = cutlass::fmha::device::FMHA<
|
using OperationSumOdO = cutlass::fmha::device::FMHA<
|
||||||
cutlass::fmha::kernel::FmhaKernelBwdSumOdO<ProblemShape, Element, ElementAccumulator>
|
cutlass::fmha::kernel::FmhaKernelBwdSumOdO<Element, ElementAccumulator>
|
||||||
>;
|
>;
|
||||||
using OperationConvert = cutlass::fmha::device::FMHA<
|
using OperationConvert = cutlass::fmha::device::FMHA<
|
||||||
cutlass::fmha::kernel::FmhaKernelBwdConvert<ProblemShape, Element, ElementAccumulator>
|
cutlass::fmha::kernel::FmhaKernelBwdConvert<Element, ElementAccumulator>
|
||||||
>;
|
>;
|
||||||
|
|
||||||
using Operation = cutlass::fmha::device::FMHA<
|
using Operation = cutlass::fmha::device::FMHA<
|
||||||
cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized<
|
cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized<Element, ElementAccumulator, TileShape, Mask>
|
||||||
ProblemShape, Element, ElementAccumulator, TileShape, Mask
|
|
||||||
>
|
|
||||||
>;
|
>;
|
||||||
using Kernel = typename Operation::Kernel;
|
using Kernel = typename Operation::Kernel;
|
||||||
|
|
||||||
@ -121,15 +113,15 @@ private:
|
|||||||
ElementAccumulator* sum_odo = nullptr,
|
ElementAccumulator* sum_odo = nullptr,
|
||||||
ElementAccumulator* scaled_lse = nullptr) {
|
ElementAccumulator* scaled_lse = nullptr) {
|
||||||
using namespace cute;
|
using namespace cute;
|
||||||
auto [Q_, K, D, HB] = args.problem_shape;
|
auto [Q, K, D, HB] = args.problem_size;
|
||||||
auto [H, B] = HB;
|
auto [H, B] = HB;
|
||||||
D = cutlass::round_up(D, 8); // Alignment
|
D = cutlass::round_up(D, 8); // Alignment
|
||||||
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
|
Q = cutlass::round_up(Q, 8); // Alignment
|
||||||
auto stride_sum_OdO = make_stride(_1{}, make_stride(Q, Q*H));
|
auto stride_sum_OdO = make_stride(_1{}, make_stride(Q, Q*H));
|
||||||
auto stride_scaled_lse = make_stride(_1{}, make_stride(Q, Q*H));
|
auto stride_scaled_lse = make_stride(_1{}, make_stride(Q, Q*H));
|
||||||
auto log2_e = log2f(expf(1.0f));
|
auto log2_e = log2f(expf(1.0f));
|
||||||
return typename OperationSumOdO::Arguments {
|
return typename OperationSumOdO::Arguments {
|
||||||
args.problem_shape,
|
args.problem_size,
|
||||||
args.ptr_O, args.stride_O,
|
args.ptr_O, args.stride_O,
|
||||||
args.ptr_dO, args.stride_dO,
|
args.ptr_dO, args.stride_dO,
|
||||||
sum_odo, stride_sum_OdO,
|
sum_odo, stride_sum_OdO,
|
||||||
@ -141,13 +133,13 @@ private:
|
|||||||
|
|
||||||
static typename OperationConvert::Arguments to_convert_arguments(Arguments const& args, ElementAccumulator* src = nullptr) {
|
static typename OperationConvert::Arguments to_convert_arguments(Arguments const& args, ElementAccumulator* src = nullptr) {
|
||||||
using namespace cute;
|
using namespace cute;
|
||||||
auto [Q_, K, D, HB] = args.problem_shape;
|
auto [Q, K, D, HB] = args.problem_size;
|
||||||
auto [H, B] = HB;
|
auto [H, B] = HB;
|
||||||
D = cutlass::round_up(D, 8); // Alignment
|
D = cutlass::round_up(D, 8); // Alignment
|
||||||
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
|
Q = cutlass::round_up(Q, 8); // Alignment
|
||||||
auto stride_src_dQ = make_stride(D, _1{}, make_stride(D*Q, D*Q*H));
|
auto stride_src_dQ = make_stride(D, _1{}, make_stride(D*Q, D*Q*H));
|
||||||
return typename OperationConvert::Arguments {
|
return typename OperationConvert::Arguments {
|
||||||
args.problem_shape,
|
args.problem_size,
|
||||||
src, stride_src_dQ,
|
src, stride_src_dQ,
|
||||||
nullptr, stride_src_dQ,
|
nullptr, stride_src_dQ,
|
||||||
nullptr, stride_src_dQ,
|
nullptr, stride_src_dQ,
|
||||||
@ -164,7 +156,7 @@ private:
|
|||||||
ElementAccumulator* scaled_lse = nullptr, cute::tuple<cute::_1, cute::tuple<int, int>> const& stride_scaled_lse = {},
|
ElementAccumulator* scaled_lse = nullptr, cute::tuple<cute::_1, cute::tuple<int, int>> const& stride_scaled_lse = {},
|
||||||
ElementAccumulator* dQ_acc = nullptr, cute::tuple<int, cute::_1, cute::tuple<int, int>> const& stride_dQ = {}) {
|
ElementAccumulator* dQ_acc = nullptr, cute::tuple<int, cute::_1, cute::tuple<int, int>> const& stride_dQ = {}) {
|
||||||
return typename Operation::Arguments{
|
return typename Operation::Arguments{
|
||||||
args.problem_shape,
|
args.problem_size,
|
||||||
{ args.ptr_Q, args.stride_Q,
|
{ args.ptr_Q, args.stride_Q,
|
||||||
args.ptr_K, args.stride_K,
|
args.ptr_K, args.stride_K,
|
||||||
args.ptr_V, args.stride_V,
|
args.ptr_V, args.stride_V,
|
||||||
@ -207,10 +199,10 @@ public:
|
|||||||
/// Gets the workspace size
|
/// Gets the workspace size
|
||||||
static size_t
|
static size_t
|
||||||
get_workspace_size(Arguments const& args) {
|
get_workspace_size(Arguments const& args) {
|
||||||
auto [Q_, K, D, HB] = args.problem_shape;
|
auto [Q, K, D, HB] = args.problem_size;
|
||||||
auto [H, B] = HB;
|
auto [H, B] = HB;
|
||||||
D = cutlass::round_up(D, 8); // Alignment
|
D = cutlass::round_up(D, 8); // Alignment
|
||||||
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
|
Q = cutlass::round_up(Q, 8); // Alignment
|
||||||
size_t workspace_bytes = 0;
|
size_t workspace_bytes = 0;
|
||||||
// OdO vector
|
// OdO vector
|
||||||
workspace_bytes += B*H*Q * sizeof(ElementAccumulator);
|
workspace_bytes += B*H*Q * sizeof(ElementAccumulator);
|
||||||
@ -227,10 +219,10 @@ public:
|
|||||||
CUTLASS_TRACE_HOST("Universal::initialize_split() - workspace_dQ="
|
CUTLASS_TRACE_HOST("Universal::initialize_split() - workspace_dQ="
|
||||||
<< workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null"));
|
<< workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null"));
|
||||||
|
|
||||||
auto [Q_, K, D, HB] = args.problem_shape;
|
auto [Q, K, D, HB] = args.problem_size;
|
||||||
auto [H, B] = HB;
|
auto [H, B] = HB;
|
||||||
D = cutlass::round_up(D, 8); // Alignment
|
D = cutlass::round_up(D, 8); // Alignment
|
||||||
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
|
Q = cutlass::round_up(Q, 8); // Alignment
|
||||||
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_sum_OdO);
|
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_sum_OdO);
|
||||||
ElementAccumulator* scaled_lse = reinterpret_cast<ElementAccumulator*>(workspace_scaled_lse);
|
ElementAccumulator* scaled_lse = reinterpret_cast<ElementAccumulator*>(workspace_scaled_lse);
|
||||||
ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_dQ);
|
ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_dQ);
|
||||||
@ -256,10 +248,10 @@ public:
|
|||||||
CUTLASS_TRACE_HOST("Universal::initialize() - workspace "
|
CUTLASS_TRACE_HOST("Universal::initialize() - workspace "
|
||||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||||
|
|
||||||
auto [Q_, K, D, HB] = args.problem_shape;
|
auto [Q, K, D, HB] = args.problem_size;
|
||||||
auto [H, B] = HB;
|
auto [H, B] = HB;
|
||||||
D = cutlass::round_up(D, 8); // Alignment
|
D = cutlass::round_up(D, 8); // Alignment
|
||||||
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
|
Q = cutlass::round_up(Q, 8); // Alignment
|
||||||
char* workspace_chr = reinterpret_cast<char*>(workspace);
|
char* workspace_chr = reinterpret_cast<char*>(workspace);
|
||||||
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_chr);
|
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_chr);
|
||||||
workspace_chr += B*H*Q * sizeof(ElementAccumulator);
|
workspace_chr += B*H*Q * sizeof(ElementAccumulator);
|
||||||
|
|||||||
@ -39,11 +39,11 @@ namespace cutlass::fmha::kernel {
|
|||||||
|
|
||||||
using namespace cute;
|
using namespace cute;
|
||||||
|
|
||||||
template<class ProblemShape, class Element, class ElementAcc>
|
template<class Element, class ElementAcc>
|
||||||
struct FmhaKernelBwdConvert {
|
struct FmhaKernelBwdConvert {
|
||||||
|
|
||||||
struct Arguments {
|
struct Arguments {
|
||||||
ProblemShape problem_shape;
|
tuple<int, int, int, tuple<int, int>> problem_size;
|
||||||
|
|
||||||
const ElementAcc* ptr_src_dQ;
|
const ElementAcc* ptr_src_dQ;
|
||||||
tuple<int, _1, tuple<int, int>> stride_src_dQ;
|
tuple<int, _1, tuple<int, int>> stride_src_dQ;
|
||||||
@ -85,11 +85,11 @@ struct FmhaKernelBwdConvert {
|
|||||||
static const int kIterationsSeq = kBlockSeq / kNumThreadsSeq;
|
static const int kIterationsSeq = kBlockSeq / kNumThreadsSeq;
|
||||||
|
|
||||||
static bool can_implement(Arguments const& args) {
|
static bool can_implement(Arguments const& args) {
|
||||||
return get<2>(args.problem_shape) % kElementsPerLoad == 0;
|
return get<2>(args.problem_size) % kElementsPerLoad == 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
static dim3 get_grid_shape(Params const& params) {
|
static dim3 get_grid_shape(Params const& params) {
|
||||||
dim3 grid(size<3,0>(params.problem_shape), size<3,1>(params.problem_shape), ceil_div(std::max(size<0>(params.problem_shape), size<1>(params.problem_shape)), kBlockSeq));
|
dim3 grid(size<3,0>(params.problem_size), size<3,1>(params.problem_size), ceil_div(std::max(size<0>(params.problem_size), size<1>(params.problem_size)), kBlockSeq));
|
||||||
return grid;
|
return grid;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -102,25 +102,18 @@ struct FmhaKernelBwdConvert {
|
|||||||
return args;
|
return args;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<class StrideSrc, class StrideDest, class Count>
|
template<class StrideSrc, class StrideDest>
|
||||||
CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, Count const& count) {
|
CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, int count) {
|
||||||
auto ptr_src_bh = ptr_src + get<2,0>(stride_src) * blockIdx.x + get<2,1>(stride_src) * blockIdx.y;
|
auto ptr_src_bh = ptr_src + get<2,0>(stride_src) * blockIdx.x + get<2,1>(stride_src) * blockIdx.y;
|
||||||
auto ptr_dest_bh = ptr_dest + get<2,0>(stride_dest) * blockIdx.x + get<2,1>(stride_dest) * blockIdx.y;
|
auto ptr_dest_bh = ptr_dest + get<2,0>(stride_dest) * blockIdx.x + get<2,1>(stride_dest) * blockIdx.y;
|
||||||
|
|
||||||
int seqlen = count;
|
|
||||||
if constexpr (is_variable_length_v<decltype(count)>) {
|
|
||||||
int offset = count.cumulative_length[blockIdx.y];
|
|
||||||
ptr_dest_bh += offset * get<0>(stride_dest);
|
|
||||||
seqlen = count.cumulative_length[blockIdx.y + 1] - offset;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int idx_s_t = threadIdx.y; idx_s_t < kBlockSeq; idx_s_t += kNumThreadsSeq) {
|
for (int idx_s_t = threadIdx.y; idx_s_t < kBlockSeq; idx_s_t += kNumThreadsSeq) {
|
||||||
int idx_s = idx_s_t + kBlockSeq * blockIdx.z;
|
int idx_s = idx_s_t + kBlockSeq * blockIdx.z;
|
||||||
if (idx_s >= seqlen) continue;
|
if (idx_s >= count) continue;
|
||||||
auto ptr_src_bhs = ptr_src_bh + idx_s * get<0>(stride_src);
|
auto ptr_src_bhs = ptr_src_bh + idx_s * get<0>(stride_src);
|
||||||
auto ptr_dest_bhs = ptr_dest_bh + idx_s * get<0>(stride_dest);
|
auto ptr_dest_bhs = ptr_dest_bh + idx_s * get<0>(stride_dest);
|
||||||
|
|
||||||
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_shape); idx_d += kElementsPerLoad * kNumThreadsD) {
|
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) {
|
||||||
ElementAcc value_src[kElementsPerLoad];
|
ElementAcc value_src[kElementsPerLoad];
|
||||||
Element value_dest[kElementsPerLoad];
|
Element value_dest[kElementsPerLoad];
|
||||||
|
|
||||||
@ -139,13 +132,13 @@ struct FmhaKernelBwdConvert {
|
|||||||
|
|
||||||
CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) {
|
CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) {
|
||||||
if (params.ptr_src_dQ != nullptr) {
|
if (params.ptr_src_dQ != nullptr) {
|
||||||
copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_shape));
|
copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_size));
|
||||||
}
|
}
|
||||||
if (params.ptr_src_dK != nullptr) {
|
if (params.ptr_src_dK != nullptr) {
|
||||||
copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<1>(params.problem_shape));
|
copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<1>(params.problem_size));
|
||||||
}
|
}
|
||||||
if (params.ptr_src_dV != nullptr) {
|
if (params.ptr_src_dV != nullptr) {
|
||||||
copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<1>(params.problem_shape));
|
copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<1>(params.problem_size));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@ -39,11 +39,11 @@ namespace cutlass::fmha::kernel {
|
|||||||
|
|
||||||
using namespace cute;
|
using namespace cute;
|
||||||
|
|
||||||
template<class ProblemShape, class Element, class ElementAcc>
|
template<class Element, class ElementAcc>
|
||||||
struct FmhaKernelBwdSumOdO {
|
struct FmhaKernelBwdSumOdO {
|
||||||
|
|
||||||
struct Arguments {
|
struct Arguments {
|
||||||
ProblemShape problem_shape;
|
cute::tuple<int, int, int, cute::tuple<int, int>> problem_size;
|
||||||
|
|
||||||
const Element* ptr_O;
|
const Element* ptr_O;
|
||||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_O;
|
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_O;
|
||||||
@ -86,11 +86,11 @@ struct FmhaKernelBwdSumOdO {
|
|||||||
static const int kIterationsQ = kBlockQ / kNumThreadsQ;
|
static const int kIterationsQ = kBlockQ / kNumThreadsQ;
|
||||||
|
|
||||||
static bool can_implement(Arguments const& args) {
|
static bool can_implement(Arguments const& args) {
|
||||||
return get<2>(args.problem_shape) % kElementsPerLoad == 0;
|
return get<2>(args.problem_size) % kElementsPerLoad == 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
static dim3 get_grid_shape(Params const& params) {
|
static dim3 get_grid_shape(Params const& params) {
|
||||||
dim3 grid(ceil_div(size<0>(params.problem_shape), kBlockQ), size<3,0>(params.problem_shape), size<3,1>(params.problem_shape));
|
dim3 grid(ceil_div(size<0>(params.problem_size), kBlockQ), size<3,0>(params.problem_size), size<3,1>(params.problem_size));
|
||||||
return grid;
|
return grid;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -110,20 +110,10 @@ struct FmhaKernelBwdSumOdO {
|
|||||||
auto ptr_lse_bh = params.ptr_lse + blockIdx.y * get<1,0>(params.stride_lse) + blockIdx.z * get<1,1>(params.stride_lse);
|
auto ptr_lse_bh = params.ptr_lse + blockIdx.y * get<1,0>(params.stride_lse) + blockIdx.z * get<1,1>(params.stride_lse);
|
||||||
auto ptr_scaled_lse_bh = params.ptr_scaled_lse + blockIdx.y * get<1,0>(params.stride_scaled_lse) + blockIdx.z * get<1,1>(params.stride_scaled_lse);
|
auto ptr_scaled_lse_bh = params.ptr_scaled_lse + blockIdx.y * get<1,0>(params.stride_scaled_lse) + blockIdx.z * get<1,1>(params.stride_scaled_lse);
|
||||||
|
|
||||||
auto problem_q = get<0>(params.problem_shape);
|
|
||||||
int seqlen_q = problem_q;
|
|
||||||
if constexpr (is_variable_length_v<decltype(problem_q)>) {
|
|
||||||
int offset = problem_q.cumulative_length[blockIdx.z];
|
|
||||||
ptr_O_bh += offset * get<0>(params.stride_O);
|
|
||||||
ptr_dO_bh += offset * get<0>(params.stride_dO);
|
|
||||||
ptr_lse_bh += offset * get<0>(params.stride_lse);
|
|
||||||
seqlen_q = problem_q.cumulative_length[blockIdx.z + 1] - offset;
|
|
||||||
}
|
|
||||||
|
|
||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
for (int idx_q_t = threadIdx.y; idx_q_t < kBlockQ; idx_q_t += kNumThreadsQ) {
|
for (int idx_q_t = threadIdx.y; idx_q_t < kBlockQ; idx_q_t += kNumThreadsQ) {
|
||||||
int idx_q = idx_q_t + kBlockQ * blockIdx.x;
|
int idx_q = idx_q_t + kBlockQ * blockIdx.x;
|
||||||
if (idx_q >= seqlen_q) continue;
|
if (idx_q >= get<0>(params.problem_size)) continue;
|
||||||
ElementAcc acc = 0;
|
ElementAcc acc = 0;
|
||||||
auto ptr_O_bhq = ptr_O_bh + idx_q * get<0>(params.stride_O);
|
auto ptr_O_bhq = ptr_O_bh + idx_q * get<0>(params.stride_O);
|
||||||
auto ptr_dO_bhq = ptr_dO_bh + idx_q * get<0>(params.stride_dO);
|
auto ptr_dO_bhq = ptr_dO_bh + idx_q * get<0>(params.stride_dO);
|
||||||
@ -131,7 +121,7 @@ struct FmhaKernelBwdSumOdO {
|
|||||||
auto ptr_lse_bhq = ptr_lse_bh + idx_q * get<0>(params.stride_lse);
|
auto ptr_lse_bhq = ptr_lse_bh + idx_q * get<0>(params.stride_lse);
|
||||||
auto ptr_scaled_lse_bhq = ptr_scaled_lse_bh + idx_q * get<0>(params.stride_scaled_lse);
|
auto ptr_scaled_lse_bhq = ptr_scaled_lse_bh + idx_q * get<0>(params.stride_scaled_lse);
|
||||||
|
|
||||||
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_shape); idx_d += kElementsPerLoad * kNumThreadsD) {
|
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) {
|
||||||
Element value_O[kElementsPerLoad];
|
Element value_O[kElementsPerLoad];
|
||||||
Element value_dO[kElementsPerLoad];
|
Element value_dO[kElementsPerLoad];
|
||||||
|
|
||||||
|
|||||||
@ -82,4 +82,4 @@ struct Option {
|
|||||||
using option_value = Value;
|
using option_value = Value;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace cutlass::fmha::kernel
|
} // namespace cutlass::fmha::kernel
|
||||||
|
|||||||
@ -90,8 +90,8 @@ struct PersistentTileScheduler {
|
|||||||
struct Params {
|
struct Params {
|
||||||
int num_blocks;
|
int num_blocks;
|
||||||
FastDivmod divmod_m_block;
|
FastDivmod divmod_m_block;
|
||||||
FastDivmod divmod_b;
|
|
||||||
FastDivmod divmod_h;
|
FastDivmod divmod_h;
|
||||||
|
FastDivmod divmod_b;
|
||||||
|
|
||||||
KernelHardwareInfo hw_info;
|
KernelHardwareInfo hw_info;
|
||||||
};
|
};
|
||||||
@ -146,7 +146,7 @@ struct PersistentTileScheduler {
|
|||||||
params.divmod_m_block(block_decode, m_block, block_decode);
|
params.divmod_m_block(block_decode, m_block, block_decode);
|
||||||
params.divmod_b(block_decode, bidb, block_decode);
|
params.divmod_b(block_decode, bidb, block_decode);
|
||||||
params.divmod_h(block_decode, bidh, block_decode);
|
params.divmod_h(block_decode, bidh, block_decode);
|
||||||
return make_coord(m_block, _0{}, make_coord(bidb, bidh));
|
return make_coord(m_block, _0{}, make_coord(bidh, bidb));
|
||||||
}
|
}
|
||||||
|
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
|
|||||||
@ -43,8 +43,6 @@
|
|||||||
|
|
||||||
#include "collective/fmha_common.hpp"
|
#include "collective/fmha_common.hpp"
|
||||||
|
|
||||||
#include <cmath>
|
|
||||||
|
|
||||||
namespace cutlass::fmha::kernel {
|
namespace cutlass::fmha::kernel {
|
||||||
|
|
||||||
using namespace cutlass::fmha::collective;
|
using namespace cutlass::fmha::collective;
|
||||||
@ -52,7 +50,6 @@ using namespace cutlass::fmha::collective;
|
|||||||
using namespace cute;
|
using namespace cute;
|
||||||
|
|
||||||
template<
|
template<
|
||||||
class ProblemShape,
|
|
||||||
class Element,
|
class Element,
|
||||||
class ElementAcc,
|
class ElementAcc,
|
||||||
class TileShape,
|
class TileShape,
|
||||||
@ -121,7 +118,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
|
|
||||||
using TensorStrideContiguousK = Stride<int, _1, Stride<int, int>>;
|
using TensorStrideContiguousK = Stride<int, _1, Stride<int, int>>;
|
||||||
using TensorStrideContiguousMN = Stride<_1, int, Stride<int, int>>;
|
using TensorStrideContiguousMN = Stride<_1, int, Stride<int, int>>;
|
||||||
|
|
||||||
// compute S
|
// compute S
|
||||||
using CollectiveMmaKQ = typename cutlass::gemm::collective::CollectiveBuilder<
|
using CollectiveMmaKQ = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
|
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
|
||||||
@ -277,6 +274,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
static constexpr int SharedStorageSize = offsetof(SharedStorage, tmem_base_ptr) + sizeof(uint32_t);
|
static constexpr int SharedStorageSize = offsetof(SharedStorage, tmem_base_ptr) + sizeof(uint32_t);
|
||||||
static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem");
|
static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem");
|
||||||
|
|
||||||
|
using ProblemShape = Shape<int, int, int, Shape<int, int>>; // Q K D (H B), eventuall D = (D_QK, D_VO)
|
||||||
using TensorStride = TensorStrideContiguousK; // S D (H B)
|
using TensorStride = TensorStrideContiguousK; // S D (H B)
|
||||||
using RowTensorStride = Stride<_1, Stride<int, int>>; // S (H B)
|
using RowTensorStride = Stride<_1, Stride<int, int>>; // S (H B)
|
||||||
|
|
||||||
@ -362,16 +360,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
|
|
||||||
|
|
||||||
static Params to_underlying_arguments(Arguments const& args, void*) {
|
static Params to_underlying_arguments(Arguments const& args, void*) {
|
||||||
auto [Q_, K_, D, HB] = args.problem_shape;
|
auto [Q, K, D, HB] = args.problem_shape;
|
||||||
int Q = Q_;
|
|
||||||
int K = K_;
|
|
||||||
|
|
||||||
if constexpr (is_variable_length_v<decltype(Q_)>) {
|
|
||||||
Q = Q_.total_length;
|
|
||||||
}
|
|
||||||
if constexpr (is_variable_length_v<decltype(K_)>) {
|
|
||||||
K = K_.total_length;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto params_kq = CollectiveMmaKQ::to_underlying_arguments(
|
auto params_kq = CollectiveMmaKQ::to_underlying_arguments(
|
||||||
make_shape(K, Q, D, HB),
|
make_shape(K, Q, D, HB),
|
||||||
@ -389,10 +378,10 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
|
|
||||||
TMA_DQ tma_red_dq = make_tma_copy(
|
TMA_DQ tma_red_dq = make_tma_copy(
|
||||||
SM90_TMA_REDUCE_ADD{},
|
SM90_TMA_REDUCE_ADD{},
|
||||||
make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q_, D, HB), args.mainloop.stride_dq_acc),
|
make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q, D, HB), args.mainloop.stride_dq_acc),
|
||||||
SmemLayoutDQ{}(_, _, _0{})
|
SmemLayoutDQ{}(_, _, _0{})
|
||||||
);
|
);
|
||||||
|
|
||||||
return Params{
|
return Params{
|
||||||
args.problem_shape,
|
args.problem_shape,
|
||||||
args.mainloop,
|
args.mainloop,
|
||||||
@ -427,11 +416,10 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template<class BlkCoord, class BlkOffset, class ProblemShape_>
|
template<class BlkCoord>
|
||||||
CUTLASS_DEVICE void load(
|
CUTLASS_DEVICE void load(
|
||||||
BlkCoord const& blk_coord,
|
BlkCoord const& blk_coord,
|
||||||
BlkOffset const& blk_offset,
|
ProblemShape const& problem_shape,
|
||||||
ProblemShape_ const& problem_shape,
|
|
||||||
int iter_index,
|
int iter_index,
|
||||||
int iter_count,
|
int iter_count,
|
||||||
MainloopArguments const& mainloop_args,
|
MainloopArguments const& mainloop_args,
|
||||||
@ -452,15 +440,10 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
|
|
||||||
uint16_t mcast_mask = 0;
|
uint16_t mcast_mask = 0;
|
||||||
|
|
||||||
auto mK_in = mainloop_params.tma_load_k.get_tma_tensor(make_shape(K, D, HB));
|
auto mK = mainloop_params.tma_load_k.get_tma_tensor(make_shape(K, D, HB));
|
||||||
auto mV_in = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D, HB));
|
auto mQ = mainloop_params.tma_load_q.get_tma_tensor(make_shape(Q, D, HB));
|
||||||
auto mQ_in = mainloop_params.tma_load_q.get_tma_tensor(make_shape(Q, D, HB));
|
auto mV = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D, HB));
|
||||||
auto mDO_in = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D, HB));
|
auto mDO = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D, HB));
|
||||||
|
|
||||||
auto mK = domain_offset(select<1,2,3>(blk_offset), mK_in);
|
|
||||||
auto mV = domain_offset(select<1,2,3>(blk_offset), mV_in);
|
|
||||||
auto mQ = domain_offset(select<0,2,3>(blk_offset), mQ_in);
|
|
||||||
auto mDO = domain_offset(select<0,2,3>(blk_offset), mDO_in);
|
|
||||||
|
|
||||||
auto gK = local_tile(mK, TileShapeKQ{}, make_coord(_,_,_), Step<_1, X, _1>{});
|
auto gK = local_tile(mK, TileShapeKQ{}, make_coord(_,_,_), Step<_1, X, _1>{});
|
||||||
auto gQ = local_tile(mQ, TileShapeKQ{}, make_coord(_,_,_), Step<X, _1, _1>{});
|
auto gQ = local_tile(mQ, TileShapeKQ{}, make_coord(_,_,_), Step<X, _1, _1>{});
|
||||||
@ -469,7 +452,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
|
|
||||||
ThrMMA cta_mma_kq = TiledMmaKQ{}.get_slice(_0{});
|
ThrMMA cta_mma_kq = TiledMmaKQ{}.get_slice(_0{});
|
||||||
ThrMMA cta_mma_vdo = TiledMmaVDO{}.get_slice(_0{});
|
ThrMMA cta_mma_vdo = TiledMmaVDO{}.get_slice(_0{});
|
||||||
|
|
||||||
auto tSTgK = cta_mma_kq.partition_A(gK);
|
auto tSTgK = cta_mma_kq.partition_A(gK);
|
||||||
auto tSTgQ = cta_mma_kq.partition_B(gQ);
|
auto tSTgQ = cta_mma_kq.partition_B(gQ);
|
||||||
auto tDPTgV = cta_mma_vdo.partition_A(gV);
|
auto tDPTgV = cta_mma_vdo.partition_A(gV);
|
||||||
@ -494,8 +477,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
group_modes<0,3>(sDO), group_modes<0,3>(tDPTgDO));
|
group_modes<0,3>(sDO), group_modes<0,3>(tDPTgDO));
|
||||||
|
|
||||||
// set up lse and sum_odo
|
// set up lse and sum_odo
|
||||||
|
|
||||||
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch] = blk_coord;
|
auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord;
|
||||||
|
|
||||||
pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state);
|
pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state);
|
||||||
auto tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state);
|
auto tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state);
|
||||||
@ -512,7 +495,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// load Q
|
// load Q
|
||||||
if (cute::elect_one_sync()) {
|
if (cute::elect_one_sync()) {
|
||||||
cute::copy(
|
cute::copy(
|
||||||
mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask),
|
mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask),
|
||||||
tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch),
|
tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch),
|
||||||
@ -532,14 +515,12 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
int smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4;
|
int smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4;
|
||||||
int gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
|
int gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
|
||||||
auto mLSE = make_tensor(mainloop_args.ptr_lse, make_shape(Q, HB), mainloop_args.stride_lse);
|
auto mLSE = make_tensor(mainloop_args.ptr_lse, make_shape(Q, HB), mainloop_args.stride_lse);
|
||||||
for (int i = 0; i < 4; i++) {
|
cutlass::arch::cp_async_zfill<16>(
|
||||||
cutlass::arch::cp_async_zfill<4>(
|
shared_tensors.smem_lse.begin() + smem_idx,
|
||||||
shared_tensors.smem_lse.begin() + smem_idx + i,
|
&mLSE(gmem_idx, blk_coord_batch),
|
||||||
&mLSE(gmem_idx + i, blk_coord_batch),
|
gmem_idx < Q
|
||||||
gmem_idx + i < Q
|
);
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive);
|
pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive);
|
||||||
++pipeline_load_compute_lse_producer_state;
|
++pipeline_load_compute_lse_producer_state;
|
||||||
|
|
||||||
@ -548,7 +529,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state);
|
tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state);
|
||||||
|
|
||||||
pipeline_load_mma_do.producer_expect_transaction(pipeline_load_mma_do_producer_state, kTransactionsBytesLoadV);
|
pipeline_load_mma_do.producer_expect_transaction(pipeline_load_mma_do_producer_state, kTransactionsBytesLoadV);
|
||||||
|
|
||||||
// load V
|
// load V
|
||||||
if (cute::elect_one_sync()) {
|
if (cute::elect_one_sync()) {
|
||||||
cute::copy(
|
cute::copy(
|
||||||
@ -559,7 +540,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// load dO
|
// load dO
|
||||||
if (cute::elect_one_sync()) {
|
if (cute::elect_one_sync()) {
|
||||||
cute::copy(
|
cute::copy(
|
||||||
mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask),
|
mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask),
|
||||||
tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch),
|
tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch),
|
||||||
@ -575,13 +556,11 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4;
|
smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4;
|
||||||
gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
|
gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
|
||||||
auto mSumOdO = make_tensor(mainloop_args.ptr_sum_odo, make_shape(Q, HB), mainloop_args.stride_sum_odo);
|
auto mSumOdO = make_tensor(mainloop_args.ptr_sum_odo, make_shape(Q, HB), mainloop_args.stride_sum_odo);
|
||||||
for (int i = 0; i < 4; i++) {
|
cutlass::arch::cp_async<16>(
|
||||||
cutlass::arch::cp_async_zfill<4>(
|
shared_tensors.smem_sum_odo.begin() + smem_idx,
|
||||||
shared_tensors.smem_sum_odo.begin() + smem_idx + i,
|
&mSumOdO(gmem_idx, blk_coord_batch),
|
||||||
&mSumOdO(gmem_idx + i, blk_coord_batch),
|
gmem_idx < Q
|
||||||
gmem_idx + i < Q
|
);
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive);
|
pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive);
|
||||||
++pipeline_load_compute_sum_odo_producer_state;
|
++pipeline_load_compute_sum_odo_producer_state;
|
||||||
@ -594,7 +573,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state);
|
tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state);
|
||||||
|
|
||||||
// load Q
|
// load Q
|
||||||
if (cute::elect_one_sync()) {
|
if (cute::elect_one_sync()) {
|
||||||
cute::copy(
|
cute::copy(
|
||||||
mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask),
|
mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask),
|
||||||
tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch),
|
tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch),
|
||||||
@ -605,26 +584,24 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
++pipeline_load_mma_q_producer_state;
|
++pipeline_load_mma_q_producer_state;
|
||||||
|
|
||||||
pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state);
|
pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state);
|
||||||
|
|
||||||
// load LSE
|
// load LSE
|
||||||
smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4;
|
smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4;
|
||||||
gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
|
gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
|
||||||
for (int i = 0; i < 4; i++) {
|
cutlass::arch::cp_async<16>(
|
||||||
cutlass::arch::cp_async_zfill<4>(
|
shared_tensors.smem_lse.begin() + smem_idx,
|
||||||
shared_tensors.smem_lse.begin() + smem_idx + i,
|
&mLSE(gmem_idx, blk_coord_batch),
|
||||||
&mLSE(gmem_idx + i, blk_coord_batch),
|
gmem_idx < Q
|
||||||
gmem_idx + i < Q
|
);
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive);
|
pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive);
|
||||||
++pipeline_load_compute_lse_producer_state;
|
++pipeline_load_compute_lse_producer_state;
|
||||||
|
|
||||||
pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state);
|
pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state);
|
||||||
tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state);
|
tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state);
|
||||||
|
|
||||||
// load dO
|
// load dO
|
||||||
if (cute::elect_one_sync()) {
|
if (cute::elect_one_sync()) {
|
||||||
cute::copy(
|
cute::copy(
|
||||||
mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask),
|
mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask),
|
||||||
tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch),
|
tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch),
|
||||||
@ -635,18 +612,16 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
++pipeline_load_mma_do_producer_state;
|
++pipeline_load_mma_do_producer_state;
|
||||||
|
|
||||||
pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state);
|
pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state);
|
||||||
|
|
||||||
// load sum_OdO
|
// load sum_OdO
|
||||||
smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4;
|
smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4;
|
||||||
gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
|
gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
|
||||||
for (int i = 0; i < 4; i++) {
|
cutlass::arch::cp_async_zfill<16>(
|
||||||
cutlass::arch::cp_async_zfill<4>(
|
shared_tensors.smem_sum_odo.begin() + smem_idx,
|
||||||
shared_tensors.smem_sum_odo.begin() + smem_idx + i,
|
&mSumOdO(gmem_idx, blk_coord_batch),
|
||||||
&mSumOdO(gmem_idx + i, blk_coord_batch),
|
gmem_idx < Q
|
||||||
gmem_idx + i < Q
|
);
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive);
|
pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive);
|
||||||
++pipeline_load_compute_sum_odo_producer_state;
|
++pipeline_load_compute_sum_odo_producer_state;
|
||||||
|
|
||||||
@ -656,31 +631,31 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template<class BlkCoord, class ProblemShape_>
|
template<class BlkCoord>
|
||||||
CUTLASS_DEVICE void mma(
|
CUTLASS_DEVICE void mma(
|
||||||
BlkCoord const& blk_coord,
|
BlkCoord const& blk_coord,
|
||||||
ProblemShape_ const& problem_shape,
|
ProblemShape const& problem_shape,
|
||||||
int iter_index,
|
int iter_index,
|
||||||
int iter_count,
|
int iter_count,
|
||||||
MainloopArguments const& mainloop_args,
|
MainloopArguments const& mainloop_args,
|
||||||
TensorStorage& shared_tensors,
|
TensorStorage& shared_tensors,
|
||||||
PipelineLoadMmaQ& pipeline_load_mma_q,
|
PipelineLoadMmaQ& pipeline_load_mma_q,
|
||||||
typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state,
|
typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state,
|
||||||
PipelineLoadMmaDO& pipeline_load_mma_do,
|
PipelineLoadMmaDO& pipeline_load_mma_do,
|
||||||
typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_consumer_state,
|
typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_consumer_state,
|
||||||
PipelineMmaComputeS& pipeline_mma_compute_s,
|
PipelineMmaComputeS& pipeline_mma_compute_s,
|
||||||
typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_producer_state,
|
typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_producer_state,
|
||||||
PipelineMmaComputeDP& pipeline_mma_compute_dp,
|
PipelineMmaComputeDP& pipeline_mma_compute_dp,
|
||||||
typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_producer_state,
|
typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_producer_state,
|
||||||
PipelineMmaReduceDQ& pipeline_mma_reduce_dq,
|
PipelineMmaReduceDQ& pipeline_mma_reduce_dq,
|
||||||
typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_producer_state,
|
typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_producer_state,
|
||||||
PipelineComputeMmaP& pipeline_compute_mma_p,
|
PipelineComputeMmaP& pipeline_compute_mma_p,
|
||||||
typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_consumer_state,
|
typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_consumer_state,
|
||||||
PipelineComputeMmaDS& pipeline_compute_mma_ds,
|
PipelineComputeMmaDS& pipeline_compute_mma_ds,
|
||||||
typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_consumer_state,
|
typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_consumer_state,
|
||||||
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
|
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
|
||||||
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) {
|
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) {
|
||||||
|
|
||||||
auto [Q, K, D, HB] = problem_shape;
|
auto [Q, K, D, HB] = problem_shape;
|
||||||
|
|
||||||
auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{});
|
auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{});
|
||||||
@ -710,7 +685,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
Tensor tDVrP = TiledMmaPDO::make_fragment_A(sP)(_, _, _, _0{});
|
Tensor tDVrP = TiledMmaPDO::make_fragment_A(sP)(_, _, _, _0{});
|
||||||
tDVrP.data() = TmemAllocation::kP;
|
tDVrP.data() = TmemAllocation::kP;
|
||||||
Tensor tDVrDOT = TiledMmaPDO::make_fragment_B(sDOT);
|
Tensor tDVrDOT = TiledMmaPDO::make_fragment_B(sDOT);
|
||||||
|
|
||||||
TiledMmaKQ tiled_mma_kq;
|
TiledMmaKQ tiled_mma_kq;
|
||||||
TiledMmaVDO tiled_mma_vdo;
|
TiledMmaVDO tiled_mma_vdo;
|
||||||
TiledMmaDSK tiled_mma_dsk;
|
TiledMmaDSK tiled_mma_dsk;
|
||||||
@ -948,8 +923,6 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
TensorC const& coord,
|
TensorC const& coord,
|
||||||
TensorShape const& tensor_shape) {
|
TensorShape const& tensor_shape) {
|
||||||
|
|
||||||
Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); });
|
|
||||||
|
|
||||||
auto copy_op = make_cotiled_copy(
|
auto copy_op = make_cotiled_copy(
|
||||||
Copy_Atom<UniversalCopy<uint128_t>, Element>{},
|
Copy_Atom<UniversalCopy<uint128_t>, Element>{},
|
||||||
make_layout(make_shape(_1{}, Int<sizeof(uint128_t) / sizeof(Element)>{})),
|
make_layout(make_shape(_1{}, Int<sizeof(uint128_t) / sizeof(Element)>{})),
|
||||||
@ -957,91 +930,42 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
);
|
);
|
||||||
auto thr_copy = copy_op.get_slice(_0{});
|
auto thr_copy = copy_op.get_slice(_0{});
|
||||||
|
|
||||||
Tensor tCg = thr_copy.partition_D(gmem);
|
auto tCg = thr_copy.partition_D(gmem);
|
||||||
Tensor tCr = thr_copy.partition_S(quantize(regs));
|
auto tCr = thr_copy.partition_S(quantize(regs));
|
||||||
Tensor tPc = thr_copy.partition_D(preds);
|
auto tCc = thr_copy.partition_D(coord);
|
||||||
|
|
||||||
copy_if(copy_op, tPc, tCr, tCg);
|
constexpr int R = decltype(tCr.layout())::rank;
|
||||||
}
|
auto tCg_v = group_modes<1, R>(tCg);
|
||||||
|
auto tCr_v = group_modes<1, R>(tCr);
|
||||||
|
auto tCc_v = group_modes<1, R>(tCc);
|
||||||
|
auto tCp_v = make_tensor<bool>(shape<1>(tCc_v));
|
||||||
|
|
||||||
|
for (int i = 0; i < size(tCp_v); ++i) {
|
||||||
template<class BlkCoord, class BlkOffset, class ProblemShape_>
|
tCp_v(i) = elem_less(tCc_v(_0{},i), tensor_shape);
|
||||||
CUTLASS_DEVICE void epilogue_clear(
|
|
||||||
BlkCoord const& blk_coord,
|
|
||||||
BlkOffset const& blk_offset,
|
|
||||||
ProblemShape_ const& problem_shape,
|
|
||||||
MainloopArguments const& mainloop_args,
|
|
||||||
EpilogueArguments const& epilogue_args) {
|
|
||||||
|
|
||||||
auto [Q, K, D, HB] = problem_shape;
|
|
||||||
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch] = blk_coord;
|
|
||||||
|
|
||||||
auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk);
|
|
||||||
auto mDK = domain_offset(select<1,2,3>(blk_offset), mDK_in);
|
|
||||||
auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{})
|
|
||||||
(_, _, blk_coord_k, _0{}, blk_coord_batch);
|
|
||||||
|
|
||||||
Tensor cDK = domain_offset(
|
|
||||||
make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}),
|
|
||||||
make_identity_tensor(take<0,2>(TileShapeDSQ{}))
|
|
||||||
);
|
|
||||||
|
|
||||||
auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv);
|
|
||||||
auto mDV = domain_offset(select<1,2,3>(blk_offset), mDV_in);
|
|
||||||
auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{})
|
|
||||||
(_, _, blk_coord_k, _0{}, blk_coord_batch);
|
|
||||||
|
|
||||||
Tensor cDV = domain_offset(
|
|
||||||
make_coord(blk_coord_k * TileShapeK{}, _0{}),
|
|
||||||
make_identity_tensor(take<0,2>(TileShapePDO{}))
|
|
||||||
);
|
|
||||||
|
|
||||||
if (threadIdx.x >= 256) {
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto tiled_copy = make_cotiled_copy(
|
copy_if(copy_op, tCp_v, tCr_v, tCg_v);
|
||||||
Copy_Atom<UniversalCopy<uint128_t>, Element>{},
|
|
||||||
make_ordered_layout(make_shape(_256{}, Int<sizeof(uint128_t) / sizeof(Element)>{}), Step<_1, _0>{}),
|
|
||||||
make_ordered_layout(make_shape(TileShapeK{}, TileShapeDQK{}), Step<_1, _0>{}));
|
|
||||||
|
|
||||||
auto thr_copy = tiled_copy.get_slice(threadIdx.x);
|
|
||||||
auto tCgDK = thr_copy.partition_D(gDK);
|
|
||||||
auto tCcDK = thr_copy.partition_S(cDK);
|
|
||||||
auto tCrDK = make_tensor<Element>(shape(tCcDK));
|
|
||||||
|
|
||||||
clear(tCrDK);
|
|
||||||
store(tCgDK, tCrDK, tCcDK, select<1,2>(problem_shape));
|
|
||||||
|
|
||||||
auto tCgDV = thr_copy.partition_D(gDV);
|
|
||||||
auto tCcDV = thr_copy.partition_S(cDV);
|
|
||||||
auto tCrDV = make_tensor<Element>(shape(tCcDV));
|
|
||||||
|
|
||||||
clear(tCrDV);
|
|
||||||
store(tCgDV, tCrDV, tCcDV, select<1,2>(problem_shape));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template<class BlkCoord, class BlkOffset, class ProblemShape_>
|
template<class BlkCoord>
|
||||||
CUTLASS_DEVICE void epilogue(
|
CUTLASS_DEVICE void epilogue(
|
||||||
BlkCoord const& blk_coord,
|
BlkCoord const& blk_coord,
|
||||||
BlkOffset const& blk_offset,
|
ProblemShape const& problem_shape,
|
||||||
ProblemShape_ const& problem_shape,
|
|
||||||
MainloopArguments const& mainloop_args,
|
MainloopArguments const& mainloop_args,
|
||||||
EpilogueArguments const& epilogue_args,
|
EpilogueArguments const& epilogue_args,
|
||||||
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
|
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
|
||||||
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) {
|
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) {
|
||||||
|
|
||||||
auto [Q, K, D, HB] = problem_shape;
|
auto [Q, K, D, HB] = problem_shape;
|
||||||
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch] = blk_coord;
|
auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord;
|
||||||
|
|
||||||
auto load_op = SM100_TMEM_LOAD_32dp32b16x{};
|
auto load_op = SM100_TMEM_LOAD_32dp32b16x{};
|
||||||
|
|
||||||
auto tDKtDK = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{});
|
auto tDKtDK = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{});
|
||||||
tDKtDK.data() = TmemAllocation::kDK;
|
tDKtDK.data() = TmemAllocation::kDK;
|
||||||
|
|
||||||
auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk);
|
auto mDK = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk);
|
||||||
auto mDK = domain_offset(select<1,2,3>(blk_offset), mDK_in);
|
|
||||||
auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{})
|
auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{})
|
||||||
(_, _, blk_coord_k, _0{}, blk_coord_batch);
|
(_, _, blk_coord_k, _0{}, blk_coord_batch);
|
||||||
|
|
||||||
@ -1076,13 +1000,12 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
auto tDVtDV = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{});
|
auto tDVtDV = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{});
|
||||||
tDVtDV.data() = TmemAllocation::kDV;
|
tDVtDV.data() = TmemAllocation::kDV;
|
||||||
|
|
||||||
auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv);
|
auto mDV = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv);
|
||||||
auto mDV = domain_offset(select<1,2,3>(blk_offset), mDV_in);
|
|
||||||
auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{})
|
auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{})
|
||||||
(_, _, blk_coord_k, _0{}, blk_coord_batch);
|
(_, _, blk_coord_k, _0{}, blk_coord_batch);
|
||||||
|
|
||||||
Tensor cDV = domain_offset(
|
Tensor cDV = domain_offset(
|
||||||
make_coord(blk_coord_k * TileShapeK{}, _0{}),
|
make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}),
|
||||||
make_identity_tensor(take<0,2>(TileShapePDO{}))
|
make_identity_tensor(take<0,2>(TileShapePDO{}))
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -1126,11 +1049,10 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template<class BlkCoord, class BlkOffset, class ProblemShape_>
|
template<class BlkCoord>
|
||||||
CUTLASS_DEVICE void compute(
|
CUTLASS_DEVICE void compute(
|
||||||
BlkCoord const& blk_coord,
|
BlkCoord const& blk_coord,
|
||||||
BlkOffset const& blk_offset,
|
ProblemShape const& problem_shape,
|
||||||
ProblemShape_ const& problem_shape,
|
|
||||||
int iter_index,
|
int iter_index,
|
||||||
int iter_count,
|
int iter_count,
|
||||||
MainloopArguments const& mainloop_args,
|
MainloopArguments const& mainloop_args,
|
||||||
@ -1151,7 +1073,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
|
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
|
||||||
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) {
|
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) {
|
||||||
|
|
||||||
|
|
||||||
auto [Q, K, D, HB] = problem_shape;
|
auto [Q, K, D, HB] = problem_shape;
|
||||||
|
|
||||||
// in tmem, S & P overlap
|
// in tmem, S & P overlap
|
||||||
@ -1192,7 +1114,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
Tensor tTR_cST = split_wg(thread_t2r.partition_D(cST));
|
Tensor tTR_cST = split_wg(thread_t2r.partition_D(cST));
|
||||||
Tensor tTR_rST = make_tensor<ElementAcc>(shape(tTR_cST));
|
Tensor tTR_rST = make_tensor<ElementAcc>(shape(tTR_cST));
|
||||||
Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST));
|
Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST));
|
||||||
|
|
||||||
Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT);
|
Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT);
|
||||||
Tensor tTR_cDPT = split_wg(tTR_cDPT_p);
|
Tensor tTR_cDPT = split_wg(tTR_cDPT_p);
|
||||||
Tensor tTR_rDPT = make_tensor<ElementAcc>(shape(tTR_cDPT));
|
Tensor tTR_rDPT = make_tensor<ElementAcc>(shape(tTR_cDPT));
|
||||||
@ -1214,9 +1136,6 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
auto tRT_tP = split_wg(thread_r2t.partition_D(tDVrP));
|
auto tRT_tP = split_wg(thread_r2t.partition_D(tDVrP));
|
||||||
auto tRT_cST = split_wg(thread_r2t.partition_S(tDVcST));
|
auto tRT_cST = split_wg(thread_r2t.partition_S(tDVcST));
|
||||||
|
|
||||||
bool is_residual_k = get<1>(blk_coord) * TileShapeK{} + TileShapeK{} >= get<1>(problem_shape);
|
|
||||||
int last_iter = iter_count - 1 + iter_index;
|
|
||||||
|
|
||||||
CUTLASS_PRAGMA_NO_UNROLL
|
CUTLASS_PRAGMA_NO_UNROLL
|
||||||
while (iter_count > 0) {
|
while (iter_count > 0) {
|
||||||
// wait for S and P
|
// wait for S and P
|
||||||
@ -1233,28 +1152,20 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
fn(cute::false_type{});
|
fn(cute::false_type{});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
bool leading_causal_masking = false;
|
dispatch_bool(std::is_base_of_v<cutlass::fmha::collective::CausalMask, Mask> &&
|
||||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask, Mask>) {
|
warp_uniform(iter_index == get<1>(blk_coord)), [&](auto is_causal_masked_tile) {
|
||||||
leading_causal_masking = warp_uniform(iter_index == get<1>(blk_coord));
|
|
||||||
}
|
|
||||||
bool trailing_residual_masking = false;
|
|
||||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::ResidualMaskForBackward, Mask>) {
|
|
||||||
trailing_residual_masking = warp_uniform((iter_index == last_iter) || is_residual_k);
|
|
||||||
}
|
|
||||||
|
|
||||||
dispatch_bool(leading_causal_masking || trailing_residual_masking, [&](auto is_masked_tile) {
|
|
||||||
|
|
||||||
// compute P = softmax(S, LSE)
|
// compute P = softmax(S, LSE)
|
||||||
cute::copy(tiled_t2r, tTR_tST, tTR_rST);
|
cute::copy(tiled_t2r, tTR_tST, tTR_rST);
|
||||||
|
|
||||||
if constexpr (decltype(is_masked_tile)::value) {
|
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask, Mask> && decltype(is_causal_masked_tile)::value) {
|
||||||
Mask{}.apply_mask(tTR_rST, [&](int i) {
|
Mask{}.apply_mask(tTR_rST, [&](int i) {
|
||||||
auto c_transpose = tTR_cST(i);
|
auto c_transpose = tTR_cST(i);
|
||||||
return make_coord(get<1>(c_transpose) + iter_index * TileShapeQ{}, get<0>(c_transpose) + get<1>(blk_coord) * TileShapeK{});
|
return make_coord(get<1>(c_transpose) + iter_index * TileShapeQ{}, get<0>(c_transpose) + get<1>(blk_coord) * TileShapeK{});
|
||||||
}, problem_shape);
|
}, problem_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
ElementAcc log2_e = static_cast<ElementAcc>(M_LOG2E);
|
ElementAcc log2_e = static_cast<ElementAcc>(M_LOG2E);
|
||||||
float2 softmax_scale_log2_e;
|
float2 softmax_scale_log2_e;
|
||||||
softmax_scale_log2_e.x = mainloop_args.softmax_scale * log2_e;
|
softmax_scale_log2_e.x = mainloop_args.softmax_scale * log2_e;
|
||||||
@ -1273,16 +1184,16 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
tTR_rST(i) = ::exp2f(out.x);
|
tTR_rST(i) = ::exp2f(out.x);
|
||||||
tTR_rST(i+1) = ::exp2f(out.y);
|
tTR_rST(i+1) = ::exp2f(out.y);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto tRT_rST = quantize(tTR_rST);
|
auto tRT_rST = quantize(tTR_rST);
|
||||||
auto tRT_rST_reshaped = make_tensor(tRT_rST.data(), shape(tRT_cST));
|
auto tRT_rST_reshaped = make_tensor(tRT_rST.data(), shape(tRT_cST));
|
||||||
|
|
||||||
cutlass::arch::fence_view_async_tmem_load();
|
cutlass::arch::fence_view_async_tmem_load();
|
||||||
cutlass::arch::NamedBarrier(
|
cutlass::arch::NamedBarrier(
|
||||||
kNumComputeWarps * NumThreadsPerWarp,
|
kNumComputeWarps * NumThreadsPerWarp,
|
||||||
cutlass::arch::ReservedNamedBarriers::TransformBarrier
|
cutlass::arch::ReservedNamedBarriers::TransformBarrier
|
||||||
).arrive_and_wait();
|
).arrive_and_wait();
|
||||||
|
|
||||||
cute::copy(tiled_r2t, tRT_rST_reshaped, tRT_tP);
|
cute::copy(tiled_r2t, tRT_rST_reshaped, tRT_tP);
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -1364,15 +1275,15 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
}
|
}
|
||||||
|
|
||||||
epilogue(
|
epilogue(
|
||||||
blk_coord, blk_offset, problem_shape, mainloop_args, epilogue_args,
|
blk_coord, problem_shape, mainloop_args, epilogue_args,
|
||||||
pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state
|
pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<class BlkCoord, class ProblemShape_>
|
template<class BlkCoord>
|
||||||
CUTLASS_DEVICE void reduce(
|
CUTLASS_DEVICE void reduce(
|
||||||
BlkCoord const& blk_coord,
|
BlkCoord const& blk_coord,
|
||||||
ProblemShape_ const& problem_shape,
|
ProblemShape const& problem_shape,
|
||||||
int iter_index,
|
int iter_index,
|
||||||
int iter_count,
|
int iter_count,
|
||||||
MainloopArguments const& mainloop_args,
|
MainloopArguments const& mainloop_args,
|
||||||
@ -1382,12 +1293,12 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state,
|
typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state,
|
||||||
PipelineReduceTmaStore& pipeline_reduce_tma_store,
|
PipelineReduceTmaStore& pipeline_reduce_tma_store,
|
||||||
typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state) {
|
typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state) {
|
||||||
|
|
||||||
using X = Underscore;
|
using X = Underscore;
|
||||||
|
|
||||||
auto [Q, K, D, HB] = problem_shape;
|
auto [Q, K, D, HB] = problem_shape;
|
||||||
|
|
||||||
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch] = blk_coord;
|
auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord;
|
||||||
|
|
||||||
// must match TileShapeDQ
|
// must match TileShapeDQ
|
||||||
auto load_op = SM100_TMEM_LOAD_32dp32b32x{};
|
auto load_op = SM100_TMEM_LOAD_32dp32b32x{};
|
||||||
@ -1396,7 +1307,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
tDQtDQ.data() = TmemAllocation::kDQ;
|
tDQtDQ.data() = TmemAllocation::kDQ;
|
||||||
|
|
||||||
Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB));
|
Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB));
|
||||||
auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step<X, _1, _1>{})
|
auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step<_1, _1, X>{})
|
||||||
(_, _, _, _0{}, blk_coord_batch);
|
(_, _, _, _0{}, blk_coord_batch);
|
||||||
|
|
||||||
Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{}));
|
Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{}));
|
||||||
@ -1465,7 +1376,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
iter_index += 1;
|
iter_index += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
CUTLASS_DEVICE void operator()(Params const& params, char* smem) {
|
CUTLASS_DEVICE void operator()(Params const& params, char* smem) {
|
||||||
int warp_idx = cutlass::canonical_warp_idx_sync();
|
int warp_idx = cutlass::canonical_warp_idx_sync();
|
||||||
@ -1650,7 +1561,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
typename decltype(pipeline_compute_mma_p)::PipelineState pipeline_compute_mma_p_consumer_state;
|
typename decltype(pipeline_compute_mma_p)::PipelineState pipeline_compute_mma_p_consumer_state;
|
||||||
typename decltype(pipeline_compute_mma_ds)::PipelineState pipeline_compute_mma_ds_consumer_state;
|
typename decltype(pipeline_compute_mma_ds)::PipelineState pipeline_compute_mma_ds_consumer_state;
|
||||||
typename decltype(pipeline_mma_compute_dkdv)::PipelineState pipeline_mma_compute_dkdv_consumer_state;
|
typename decltype(pipeline_mma_compute_dkdv)::PipelineState pipeline_mma_compute_dkdv_consumer_state;
|
||||||
|
|
||||||
auto pipeline_load_mma_q_producer_state = make_producer_start_state<decltype(pipeline_load_mma_q)>();
|
auto pipeline_load_mma_q_producer_state = make_producer_start_state<decltype(pipeline_load_mma_q)>();
|
||||||
auto pipeline_load_mma_do_producer_state = make_producer_start_state<decltype(pipeline_load_mma_do)>();
|
auto pipeline_load_mma_do_producer_state = make_producer_start_state<decltype(pipeline_load_mma_do)>();
|
||||||
auto pipeline_load_compute_lse_producer_state = make_producer_start_state<decltype(pipeline_load_compute_lse)>();
|
auto pipeline_load_compute_lse_producer_state = make_producer_start_state<decltype(pipeline_load_compute_lse)>();
|
||||||
@ -1665,45 +1576,27 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
|
|
||||||
pipeline_init_wait(size(ClusterShape{}));
|
pipeline_init_wait(size(ClusterShape{}));
|
||||||
|
|
||||||
auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, make_coord(blockIdx.y, blockIdx.z));
|
auto blk_coord = make_coord(_0{}, blockIdx.x, make_coord(blockIdx.y, blockIdx.z));
|
||||||
auto [problem_shape, blk_offset] = apply_variable_length_offset(
|
auto problem_shape = params.problem_shape;
|
||||||
params.problem_shape,
|
|
||||||
blk_coord
|
|
||||||
);
|
|
||||||
int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{});
|
int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{});
|
||||||
int iter_start = 0;
|
int iter_start = 0;
|
||||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask, Mask>) {
|
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask, Mask>) {
|
||||||
iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{};
|
iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{};
|
||||||
}
|
}
|
||||||
if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
iter_count -= iter_start;
|
iter_count -= iter_start;
|
||||||
|
|
||||||
if (iter_count <= 0) {
|
|
||||||
epilogue_clear(
|
|
||||||
blk_coord,
|
|
||||||
blk_offset,
|
|
||||||
problem_shape,
|
|
||||||
params.mainloop,
|
|
||||||
params.epilogue
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (role == WarpRole::Load) {
|
if (role == WarpRole::Load) {
|
||||||
warpgroup_reg_set<RegisterAllocation::kLoad>();
|
warpgroup_reg_set<RegisterAllocation::kLoad>();
|
||||||
|
|
||||||
load(
|
load(
|
||||||
blk_coord,
|
blk_coord,
|
||||||
blk_offset,
|
|
||||||
problem_shape,
|
problem_shape,
|
||||||
iter_start,
|
iter_start,
|
||||||
iter_count,
|
iter_count,
|
||||||
params.mainloop,
|
params.mainloop,
|
||||||
params.mainloop_params,
|
params.mainloop_params,
|
||||||
shared_storage.tensors,
|
shared_storage.tensors,
|
||||||
pipeline_load_mma_q, pipeline_load_mma_q_producer_state,
|
pipeline_load_mma_q, pipeline_load_mma_q_producer_state,
|
||||||
pipeline_load_mma_do, pipeline_load_mma_do_producer_state,
|
pipeline_load_mma_do, pipeline_load_mma_do_producer_state,
|
||||||
pipeline_load_compute_lse, pipeline_load_compute_lse_producer_state,
|
pipeline_load_compute_lse, pipeline_load_compute_lse_producer_state,
|
||||||
pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_producer_state
|
pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_producer_state
|
||||||
@ -1715,7 +1608,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
|
|
||||||
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr);
|
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr);
|
||||||
__syncwarp();
|
__syncwarp();
|
||||||
|
|
||||||
mma(
|
mma(
|
||||||
blk_coord,
|
blk_coord,
|
||||||
problem_shape,
|
problem_shape,
|
||||||
@ -1723,7 +1616,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
iter_count,
|
iter_count,
|
||||||
params.mainloop,
|
params.mainloop,
|
||||||
shared_storage.tensors,
|
shared_storage.tensors,
|
||||||
pipeline_load_mma_q, pipeline_load_mma_q_consumer_state,
|
pipeline_load_mma_q, pipeline_load_mma_q_consumer_state,
|
||||||
pipeline_load_mma_do, pipeline_load_mma_do_consumer_state,
|
pipeline_load_mma_do, pipeline_load_mma_do_consumer_state,
|
||||||
pipeline_mma_compute_s, pipeline_mma_compute_s_producer_state,
|
pipeline_mma_compute_s, pipeline_mma_compute_s_producer_state,
|
||||||
pipeline_mma_compute_dp, pipeline_mma_compute_dp_producer_state,
|
pipeline_mma_compute_dp, pipeline_mma_compute_dp_producer_state,
|
||||||
@ -1736,10 +1629,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
}
|
}
|
||||||
else if (role == WarpRole::Compute) {
|
else if (role == WarpRole::Compute) {
|
||||||
warpgroup_reg_set<RegisterAllocation::kCompute>();
|
warpgroup_reg_set<RegisterAllocation::kCompute>();
|
||||||
|
|
||||||
compute(
|
compute(
|
||||||
blk_coord,
|
blk_coord,
|
||||||
blk_offset,
|
|
||||||
problem_shape,
|
problem_shape,
|
||||||
iter_start,
|
iter_start,
|
||||||
iter_count,
|
iter_count,
|
||||||
@ -1768,7 +1660,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
}
|
}
|
||||||
else if (role == WarpRole::Reduce) {
|
else if (role == WarpRole::Reduce) {
|
||||||
warpgroup_reg_set<RegisterAllocation::kReduce>();
|
warpgroup_reg_set<RegisterAllocation::kReduce>();
|
||||||
|
|
||||||
reduce(
|
reduce(
|
||||||
blk_coord,
|
blk_coord,
|
||||||
problem_shape,
|
problem_shape,
|
||||||
@ -1785,9 +1677,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
warpgroup_reg_set<RegisterAllocation::kEmpty>();
|
warpgroup_reg_set<RegisterAllocation::kEmpty>();
|
||||||
|
|
||||||
/* no-op */
|
/* no-op */
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -356,7 +356,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
|||||||
typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineO>();
|
typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineO>();
|
||||||
|
|
||||||
CollectiveMainloop mainloop;
|
CollectiveMainloop mainloop;
|
||||||
CollectiveEpilogue epilogue{params.epilogue};
|
CollectiveEpilogue epilogue;
|
||||||
|
|
||||||
if (role == WarpRole::Softmax0 || role == WarpRole::Softmax1) {
|
if (role == WarpRole::Softmax0 || role == WarpRole::Softmax1) {
|
||||||
warpgroup_reg_set<NumRegsSoftmax>();
|
warpgroup_reg_set<NumRegsSoftmax>();
|
||||||
@ -407,8 +407,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
|||||||
pipeline_s0_corr, pipeline_s0_corr_consumer_state,
|
pipeline_s0_corr, pipeline_s0_corr_consumer_state,
|
||||||
pipeline_s1_corr, pipeline_s1_corr_consumer_state,
|
pipeline_s1_corr, pipeline_s1_corr_consumer_state,
|
||||||
pipeline_mma_corr, pipeline_mma_corr_consumer_state,
|
pipeline_mma_corr, pipeline_mma_corr_consumer_state,
|
||||||
pipeline_corr_epi, pipeline_corr_epi_producer_state,
|
pipeline_corr_epi, pipeline_corr_epi_producer_state
|
||||||
epilogue
|
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -146,7 +146,7 @@ struct Sm100FmhaMlaReductionKernel {
|
|||||||
ElementAcc sum_lse = 0;
|
ElementAcc sum_lse = 0;
|
||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||||
sum_lse = sum_lse + expf(local_lse[i] - lse_max);
|
sum_lse = sum_lse + expf(local_lse[i] - params.scale * lse_max);
|
||||||
}
|
}
|
||||||
|
|
||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
@ -156,7 +156,7 @@ struct Sm100FmhaMlaReductionKernel {
|
|||||||
|
|
||||||
sum_lse = __shfl_sync(0xffffffff, sum_lse, 0);
|
sum_lse = __shfl_sync(0xffffffff, sum_lse, 0);
|
||||||
|
|
||||||
ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementAcc>::infinity() : logf(sum_lse) + lse_max;
|
ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementAcc>::infinity() : logf(sum_lse) + params.scale * lse_max;
|
||||||
if (threadIdx.x == 0 and params.ptr_lse != nullptr) {
|
if (threadIdx.x == 0 and params.ptr_lse != nullptr) {
|
||||||
gLSE(0) = global_lse;
|
gLSE(0) = global_lse;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -44,10 +44,10 @@ template<
|
|||||||
class Fusion
|
class Fusion
|
||||||
>
|
>
|
||||||
void __global__ fmha_bwd_reference_dQ_kernel(
|
void __global__ fmha_bwd_reference_dQ_kernel(
|
||||||
ProblemShape problem_shape_in,
|
ProblemShape problem_shape,
|
||||||
TensorQ mQ_in, TensorK mK_in, TensorV mV_in,
|
TensorQ mQ, TensorK mK, TensorV mV,
|
||||||
TensorO mO_in, TensorLSE mLSE_in, TensorDO mDO_in,
|
TensorO mO, TensorLSE mLSE, TensorDO mDO,
|
||||||
TensorDQ mDQ_in, /* TensorDK mDK, TensorDV mDV, */
|
TensorDQ mDQ, /* TensorDK mDK, TensorDV mDV, */
|
||||||
Fusion fusion) {
|
Fusion fusion) {
|
||||||
|
|
||||||
using namespace cute;
|
using namespace cute;
|
||||||
@ -58,28 +58,15 @@ void __global__ fmha_bwd_reference_dQ_kernel(
|
|||||||
extern __shared__ char mS_mem[];
|
extern __shared__ char mS_mem[];
|
||||||
Element* mS = reinterpret_cast<Element*>(mS_mem);
|
Element* mS = reinterpret_cast<Element*>(mS_mem);
|
||||||
|
|
||||||
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<2>(problem_shape_in)));
|
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<1>(mO)));
|
||||||
|
|
||||||
for (int idx_L = blockIdx.y; idx_L < size<3>(problem_shape_in); idx_L += gridDim.y) {
|
for (int idx_L = blockIdx.y; idx_L < size<2>(mDQ); idx_L += gridDim.y) {
|
||||||
auto [problem_shape, offset] = apply_variable_length_offset(
|
for (int idx_Q = blockIdx.x; idx_Q < size<0>(mDQ); idx_Q += gridDim.x) {
|
||||||
problem_shape_in,
|
for (int idx_K = threadIdx.x; idx_K < size<0>(mK); idx_K += blockDim.x) {
|
||||||
make_coord(_0{}, _0{}, _0{}, idx2crd(idx_L, get<3>(problem_shape_in)))
|
|
||||||
);
|
|
||||||
// problem_shape = problem_shape_in;
|
|
||||||
// offset = repeat_like(problem_shape_in, _0{});
|
|
||||||
auto mQ = domain_offset(select<0,2,3>(offset), mQ_in);
|
|
||||||
auto mK = domain_offset(select<1,2,3>(offset), mK_in);
|
|
||||||
auto mV = domain_offset(select<1,2,3>(offset), mV_in);
|
|
||||||
auto mO = domain_offset(select<0,2,3>(offset), mO_in);
|
|
||||||
auto mLSE = domain_offset(select<0,3>(offset), mLSE_in);
|
|
||||||
auto mDO = domain_offset(select<0,2,3>(offset), mDO_in);
|
|
||||||
auto mDQ = domain_offset(select<0,2,3>(offset), mDQ_in);
|
|
||||||
for (int idx_Q = blockIdx.x; idx_Q < size<0>(problem_shape); idx_Q += gridDim.x) {
|
|
||||||
for (int idx_K = threadIdx.x; idx_K < size<1>(problem_shape); idx_K += blockDim.x) {
|
|
||||||
ElementAccumulator acc_qk = 0;
|
ElementAccumulator acc_qk = 0;
|
||||||
ElementAccumulator acc_dov = 0;
|
ElementAccumulator acc_dov = 0;
|
||||||
ElementAccumulator acc_doo = 0;
|
ElementAccumulator acc_doo = 0;
|
||||||
for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) {
|
for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) {
|
||||||
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
|
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
|
||||||
acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
|
acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
|
||||||
acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L);
|
acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L);
|
||||||
@ -96,9 +83,9 @@ void __global__ fmha_bwd_reference_dQ_kernel(
|
|||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) {
|
for (int idx_D = threadIdx.x; idx_D < size<1>(mDQ); idx_D += blockDim.x) {
|
||||||
ElementAccumulator acc = 0;
|
ElementAccumulator acc = 0;
|
||||||
for (int idx_K = 0; idx_K < size<1>(problem_shape); idx_K++) {
|
for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) {
|
||||||
acc += mS[idx_K] * mK(idx_K, idx_D, idx_L);
|
acc += mS[idx_K] * mK(idx_K, idx_D, idx_L);
|
||||||
}
|
}
|
||||||
mDQ(idx_Q, idx_D, idx_L) = static_cast<typename TensorDQ::value_type>(acc);
|
mDQ(idx_Q, idx_D, idx_L) = static_cast<typename TensorDQ::value_type>(acc);
|
||||||
@ -117,10 +104,10 @@ template<
|
|||||||
class Fusion
|
class Fusion
|
||||||
>
|
>
|
||||||
void __global__ fmha_bwd_reference_dK_kernel(
|
void __global__ fmha_bwd_reference_dK_kernel(
|
||||||
ProblemShape problem_shape_in,
|
ProblemShape problem_shape,
|
||||||
TensorQ mQ_in, TensorK mK_in, TensorV mV_in,
|
TensorQ mQ, TensorK mK, TensorV mV,
|
||||||
TensorO mO_in, TensorLSE mLSE_in, TensorDO mDO_in,
|
TensorO mO, TensorLSE mLSE, TensorDO mDO,
|
||||||
/* TensorDQ mDQ_in, */ TensorDK mDK_in, /* TensorDV mDV_in, */
|
/* TensorDQ mDQ, */ TensorDK mDK, /* TensorDV mDV, */
|
||||||
Fusion fusion) {
|
Fusion fusion) {
|
||||||
|
|
||||||
using namespace cute;
|
using namespace cute;
|
||||||
@ -131,28 +118,15 @@ void __global__ fmha_bwd_reference_dK_kernel(
|
|||||||
extern __shared__ char mS_mem[];
|
extern __shared__ char mS_mem[];
|
||||||
Element* mS = reinterpret_cast<Element*>(mS_mem);
|
Element* mS = reinterpret_cast<Element*>(mS_mem);
|
||||||
|
|
||||||
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<2>(problem_shape_in)));
|
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<1>(mO)));
|
||||||
|
|
||||||
for (int idx_L = blockIdx.y; idx_L < size<3>(problem_shape_in); idx_L += gridDim.y) {
|
for (int idx_L = blockIdx.y; idx_L < size<2>(mDK); idx_L += gridDim.y) {
|
||||||
auto [problem_shape, offset] = apply_variable_length_offset(
|
for (int idx_K = blockIdx.x; idx_K < size<0>(mDK); idx_K += gridDim.x) {
|
||||||
problem_shape_in,
|
for (int idx_Q = threadIdx.x; idx_Q < size<0>(mDO); idx_Q += blockDim.x) {
|
||||||
make_coord(_0{}, _0{}, _0{}, idx2crd(idx_L, get<3>(problem_shape_in)))
|
|
||||||
);
|
|
||||||
// problem_shape = problem_shape_in;
|
|
||||||
// offset = repeat_like(problem_shape_in, _0{});
|
|
||||||
auto mQ = domain_offset(select<0,2,3>(offset), mQ_in);
|
|
||||||
auto mK = domain_offset(select<1,2,3>(offset), mK_in);
|
|
||||||
auto mV = domain_offset(select<1,2,3>(offset), mV_in);
|
|
||||||
auto mO = domain_offset(select<0,2,3>(offset), mO_in);
|
|
||||||
auto mLSE = domain_offset(select<0,3>(offset), mLSE_in);
|
|
||||||
auto mDO = domain_offset(select<0,2,3>(offset), mDO_in);
|
|
||||||
auto mDK = domain_offset(select<1,2,3>(offset), mDK_in);
|
|
||||||
for (int idx_K = blockIdx.x; idx_K < size<1>(problem_shape); idx_K += gridDim.x) {
|
|
||||||
for (int idx_Q = threadIdx.x; idx_Q < size<0>(problem_shape); idx_Q += blockDim.x) {
|
|
||||||
ElementAccumulator acc_qk = 0;
|
ElementAccumulator acc_qk = 0;
|
||||||
ElementAccumulator acc_dov = 0;
|
ElementAccumulator acc_dov = 0;
|
||||||
ElementAccumulator acc_doo = 0;
|
ElementAccumulator acc_doo = 0;
|
||||||
for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) {
|
for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) {
|
||||||
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
|
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
|
||||||
acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
|
acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
|
||||||
acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L);
|
acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L);
|
||||||
@ -169,9 +143,9 @@ void __global__ fmha_bwd_reference_dK_kernel(
|
|||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) {
|
for (int idx_D = threadIdx.x; idx_D < size<1>(mDK); idx_D += blockDim.x) {
|
||||||
ElementAccumulator acc = 0;
|
ElementAccumulator acc = 0;
|
||||||
for (int idx_Q = 0; idx_Q < size<0>(problem_shape); idx_Q++) {
|
for (int idx_Q = 0; idx_Q < size<0>(mDO); idx_Q++) {
|
||||||
acc += mS[idx_Q] * mQ(idx_Q, idx_D, idx_L);
|
acc += mS[idx_Q] * mQ(idx_Q, idx_D, idx_L);
|
||||||
}
|
}
|
||||||
mDK(idx_K, idx_D, idx_L) = static_cast<typename TensorDK::value_type>(acc);
|
mDK(idx_K, idx_D, idx_L) = static_cast<typename TensorDK::value_type>(acc);
|
||||||
@ -190,10 +164,10 @@ template<
|
|||||||
class Fusion
|
class Fusion
|
||||||
>
|
>
|
||||||
void __global__ fmha_bwd_reference_dV_kernel(
|
void __global__ fmha_bwd_reference_dV_kernel(
|
||||||
ProblemShape problem_shape_in,
|
ProblemShape problem_shape,
|
||||||
TensorQ mQ_in, TensorK mK_in, TensorV mV_in,
|
TensorQ mQ, TensorK mK, TensorV mV,
|
||||||
TensorO mO_in, TensorLSE mLSE_in, TensorDO mDO_in,
|
TensorO mO, TensorLSE mLSE, TensorDO mDO,
|
||||||
/* TensorDQ mDQ_in, TensorDK mDK_in, */ TensorDV mDV_in,
|
/* TensorDQ mDQ, TensorDK mDK, */ TensorDV mDV,
|
||||||
Fusion fusion) {
|
Fusion fusion) {
|
||||||
|
|
||||||
using namespace cute;
|
using namespace cute;
|
||||||
@ -204,27 +178,14 @@ void __global__ fmha_bwd_reference_dV_kernel(
|
|||||||
extern __shared__ char mS_mem[];
|
extern __shared__ char mS_mem[];
|
||||||
Element* mS = reinterpret_cast<Element*>(mS_mem);
|
Element* mS = reinterpret_cast<Element*>(mS_mem);
|
||||||
|
|
||||||
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<2>(problem_shape_in)));
|
ElementAcc softmax_scale = static_cast<ElementAcc>(1.0 / sqrt(1.0 * size<1>(mO)));
|
||||||
|
|
||||||
for (int idx_L = blockIdx.y; idx_L < size<3>(problem_shape_in); idx_L += gridDim.y) {
|
for (int idx_L = blockIdx.y; idx_L < size<2>(mDV); idx_L += gridDim.y) {
|
||||||
auto [problem_shape, offset] = apply_variable_length_offset(
|
for (int idx_K = blockIdx.x; idx_K < size<0>(mDV); idx_K += gridDim.x) {
|
||||||
problem_shape_in,
|
for (int idx_Q = threadIdx.x; idx_Q < size<0>(mDO); idx_Q += blockDim.x) {
|
||||||
make_coord(_0{}, _0{}, _0{}, idx2crd(idx_L, get<3>(problem_shape_in)))
|
|
||||||
);
|
|
||||||
// problem_shape = problem_shape_in;
|
|
||||||
// offset = repeat_like(problem_shape_in, _0{});
|
|
||||||
auto mQ = domain_offset(select<0,2,3>(offset), mQ_in);
|
|
||||||
auto mK = domain_offset(select<1,2,3>(offset), mK_in);
|
|
||||||
auto mV = domain_offset(select<1,2,3>(offset), mV_in);
|
|
||||||
auto mO = domain_offset(select<0,2,3>(offset), mO_in);
|
|
||||||
auto mLSE = domain_offset(select<0,3>(offset), mLSE_in);
|
|
||||||
auto mDO = domain_offset(select<0,2,3>(offset), mDO_in);
|
|
||||||
auto mDV = domain_offset(select<1,2,3>(offset), mDV_in);
|
|
||||||
for (int idx_K = blockIdx.x; idx_K < size<1>(problem_shape); idx_K += gridDim.x) {
|
|
||||||
for (int idx_Q = threadIdx.x; idx_Q < size<0>(problem_shape); idx_Q += blockDim.x) {
|
|
||||||
ElementAcc acc_qk = 0;
|
ElementAcc acc_qk = 0;
|
||||||
|
|
||||||
for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) {
|
for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) {
|
||||||
ElementAcc rQ = mQ(idx_Q, idx_D0, idx_L);
|
ElementAcc rQ = mQ(idx_Q, idx_D0, idx_L);
|
||||||
ElementAcc rK = mK(idx_K, idx_D0, idx_L);
|
ElementAcc rK = mK(idx_K, idx_D0, idx_L);
|
||||||
acc_qk += rQ * rK;
|
acc_qk += rQ * rK;
|
||||||
@ -241,9 +202,9 @@ void __global__ fmha_bwd_reference_dV_kernel(
|
|||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) {
|
for (int idx_D = threadIdx.x; idx_D < size<1>(mDV); idx_D += blockDim.x) {
|
||||||
ElementAcc acc = 0;
|
ElementAcc acc = 0;
|
||||||
for (int idx_Q = 0; idx_Q < size<0>(problem_shape); idx_Q++) {
|
for (int idx_Q = 0; idx_Q < size<0>(mDO); idx_Q++) {
|
||||||
ElementAcc rS = mS[idx_Q];
|
ElementAcc rS = mS[idx_Q];
|
||||||
ElementAcc rDO = mDO(idx_Q, idx_D, idx_L);
|
ElementAcc rDO = mDO(idx_Q, idx_D, idx_L);
|
||||||
acc += rS * rDO;
|
acc += rS * rDO;
|
||||||
|
|||||||
@ -127,7 +127,7 @@ void __global__ fmha_reference_kernel(
|
|||||||
mO(idx_Q + offset_Q, idx_D, idx_L) = static_cast<typename TensorO::value_type>(acc * scale);
|
mO(idx_Q + offset_Q, idx_D, idx_L) = static_cast<typename TensorO::value_type>(acc * scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (threadIdx.x == 0 && mLSE.data() != nullptr) {
|
if (threadIdx.x == 0) {
|
||||||
mLSE(idx_Q + offset_Q, idx_L) = log(sum) + softmax_scale * maxS;
|
mLSE(idx_Q + offset_Q, idx_L) = log(sum) + softmax_scale * maxS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -75,8 +75,6 @@ struct DeviceAllocation {
|
|||||||
|
|
||||||
size_t size() const { return size_; }
|
size_t size() const { return size_; }
|
||||||
|
|
||||||
size_t get_storage_size() const { return (size_ + offset_) * sizeof(T); }
|
|
||||||
|
|
||||||
void copy_from_host(const T* ptr, size_t sz) {
|
void copy_from_host(const T* ptr, size_t sz) {
|
||||||
auto ret = cudaMemcpy(ptr_, ptr, sz * sizeof(T), cudaMemcpyDefault);
|
auto ret = cudaMemcpy(ptr_, ptr, sz * sizeof(T), cudaMemcpyDefault);
|
||||||
assert(ret == cudaSuccess);
|
assert(ret == cudaSuccess);
|
||||||
|
|||||||
Reference in New Issue
Block a user