Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e9a75581fe | |||
| ac210faef8 | |||
| 15f5468872 | |||
| af5519d938 | |||
| 415d587ebf |
@ -0,0 +1,712 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Grouped scale Hopper FP8 GEMM example using CUTLASS 3.0 APIs for NVIDIA Hopper architecture
|
||||
This example demonstrate a grouped scaled FP8 GEMM using the new CUTLASS 3.0.
|
||||
APIs on NVIDIA Hopper architecture. New features that will be showcased in this example are as follows:
|
||||
1. NVIDIA Hopper architecture introduces a new series of tensor core instructions (GMMA)
|
||||
which are more efficient than the Ampere tensor core instructions.
|
||||
2. NVIDIA Hopper architecture includes new Tensor Memory Accelerator (TMA) unit to transfer large
|
||||
blocks of data efficiently between global memory and shared memory. TMA also supports asynchronous
|
||||
copies between thread blocks in a cluster.
|
||||
3. This example uses the Warp Specialized kernel design (see /media/docs/efficient_gemm.md for details).
|
||||
4. This example shows all important fusions used by FP8 gemm kernels, i.e., grouped scale factor along M for
|
||||
A, blocked scale factor along K for A tensor, blocked scale factor for B tensor, the abs_max value of D tensor.
|
||||
5. A simple way to tune the CTA rasterization direction and swizzle pattern of Hopper kernels. Both the
|
||||
CTA rasterization direction and swizzle pattern impact cross-CTA locality of accesses. By tuning we can
|
||||
improve performance.
|
||||
Examples:
|
||||
$ ./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_deepgemm \
|
||||
--m=4096 --iterations=1000
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
// #include "cutlass/util/reference/host/tensor_copy.h"
|
||||
// #include "cutlass/util/reference/host/tensor_compare.h"
|
||||
// #include "cutlass/util/reference/host/tensor_norm.h"
|
||||
|
||||
// Includes from examples directory
|
||||
#include "helper.h"
|
||||
// #include "reference/host/gemm_with_groupwise_scaling.h"
|
||||
|
||||
#include "deep_gemm/include/deep_gemm/fp8_gemm.cuh"
|
||||
|
||||
// using namespace cute;
|
||||
using namespace deep_gemm;
|
||||
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
bool help;
|
||||
int iterations;
|
||||
int m, n, k, num_groups;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(4096),
|
||||
n(4096),
|
||||
k(4096),
|
||||
num_groups(4),
|
||||
iterations(10)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
Options defaults;
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m, defaults.m);
|
||||
cmd.get_cmd_line_argument("num_groups", num_groups, defaults.num_groups);
|
||||
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "67_hopper_fp8_deepgemm\n\n"
|
||||
<< " Hopper FP8 DeepGEMM kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the m size\n"
|
||||
<< " --num_groups=<int> Sets the number of groups\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
constexpr int cdiv(int a, int b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
|
||||
// #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
int bits_output = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
} else if (bits_output == 16) {
|
||||
scope_max = 5;
|
||||
scope_min = -5;
|
||||
} else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Not implementated.");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Helper to initialize a block of device data (scale_tensors)
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_scale_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
|
||||
scope_min = -1;
|
||||
scope_max = 1;
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Not implementated.");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
/// Todo: add reference check
|
||||
bool verify(const Options &options) {
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
struct TestGemm {
|
||||
using Element = cutlass::float_e4m3_t;
|
||||
using ElementScale = float;
|
||||
using ElementAcc = float;
|
||||
using ElementOut = cutlass::bfloat16_t;
|
||||
|
||||
cutlass::HostTensor<Element, cutlass::layout::RowMajor> Tensor_lhs;
|
||||
cutlass::HostTensor<Element, cutlass::layout::RowMajor> Tensor_rhs;
|
||||
cutlass::HostTensor<ElementScale, cutlass::layout::ColumnMajor> Tensor_lhs_scale;
|
||||
cutlass::HostTensor<ElementScale, cutlass::layout::RowMajor> Tensor_rhs_scale;
|
||||
cutlass::HostTensor<ElementOut, cutlass::layout::RowMajor> Tensor_out;
|
||||
|
||||
|
||||
/// Initialize operands to be used in the GEMM
|
||||
void initialize(
|
||||
const Options &options,
|
||||
uint64_t seed = 2025) {
|
||||
|
||||
Tensor_lhs.resize({options.m, options.k}); //[m, k]
|
||||
Tensor_rhs.resize({options.n, options.k}); //[n, k]
|
||||
Tensor_lhs_scale.resize({options.m, cdiv(options.k, 128)}); // [m, cdiv(k, 128)] column major
|
||||
Tensor_rhs_scale.resize({cdiv(options.n, 128), cdiv(options.k, 128)}); // [cdiv(n, 128), cdiv(k, 128)]
|
||||
Tensor_out.resize({options.m, options.n}); // [m, n]
|
||||
|
||||
initialize_tensor(Tensor_lhs.host_view(), cutlass::Distribution::Uniform, seed + 1);
|
||||
initialize_tensor(Tensor_rhs.host_view(), cutlass::Distribution::Uniform, seed + 2);
|
||||
initialize_scale_tensor(Tensor_lhs_scale.host_view(), cutlass::Distribution::Uniform, seed + 3);
|
||||
initialize_scale_tensor(Tensor_rhs_scale.host_view(), cutlass::Distribution::Uniform, seed + 4);
|
||||
|
||||
Tensor_lhs.sync_device();
|
||||
Tensor_rhs.sync_device();
|
||||
Tensor_lhs_scale.sync_device();
|
||||
Tensor_rhs_scale.sync_device();
|
||||
Tensor_out.sync_device();
|
||||
|
||||
}
|
||||
|
||||
void run(Options &options)
|
||||
{
|
||||
cudaDeviceProp props;
|
||||
int current_device;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device));
|
||||
|
||||
initialize(options);
|
||||
|
||||
cudaStream_t stream{nullptr};
|
||||
constexpr auto N = 4096;
|
||||
constexpr auto K = 4096;
|
||||
constexpr auto BLOCK_M = 128;
|
||||
constexpr auto BLOCK_N = 128;
|
||||
constexpr auto kNumStages = 5;
|
||||
constexpr auto kNumTMAMulticast = 2;
|
||||
const int num_sms = 132; // for H100
|
||||
const int best_smem_size = 199376;
|
||||
|
||||
// Make a templated GEMM
|
||||
using GemmKernel = Gemm<N, K, BLOCK_M, BLOCK_N, 128, 1, kNumStages, kNumTMAMulticast, GemmType::Normal>;
|
||||
|
||||
int m = options.m;
|
||||
// DeepGEMM requires __nv_fp8_e4m3 input and __nv_bfloat16 output
|
||||
__nv_fp8_e4m3* lhs = reinterpret_cast<__nv_fp8_e4m3*>(Tensor_lhs.device_data());
|
||||
__nv_fp8_e4m3* rhs = reinterpret_cast<__nv_fp8_e4m3*>(Tensor_rhs.device_data());
|
||||
float* lhs_scales = Tensor_lhs_scale.device_data();
|
||||
float* rhs_scales = Tensor_rhs_scale.device_data();
|
||||
__nv_bfloat16* out = reinterpret_cast<__nv_bfloat16*>(Tensor_out.device_data());
|
||||
|
||||
// Launch kernel
|
||||
auto tma_a_desc = GemmKernel::make_2d_tma_a_desc(lhs, m);
|
||||
auto tma_b_desc = GemmKernel::make_2d_tma_b_desc(rhs);
|
||||
auto tma_scales_a_desc = GemmKernel::make_2d_tma_scales_a_desc(lhs_scales, m);
|
||||
auto tma_d_desc = GemmKernel::make_2d_tma_d_desc(out, m);
|
||||
GemmKernel::run(out, rhs_scales, nullptr,
|
||||
m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
||||
stream, num_sms, best_smem_size);
|
||||
|
||||
CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
std::cout << "run Gemm...\n";
|
||||
// TODO: reference check
|
||||
Result result;
|
||||
// result.passed = verify(options, ScaleMsPerTile, ScaleNsPerTile);
|
||||
|
||||
// std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
// if (!result.passed) {
|
||||
// exit(-1);
|
||||
// }
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
// initialize(options);
|
||||
GemmKernel::run(out, rhs_scales, nullptr,
|
||||
m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
||||
stream, num_sms, best_smem_size);
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
|
||||
std::cout << " Tile shape (M, N, K): (128, 128, 128)" << std::endl;
|
||||
std::cout << " ScaleGranularityM: 1 (ScaleMsPerTile: 128)" << std::endl;
|
||||
std::cout << " ScaleGranularityN: 128 (ScaleNsPerTile: 1)" << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
fflush(stdout);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct TestGroupedGemm_Contiguous {
|
||||
using Element = cutlass::float_e4m3_t;
|
||||
using ElementScale = float;
|
||||
using ElementAcc = float;
|
||||
using ElementOut = cutlass::bfloat16_t;
|
||||
|
||||
cutlass::HostTensor<Element, cutlass::layout::RowMajor> Tensor_lhs;
|
||||
cutlass::HostTensor<Element, cutlass::layout::RowMajor> Tensor_rhs;
|
||||
cutlass::HostTensor<ElementScale, cutlass::layout::ColumnMajor> Tensor_lhs_scale;
|
||||
cutlass::HostTensor<ElementScale, cutlass::layout::RowMajor> Tensor_rhs_scale;
|
||||
cutlass::HostTensor<ElementOut, cutlass::layout::RowMajor> Tensor_out;
|
||||
cutlass::HostTensor<int, cutlass::layout::RowMajor> Tensor_grouped_layout;
|
||||
|
||||
/// Initialize operands to be used in the GEMM
|
||||
void initialize(
|
||||
const Options &options,
|
||||
uint64_t seed = 2025) {
|
||||
|
||||
Tensor_lhs.resize({options.m, options.k}); //[m, k]
|
||||
Tensor_rhs.resize({options.num_groups * options.n, options.k}); //[num_groups, n, k]
|
||||
Tensor_lhs_scale.resize({options.m, cdiv(options.k, 128)}); // [m, cdiv(k, 128)] column major
|
||||
Tensor_rhs_scale.resize({options.num_groups * cdiv(options.n, 128), cdiv(options.k, 128)}); // [num_groups, cdiv(n, 128), cdiv(k, 128)]
|
||||
Tensor_out.resize({options.m, options.n}); // [m, n]
|
||||
Tensor_grouped_layout.resize({1,options.m}); // [num_groups,]
|
||||
|
||||
std::vector<int> group_start {0, options.m/4, 2*options.m/4, 3*options.m/4, options.m}; // sum(grouped_layout) = options.m
|
||||
for (int i = 0; i < options.m; ++i) {
|
||||
for(int j = 0; j < options.num_groups; ++j) {
|
||||
if(i >= group_start[j] && i < group_start[j+1]) {
|
||||
Tensor_grouped_layout.host_data()[i] = j;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
initialize_tensor(Tensor_lhs.host_view(), cutlass::Distribution::Uniform, seed + 1);
|
||||
initialize_tensor(Tensor_rhs.host_view(), cutlass::Distribution::Uniform, seed + 2);
|
||||
initialize_scale_tensor(Tensor_lhs_scale.host_view(), cutlass::Distribution::Uniform, seed + 3);
|
||||
initialize_scale_tensor(Tensor_rhs_scale.host_view(), cutlass::Distribution::Uniform, seed + 4);
|
||||
|
||||
Tensor_lhs.sync_device();
|
||||
Tensor_rhs.sync_device();
|
||||
Tensor_lhs_scale.sync_device();
|
||||
Tensor_rhs_scale.sync_device();
|
||||
Tensor_out.sync_device();
|
||||
Tensor_grouped_layout.sync_device();
|
||||
|
||||
}
|
||||
|
||||
void run(Options &options)
|
||||
{
|
||||
cudaDeviceProp props;
|
||||
int current_device;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device));
|
||||
|
||||
initialize(options);
|
||||
|
||||
cudaStream_t stream{nullptr};
|
||||
constexpr auto N = 4096;
|
||||
constexpr auto K = 4096;
|
||||
constexpr auto BLOCK_M = 128;
|
||||
constexpr auto BLOCK_N = 128;
|
||||
constexpr auto num_groups = 4;
|
||||
constexpr auto kNumStages = 5;
|
||||
constexpr auto kNumTMAMulticast = 2;
|
||||
const int num_sms = 132; // for H100
|
||||
const int best_smem_size = 199376;
|
||||
|
||||
// Make a templated GEMM
|
||||
using GemmKernel = Gemm<N, K, BLOCK_M, BLOCK_N, 128, num_groups, kNumStages, kNumTMAMulticast, GemmType::GroupedContiguous>;
|
||||
|
||||
int m = options.m;
|
||||
// DeepGEMM requires __nv_fp8_e4m3 input and __nv_bfloat16 output
|
||||
__nv_fp8_e4m3* lhs = reinterpret_cast<__nv_fp8_e4m3*>(Tensor_lhs.device_data());
|
||||
__nv_fp8_e4m3* rhs = reinterpret_cast<__nv_fp8_e4m3*>(Tensor_rhs.device_data());
|
||||
float* lhs_scales = Tensor_lhs_scale.device_data();
|
||||
float* rhs_scales = Tensor_rhs_scale.device_data();
|
||||
__nv_bfloat16* out = reinterpret_cast<__nv_bfloat16*>(Tensor_out.device_data());
|
||||
int* grouped_layout = Tensor_grouped_layout.device_data();
|
||||
// Launch kernel
|
||||
auto tma_a_desc = GemmKernel::make_2d_tma_a_desc(lhs, m);
|
||||
auto tma_b_desc = GemmKernel::make_2d_tma_b_desc(rhs);
|
||||
auto tma_scales_a_desc = GemmKernel::make_2d_tma_scales_a_desc(lhs_scales, m);
|
||||
auto tma_d_desc = GemmKernel::make_2d_tma_d_desc(out, m);
|
||||
GemmKernel::run(out, rhs_scales, grouped_layout,
|
||||
m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
||||
stream, num_sms, best_smem_size);
|
||||
|
||||
CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
std::cout << "run GroupedGemm Contiguous...\n";
|
||||
// TODO: reference check
|
||||
Result result;
|
||||
// result.passed = verify(options, ScaleMsPerTile, ScaleNsPerTile);
|
||||
|
||||
// std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
// if (!result.passed) {
|
||||
// exit(-1);
|
||||
// }
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
// initialize(options);
|
||||
GemmKernel::run(out, rhs_scales, grouped_layout,
|
||||
m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
||||
stream, num_sms, best_smem_size);
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
|
||||
std::cout << " Number of groups: " << options.num_groups << std::endl;
|
||||
std::cout << " Tile shape (M, N, K): (128, 128, 128)" << std::endl;
|
||||
std::cout << " ScaleGranularityM: 1 (ScaleMsPerTile: 128)" << std::endl;
|
||||
std::cout << " ScaleGranularityN: 128 (ScaleNsPerTile: 1)" << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
fflush(stdout);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct TestGroupedGemm_Masked {
|
||||
using Element = cutlass::float_e4m3_t;
|
||||
using ElementScale = float;
|
||||
using ElementAcc = float;
|
||||
using ElementOut = cutlass::bfloat16_t;
|
||||
|
||||
cutlass::HostTensor<Element, cutlass::layout::RowMajor> Tensor_lhs;
|
||||
cutlass::HostTensor<Element, cutlass::layout::RowMajor> Tensor_rhs;
|
||||
cutlass::HostTensor<ElementScale, cutlass::layout::ColumnMajor> Tensor_lhs_scale;
|
||||
cutlass::HostTensor<ElementScale, cutlass::layout::RowMajor> Tensor_rhs_scale;
|
||||
cutlass::HostTensor<ElementOut, cutlass::layout::RowMajor> Tensor_out;
|
||||
cutlass::HostTensor<int, cutlass::layout::RowMajor> Tensor_masked_m;
|
||||
|
||||
/// Initialize operands to be used in the GEMM
|
||||
void initialize(
|
||||
const Options &options,
|
||||
uint64_t seed = 2025) {
|
||||
|
||||
int m_max = options.m;
|
||||
Tensor_lhs.resize({options.num_groups * m_max, options.k}); //[num_groups, m, k]
|
||||
Tensor_rhs.resize({options.num_groups * options.n, options.k}); //[num_groups, n, k]
|
||||
Tensor_lhs_scale.resize({options.num_groups * m_max, cdiv(options.k, 128)}); // [num_groups, m, cdiv(k, 128)] column major
|
||||
Tensor_rhs_scale.resize({options.num_groups * cdiv(options.n, 128), cdiv(options.k, 128)}); // [num_groups, cdiv(n, 128), cdiv(k, 128)]
|
||||
Tensor_out.resize({options.num_groups * m_max, options.n}); // [num_groups, m, n]
|
||||
Tensor_masked_m.resize({1,options.num_groups}); // [num_groups,]
|
||||
|
||||
std::vector<int> masked_m {options.m/4,2*options.m/4,3*options.m/4,options.m}; // max(masked_m) <= options.m
|
||||
for (int i = 0; i < options.num_groups; ++i) {
|
||||
Tensor_masked_m.host_data()[i] = masked_m[i];
|
||||
}
|
||||
|
||||
initialize_tensor(Tensor_lhs.host_view(), cutlass::Distribution::Uniform, seed + 1);
|
||||
initialize_tensor(Tensor_rhs.host_view(), cutlass::Distribution::Uniform, seed + 2);
|
||||
initialize_scale_tensor(Tensor_lhs_scale.host_view(), cutlass::Distribution::Uniform, seed + 3);
|
||||
initialize_scale_tensor(Tensor_rhs_scale.host_view(), cutlass::Distribution::Uniform, seed + 4);
|
||||
|
||||
Tensor_lhs.sync_device();
|
||||
Tensor_rhs.sync_device();
|
||||
Tensor_lhs_scale.sync_device();
|
||||
Tensor_rhs_scale.sync_device();
|
||||
Tensor_out.sync_device();
|
||||
Tensor_masked_m.sync_device();
|
||||
|
||||
}
|
||||
|
||||
void run(Options &options)
|
||||
{
|
||||
cudaDeviceProp props;
|
||||
int current_device;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device));
|
||||
|
||||
initialize(options);
|
||||
|
||||
cudaStream_t stream{nullptr};
|
||||
constexpr auto N = 4096;
|
||||
constexpr auto K = 4096;
|
||||
constexpr auto BLOCK_M = 128;
|
||||
constexpr auto BLOCK_N = 128;
|
||||
constexpr auto num_groups = 4;
|
||||
constexpr auto kNumStages = 5;
|
||||
constexpr auto kNumTMAMulticast = 2;
|
||||
const int num_sms = 132; // for H100
|
||||
const int best_smem_size = 199376;
|
||||
|
||||
// Make a templated GEMM
|
||||
using GemmKernel = Gemm<N, K, BLOCK_M, BLOCK_N, 128, num_groups, kNumStages, kNumTMAMulticast, GemmType::GroupedMasked>;
|
||||
|
||||
int m = options.m;
|
||||
// DeepGEMM requires __nv_fp8_e4m3 input and __nv_bfloat16 output
|
||||
__nv_fp8_e4m3* lhs = reinterpret_cast<__nv_fp8_e4m3*>(Tensor_lhs.device_data());
|
||||
__nv_fp8_e4m3* rhs = reinterpret_cast<__nv_fp8_e4m3*>(Tensor_rhs.device_data());
|
||||
float* lhs_scales = Tensor_lhs_scale.device_data();
|
||||
float* rhs_scales = Tensor_rhs_scale.device_data();
|
||||
__nv_bfloat16* out = reinterpret_cast<__nv_bfloat16*>(Tensor_out.device_data());
|
||||
int* masked_m = Tensor_masked_m.device_data();
|
||||
// Launch kernel
|
||||
auto tma_a_desc = GemmKernel::make_2d_tma_a_desc(lhs, m);
|
||||
auto tma_b_desc = GemmKernel::make_2d_tma_b_desc(rhs);
|
||||
auto tma_scales_a_desc = GemmKernel::make_2d_tma_scales_a_desc(lhs_scales, m);
|
||||
auto tma_d_desc = GemmKernel::make_2d_tma_d_desc(out, m);
|
||||
GemmKernel::run(out, rhs_scales, masked_m,
|
||||
m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
||||
stream, num_sms, best_smem_size);
|
||||
|
||||
CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
std::cout << "run GroupedGemm Contiguous...\n";
|
||||
// TODO: reference check
|
||||
Result result;
|
||||
// result.passed = verify(options, ScaleMsPerTile, ScaleNsPerTile);
|
||||
|
||||
// std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
// if (!result.passed) {
|
||||
// exit(-1);
|
||||
// }
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
// initialize(options);
|
||||
GemmKernel::run(out, rhs_scales, masked_m,
|
||||
m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
||||
stream, num_sms, best_smem_size);
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Problem Size: M " << 'x' << options.n << 'x' << options.k << std::endl;
|
||||
std::cout << " Number of groups: " << options.num_groups << std::endl;
|
||||
std::cout << " Number of masked rows: " ;
|
||||
for (int i = 0; i < options.num_groups; ++i) {
|
||||
std::cout << Tensor_masked_m.host_data()[i] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
std::cout << " Tile shape (M, N, K): (128, 128, 128)" << std::endl;
|
||||
std::cout << " ScaleGranularityM: 1 (ScaleMsPerTile: 128)" << std::endl;
|
||||
std::cout << " ScaleGranularityN: 128 (ScaleNsPerTile: 1)" << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
fflush(stdout);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
// and must have compute capability at least 90.
|
||||
if (__CUDACC_VER_MAJOR__ < 12) {
|
||||
std::cerr << "This example requires CUDA 12 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major != 9) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
#if defined (CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
TestGemm testgemm{};
|
||||
testgemm.run(options);
|
||||
|
||||
TestGroupedGemm_Contiguous testgroupedgemm_contiguous{};
|
||||
testgroupedgemm_contiguous.run(options);
|
||||
|
||||
TestGroupedGemm_Masked testgroupedgemm_masked{};
|
||||
testgroupedgemm_masked.run(options);
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -35,3 +35,8 @@ cutlass_example_add_executable(
|
||||
67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling
|
||||
67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
67_hopper_fp8_deepgemm
|
||||
67_hopper_fp8_deepgemm.cu
|
||||
)
|
||||
|
||||
@ -0,0 +1,13 @@
|
||||
import torch
|
||||
|
||||
from . import jit
|
||||
from .jit_kernels import (
|
||||
gemm_fp8_fp8_bf16_nt,
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
|
||||
cell_div,
|
||||
set_num_sms, get_num_sms,
|
||||
get_col_major_tma_aligned_tensor,
|
||||
get_m_alignment_for_contiguous_layout
|
||||
)
|
||||
from .utils import bench, bench_kineto, calc_diff
|
||||
@ -0,0 +1,444 @@
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wunknown-attributes"
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/arch/barrier.h>
|
||||
#include <cutlass/arch/reg_reconfig.h>
|
||||
|
||||
#include <cute/arch/cluster_sm90.hpp>
|
||||
#include <cute/arch/copy_sm90_desc.hpp>
|
||||
#include <cute/arch/copy_sm90_tma.hpp>
|
||||
|
||||
#include "mma_utils.cuh"
|
||||
#include "scheduler.cuh"
|
||||
#include "tma_utils.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
enum class Layout {
|
||||
RowMajor,
|
||||
ColMajor
|
||||
};
|
||||
|
||||
template <uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup>
|
||||
__device__ __host__ constexpr int get_num_threads_per_sm(int block_m) {
|
||||
DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group");
|
||||
return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads;
|
||||
}
|
||||
|
||||
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t kNumGroups, uint32_t kNumStages,
|
||||
uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup,
|
||||
uint32_t kNumTMAMulticast,
|
||||
GemmType kGemmType>
|
||||
__global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M), 1)
|
||||
fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
uint32_t shape_m,
|
||||
const __grid_constant__ CUtensorMap tensor_map_a,
|
||||
const __grid_constant__ CUtensorMap tensor_map_b,
|
||||
const __grid_constant__ CUtensorMap tensor_map_scales_a,
|
||||
const __grid_constant__ CUtensorMap tensor_map_d) {
|
||||
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
|
||||
// Scaling checks
|
||||
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
|
||||
DG_STATIC_ASSERT(cell_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block");
|
||||
|
||||
// Types
|
||||
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
|
||||
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
||||
|
||||
// Shared memory
|
||||
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16);
|
||||
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
|
||||
static constexpr uint32_t SHAPE_K_SCALES = cell_div(SHAPE_K, BLOCK_K);
|
||||
static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
|
||||
|
||||
// Configs
|
||||
constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K;
|
||||
constexpr uint32_t kNumThreads = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
|
||||
constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads;
|
||||
constexpr uint32_t kNumIterations = cell_div(SHAPE_K, kFullKOfAllStages);
|
||||
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
const uint32_t lane_idx = get_lane_id();
|
||||
|
||||
// Prefetch TMA descriptors at very beginning
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_a));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_b));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_scales_a));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_d));
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Align to 1024 bytes for swizzle-128B
|
||||
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
||||
DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
|
||||
|
||||
// Data on shared memory
|
||||
auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer);
|
||||
__nv_fp8_e4m3* smem_a[kNumStages];
|
||||
__nv_fp8_e4m3* smem_b[kNumStages];
|
||||
float* smem_scales_a[kNumStages];
|
||||
float* smem_scales_b;
|
||||
|
||||
// TMA Barrier for both divisible and non-divisible cases
|
||||
Barrier* full_barriers[kNumStages];
|
||||
Barrier* empty_barriers[kNumStages];
|
||||
|
||||
// Fill shared memory pointers
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumStages; ++ i) {
|
||||
smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
||||
smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
||||
smem_scales_a[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SCALES_A_SIZE_PER_STAGE);
|
||||
}
|
||||
smem_scales_b = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE));
|
||||
|
||||
// Fill barriers
|
||||
DG_STATIC_ASSERT(sizeof(Barrier) % sizeof(float) == 0, "Misaligned barriers");
|
||||
DG_STATIC_ASSERT(not kMustUseUniformedScaleB or SHAPE_K_SCALES % (sizeof(Barrier) / sizeof(float)) == 0, "Misaligned barriers");
|
||||
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_scales_b + SHAPE_K_SCALES * (kMustUseUniformedScaleB ? 1 : 2));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumStages; ++ i) {
|
||||
full_barriers[i] = barrier_start_ptr + i;
|
||||
empty_barriers[i] = barrier_start_ptr + kNumStages + i;
|
||||
}
|
||||
|
||||
// Initialize barriers
|
||||
DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "To many TMA multicast");
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumStages; ++ i) {
|
||||
full_barriers[i]->init(1);
|
||||
empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32);
|
||||
}
|
||||
|
||||
// Make initialized barrier visible in async proxy
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
(kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void();
|
||||
}
|
||||
|
||||
// Synchronize all threads to make barrier visible in normal memory model
|
||||
(kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
|
||||
|
||||
// For pipeline unrolling
|
||||
struct DivisibleK {};
|
||||
struct NotDivisibleK {};
|
||||
auto launch_k_iterations = [](const auto& func) {
|
||||
if constexpr (SHAPE_K % kFullKOfAllStages == 0) {
|
||||
for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter)
|
||||
func(k_iter, DivisibleK{});
|
||||
} else {
|
||||
for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter)
|
||||
func(k_iter, DivisibleK{});
|
||||
func(kNumIterations - 1, NotDivisibleK{});
|
||||
}
|
||||
};
|
||||
|
||||
// Register reconfigurations
|
||||
constexpr int kNumTMARegisters = 40;
|
||||
constexpr int kNumMathRegisters = 232;
|
||||
|
||||
// Block scheduler
|
||||
uint32_t m_block_idx, n_block_idx;
|
||||
auto scheduler = Scheduler<kGemmType, SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast>(shape_m, grouped_layout);
|
||||
|
||||
if (threadIdx.x >= kNumMathThreads) {
|
||||
// TMA warp-group for loading data
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
||||
|
||||
// NOTES: only one thread (or warp) will be used
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
launch_k_iterations([&](int k_iter, auto type) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
|
||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Wait consumer release
|
||||
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
|
||||
|
||||
// Issue TMA A with broadcasting
|
||||
auto& full_barrier = *full_barriers[s];
|
||||
int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
|
||||
tma_copy<kNumTMAMulticast>(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
|
||||
tma_copy<kNumTMAMulticast>(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_scales_a[s], m_block_idx * BLOCK_M,
|
||||
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
|
||||
|
||||
// Issue TMA B without broadcasting
|
||||
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
|
||||
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE);
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
|
||||
full_barriers[s]->arrive();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// To safely deconstruct distributed shared barriers, we need another round of empty waits
|
||||
if constexpr (kNumTMAMulticast > 1) {
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumStages; ++ s)
|
||||
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Math warp-groups for WGMMA
|
||||
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
||||
|
||||
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
|
||||
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0);
|
||||
const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8;
|
||||
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
// Decide the number of scales B to load
|
||||
DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N");
|
||||
uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters;
|
||||
if constexpr (not kMustUseUniformedScaleB) {
|
||||
num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8;
|
||||
num_full_iters = min(SHAPE_N - n_block_idx * BLOCK_N, BLOCK_N) / 8;
|
||||
}
|
||||
uint32_t num_scales_b = SHAPE_K_SCALES * (num_former_iters >= num_full_iters ? 1 : 2);
|
||||
|
||||
// Load B scales with math warp-groups
|
||||
// NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks
|
||||
if (threadIdx.x >= 32) {
|
||||
auto num_previous_lines = scheduler.get_global_idx<false>(cell_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx);
|
||||
auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES;
|
||||
#pragma unroll
|
||||
for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32)
|
||||
st_shared(smem_scales_b + i, __ldg(local_scales_b + i));
|
||||
}
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
|
||||
// Accumulation for WGMMA or CUDA promotion
|
||||
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0};
|
||||
|
||||
// Empty barrier arrival
|
||||
auto empty_barrier_arrive = [&](int s) {
|
||||
if constexpr (kNumTMAMulticast == 1) {
|
||||
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
|
||||
} else {
|
||||
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void();
|
||||
}
|
||||
};
|
||||
|
||||
// Launch MMAs
|
||||
launch_k_iterations([&](int k_iter, auto type) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
|
||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||
|
||||
#pragma unroll
|
||||
for (int s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Read B scales
|
||||
float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1;
|
||||
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
|
||||
if constexpr (not kMustUseUniformedScaleB)
|
||||
scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES);
|
||||
|
||||
// Wait TMA arrivals
|
||||
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
|
||||
|
||||
// Read A scales
|
||||
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
|
||||
auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1);
|
||||
|
||||
// Commit WGMMA instructions
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_arrive();
|
||||
#pragma unroll
|
||||
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
|
||||
auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1);
|
||||
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
|
||||
WGMMA::wgmma(desc_a, desc_b, accum, k);
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_wait<0>();
|
||||
|
||||
// Notify barrier arrival
|
||||
empty_barrier_arrive(s);
|
||||
|
||||
// Promote with scales
|
||||
float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0;
|
||||
float scale_0_1, scale_1_1;
|
||||
if constexpr (not kMustUseUniformedScaleB)
|
||||
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
||||
bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
|
||||
final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
|
||||
final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
|
||||
final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
|
||||
final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
|
||||
}
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
|
||||
empty_barrier_arrive(s);
|
||||
}
|
||||
});
|
||||
|
||||
// Write back to shared memory using STSM
|
||||
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
|
||||
#pragma unroll
|
||||
for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) {
|
||||
SM90_U32x4_STSM_N<nv_bfloat162>::copy(
|
||||
__float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}),
|
||||
__float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}),
|
||||
__float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}),
|
||||
__float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}),
|
||||
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16)
|
||||
);
|
||||
}
|
||||
if constexpr (WGMMA::kNumAccum % 8 != 0) {
|
||||
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
|
||||
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}),
|
||||
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}),
|
||||
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16
|
||||
);
|
||||
}
|
||||
cute::tma_store_fence();
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
|
||||
// Use TMA store to write back to global memory
|
||||
if (threadIdx.x == 0) {
|
||||
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N,
|
||||
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
|
||||
cute::tma_store_arrive();
|
||||
cute::tma_store_wait<0>();
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
}
|
||||
#else
|
||||
if (blockIdx.x == 0 and threadIdx.x == 0)
|
||||
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t kNumGroups, uint32_t kNumStages,
|
||||
uint32_t kNumTMAMulticast,
|
||||
GemmType kGemmType>
|
||||
class Gemm {
|
||||
private:
|
||||
using Barrier = cuda::barrier<cuda::thread_scope_block>;
|
||||
|
||||
public:
|
||||
Gemm() = default;
|
||||
|
||||
static void run(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
uint32_t shape_m,
|
||||
const CUtensorMap& tma_a_desc,
|
||||
const CUtensorMap& tma_b_desc,
|
||||
const CUtensorMap& tma_scales_a_desc,
|
||||
const CUtensorMap& tma_d_desc,
|
||||
cudaStream_t stream,
|
||||
int num_sms, uint32_t smem_size) {
|
||||
// NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps
|
||||
constexpr uint32_t kNumTMAThreads = 128;
|
||||
constexpr uint32_t kNumMathThreadsPerGroup = 128;
|
||||
auto kernel = fp8_gemm_kernel<SHAPE_N, SHAPE_K, BLOCK_M, BLOCK_N, BLOCK_K,
|
||||
kNumGroups, kNumStages, kNumTMAThreads, kNumMathThreadsPerGroup,
|
||||
kNumTMAMulticast, kGemmType>;
|
||||
DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess);
|
||||
|
||||
// Cluster launch
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = num_sms;
|
||||
config.blockDim = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
|
||||
config.dynamicSmemBytes = smem_size;
|
||||
config.stream = stream;
|
||||
|
||||
// Clusters for TMA multicast
|
||||
// NOTES: `>= 4` cluster size will cause performance degradation
|
||||
cudaLaunchAttribute attr;
|
||||
attr.id = cudaLaunchAttributeClusterDimension;
|
||||
attr.val.clusterDim = {kNumTMAMulticast, 1, 1};
|
||||
config.attrs = &attr;
|
||||
config.numAttrs = 1;
|
||||
|
||||
// Launch
|
||||
auto status = cudaLaunchKernelEx(&config, kernel,
|
||||
gmem_d, scales_b, grouped_layout,
|
||||
shape_m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc);
|
||||
DG_HOST_ASSERT(status == cudaSuccess);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_a_desc(T* global_address, uint32_t shape_m) {
|
||||
return make_2d_tma_desc(global_address, Layout::RowMajor,
|
||||
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_K, BLOCK_M, BLOCK_K);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_b_desc(T* global_address) {
|
||||
return make_2d_tma_desc(global_address, Layout::ColMajor,
|
||||
SHAPE_K, SHAPE_N * (kGemmType != GemmType::Normal ? kNumGroups : 1), BLOCK_K, BLOCK_N);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) {
|
||||
return make_2d_tma_desc(global_address, Layout::RowMajor,
|
||||
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N, BLOCK_M, BLOCK_N,
|
||||
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) {
|
||||
// Make TMA aligned to 16 bytes
|
||||
constexpr uint32_t kAlignment = 16 / sizeof(T);
|
||||
shape_m = cell_div(shape_m, kAlignment) * kAlignment;
|
||||
|
||||
return make_2d_tma_desc(global_address, Layout::ColMajor,
|
||||
shape_m, cell_div(SHAPE_K, BLOCK_K) * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), BLOCK_M, 1,
|
||||
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_desc(
|
||||
T* global_address, Layout layout,
|
||||
uint32_t gmem_rows, uint32_t gmem_cols,
|
||||
uint32_t smem_rows, uint32_t smem_cols,
|
||||
CUtensorMapSwizzle swizzle_type = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) {
|
||||
if (layout == Layout::RowMajor) {
|
||||
uint64_t gmem_dim[2] = {gmem_cols, gmem_rows};
|
||||
uint32_t smem_dim[2] = {smem_cols, smem_rows};
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_cols * sizeof(T), smem_dim, swizzle_type);
|
||||
} else {
|
||||
uint64_t gmem_dim[2] = {gmem_rows, gmem_cols};
|
||||
uint32_t smem_dim[2] = {smem_rows, smem_cols};
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_rows * sizeof(T), smem_dim, swizzle_type);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
}; // namespace deep_gemm
|
||||
|
||||
#pragma clang diagnostic pop
|
||||
@ -0,0 +1,885 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
struct SM90_64x16x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %10, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
||||
" %8,"
|
||||
" %9,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 16;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x24x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %14, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11},"
|
||||
" %12,"
|
||||
" %13,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 24;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x32x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %18, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
||||
" %16,"
|
||||
" %17,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 32;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x40x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %22, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n40k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19},"
|
||||
" %20,"
|
||||
" %21,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 40;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x48x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %26, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23},"
|
||||
" %24,"
|
||||
" %25,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 48;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x56x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %30, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n56k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27}, "
|
||||
" %28,"
|
||||
" %29,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 56;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x64x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %34, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31}, "
|
||||
" %32,"
|
||||
" %33,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 64;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x72x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %38, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n72k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35}, "
|
||||
" %36,"
|
||||
" %37,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 72;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x80x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %42, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39}, "
|
||||
" %40,"
|
||||
" %41,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 80;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x88x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %46, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n88k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43}, "
|
||||
" %44,"
|
||||
" %45,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 88;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x96x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %50, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47}, "
|
||||
" %48,"
|
||||
" %49,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 96;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x104x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
float& d48, float& d49, float& d50, float& d51,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %54, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n104k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
||||
" %48, %49, %50, %51}, "
|
||||
" %52,"
|
||||
" %53,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
||||
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
d[48], d[49], d[50], d[51],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 104;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x112x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %58, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
||||
" %48, %49, %50, %51, %52, %53, %54, %55}, "
|
||||
" %56,"
|
||||
" %57,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
||||
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 112;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x120x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
|
||||
float& d56, float& d57, float& d58, float& d59,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %62, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n120k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
||||
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
||||
" %56, %57, %58, %59}, "
|
||||
" %60,"
|
||||
" %61,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
||||
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
||||
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
|
||||
d[56], d[57], d[58], d[59],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 120;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x128x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
|
||||
float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %66, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
||||
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
||||
" %56, %57, %58, %59, %60, %61, %62, %63}, "
|
||||
" %64,"
|
||||
" %65,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
||||
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
||||
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
|
||||
d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 128;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x192x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
|
||||
float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63,
|
||||
float& d64, float& d65, float& d66, float& d67, float& d68, float& d69, float& d70, float& d71,
|
||||
float& d72, float& d73, float& d74, float& d75, float& d76, float& d77, float& d78, float& d79,
|
||||
float& d80, float& d81, float& d82, float& d83, float& d84, float& d85, float& d86, float& d87,
|
||||
float& d88, float& d89, float& d90, float& d91, float& d92, float& d93, float& d94, float& d95,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %98, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
||||
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
||||
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
||||
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
||||
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
||||
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
||||
" %88, %89, %90, %91, %92, %93, %94, %95}, "
|
||||
" %96,"
|
||||
" %97,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
||||
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
||||
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63),
|
||||
"+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71),
|
||||
"+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79),
|
||||
"+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87),
|
||||
"+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
|
||||
d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63],
|
||||
d[64], d[65], d[66], d[67], d[68], d[69], d[70], d[71],
|
||||
d[72], d[73], d[74], d[75], d[76], d[77], d[78], d[79],
|
||||
d[80], d[81], d[82], d[83], d[84], d[85], d[86], d[87],
|
||||
d[88], d[89], d[90], d[91], d[92], d[93], d[94], d[95],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 192;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
template <typename dtype_t>
|
||||
struct SM90_U32x2_STSM_N {
|
||||
__device__ __forceinline__ static void
|
||||
copy(dtype_t src_0, dtype_t src_1, void* smem_dst) {
|
||||
const uint32_t src[2] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1)};
|
||||
asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n"
|
||||
:: "l"(smem_dst), "r"(src[0]), "r"(src[1]));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename dtype_t>
|
||||
struct SM90_U32x4_STSM_N {
|
||||
__device__ __forceinline__ static void
|
||||
copy(dtype_t src_0, dtype_t src_1, dtype_t src_2, dtype_t src_3, void* smem_dst) {
|
||||
const uint32_t src[4] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1),
|
||||
*reinterpret_cast<uint32_t*>(&src_2), *reinterpret_cast<uint32_t*>(&src_3)};
|
||||
asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n"
|
||||
:: "l"(smem_dst), "r"(src[0]), "r"(src[1]), "r"(src[2]), "r"(src[3]));
|
||||
}
|
||||
};
|
||||
|
||||
__device__ void warpgroup_arrive() {
|
||||
asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory");
|
||||
}
|
||||
|
||||
__device__ void warpgroup_commit_batch() {
|
||||
asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory");
|
||||
}
|
||||
|
||||
__device__ void warpgroup_fence_operand(float& reg) {
|
||||
asm volatile("" : "+f"(reg) :: "memory");
|
||||
}
|
||||
|
||||
__forceinline__ __device__ uint32_t get_lane_id() {
|
||||
uint32_t lane_id;
|
||||
asm("mov.u32 %0, %laneid;" : "=r"(lane_id));
|
||||
return lane_id;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint32_t ld_shared(const uint32_t* __restrict__ ptr) {
|
||||
uint32_t ret;
|
||||
asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int4 ld_shared(const int4* __restrict__ ptr) {
|
||||
int4 ret;
|
||||
asm volatile("ld.shared.v4.s32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float ld_shared(const float* __restrict__ ptr) {
|
||||
float ret;
|
||||
asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_shared(const float* ptr, float val) {
|
||||
asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) {
|
||||
asm volatile("st.shared.u32 [%0], %1;" :: "l"(ptr), "r"(val));
|
||||
}
|
||||
|
||||
template <int N>
|
||||
__device__ void warpgroup_wait() {
|
||||
DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]");
|
||||
asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory");
|
||||
}
|
||||
|
||||
union GmmaDescriptor {
|
||||
__host__ __device__ constexpr GmmaDescriptor() noexcept: desc_(0) {}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor(uint64_t desc) noexcept: desc_(desc) {}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor const &t) noexcept: desc_(t.desc_) {}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor &&t) noexcept: desc_(t.desc_) {}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor const &t) noexcept {
|
||||
desc_ = t.desc_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor &&t) noexcept {
|
||||
desc_ = t.desc_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
uint64_t desc_;
|
||||
uint32_t reg32_[2];
|
||||
uint16_t reg16_[4];
|
||||
|
||||
struct {
|
||||
uint16_t start_address_: 14, : 2;
|
||||
uint16_t leading_byte_offset_: 14, : 2;
|
||||
uint16_t stride_byte_offset_: 14, : 2;
|
||||
uint8_t : 1, base_offset_: 3, : 4;
|
||||
uint8_t : 6, layout_type_: 2;
|
||||
} bitfield;
|
||||
|
||||
// Decay to an `uint64_t`
|
||||
__host__ __device__ constexpr operator uint64_t() const noexcept { return desc_; }
|
||||
};
|
||||
|
||||
template <class PointerType>
|
||||
__device__ GmmaDescriptor make_smem_desc(PointerType smem_ptr, int layout_type,
|
||||
int leading_byte_offset = 0,
|
||||
int stride_byte_offset = 1024) {
|
||||
GmmaDescriptor desc;
|
||||
auto uint_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
desc.bitfield.start_address_ = uint_ptr >> 4;
|
||||
desc.bitfield.layout_type_ = layout_type;
|
||||
desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4;
|
||||
desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4;
|
||||
desc.bitfield.base_offset_ = 0;
|
||||
return desc;
|
||||
}
|
||||
|
||||
template <int N>
|
||||
struct FP8MMASelector {
|
||||
static constexpr auto select_type() {
|
||||
if constexpr (N == 16) return SM90_64x16x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 24) return SM90_64x24x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 32) return SM90_64x32x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 40) return SM90_64x40x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 48) return SM90_64x48x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 56) return SM90_64x56x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 64) return SM90_64x64x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 72) return SM90_64x72x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 80) return SM90_64x80x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 88) return SM90_64x88x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 96) return SM90_64x96x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 104) return SM90_64x104x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 112) return SM90_64x112x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 120) return SM90_64x120x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 128) return SM90_64x128x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 192) return SM90_64x192x32_F32E4M3E4M3_SS();
|
||||
}
|
||||
|
||||
using type = decltype(select_type());
|
||||
};
|
||||
|
||||
} // namespace deep_gemm
|
||||
@ -0,0 +1,103 @@
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
enum class GemmType {
|
||||
Normal,
|
||||
GroupedContiguous,
|
||||
GroupedMasked
|
||||
};
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init"
|
||||
template <GemmType kGemmType,
|
||||
uint32_t SHAPE_N, uint32_t BLOCK_M, uint32_t BLOCK_N,
|
||||
uint32_t kNumGroups, uint32_t kNumTMAMulticast,
|
||||
uint32_t kNumNBlocks = cell_div(SHAPE_N, BLOCK_N),
|
||||
uint32_t kNumNBlocksPerGroup = 16>
|
||||
struct Scheduler {
|
||||
int current_iter = -1;
|
||||
uint32_t num_aligned_m_blocks;
|
||||
|
||||
// For normal GEMM
|
||||
// Maybe not used in the masked grouped GEMM
|
||||
uint32_t num_blocks;
|
||||
|
||||
// For grouped GEMM
|
||||
int* grouped_layout;
|
||||
// Only used for masked layout
|
||||
uint32_t curr_group_idx, curr_cumsum;
|
||||
|
||||
__device__ __forceinline__ explicit Scheduler(const uint32_t shape_m,
|
||||
int* grouped_layout = nullptr) {
|
||||
num_aligned_m_blocks = cell_div(shape_m, BLOCK_M);
|
||||
if constexpr (kGemmType == GemmType::Normal) {
|
||||
num_blocks = num_aligned_m_blocks * kNumNBlocks;
|
||||
} else if (kGemmType == GemmType::GroupedContiguous) {
|
||||
num_blocks = num_aligned_m_blocks * kNumNBlocks;
|
||||
this->grouped_layout = grouped_layout;
|
||||
} else if (kGemmType == GemmType::GroupedMasked) {
|
||||
curr_group_idx = curr_cumsum = 0;
|
||||
this->grouped_layout = grouped_layout;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) {
|
||||
DG_STATIC_ASSERT(kNumNBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
|
||||
|
||||
// Swizzle for better L2 usages
|
||||
auto num_blocks_per_group = num_m_blocks * kNumNBlocksPerGroup;
|
||||
auto group_idx = block_idx / num_blocks_per_group;
|
||||
auto first_n_block_idx = group_idx * kNumNBlocksPerGroup;
|
||||
auto num_n_blocks_in_group = min(kNumNBlocksPerGroup, kNumNBlocks - first_n_block_idx);
|
||||
auto in_group_idx = block_idx % num_blocks_per_group;
|
||||
m_block_idx = in_group_idx / num_n_blocks_in_group;
|
||||
n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group;
|
||||
}
|
||||
|
||||
template <bool kIgnoreGroupedForGroupedContiguous=true>
|
||||
__device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size,
|
||||
const uint32_t& block_idx, const uint32_t& m_block_idx=0) {
|
||||
if constexpr (kGemmType == GemmType::Normal) {
|
||||
return block_idx * block_size;
|
||||
} else if (kGemmType == GemmType::GroupedContiguous) {
|
||||
auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg(grouped_layout + m_block_idx * BLOCK_M);
|
||||
return offset * shape_dim + block_idx * block_size;
|
||||
} else if (kGemmType == GemmType::GroupedMasked) {
|
||||
return curr_group_idx * shape_dim + block_idx * block_size;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) {
|
||||
const auto next_block_idx = (++ current_iter) * gridDim.x + blockIdx.x;
|
||||
|
||||
if constexpr (kGemmType == GemmType::GroupedMasked) {
|
||||
uint32_t num_m_blocks;
|
||||
while (true) {
|
||||
// End of the task
|
||||
if (curr_group_idx == kNumGroups)
|
||||
return false;
|
||||
|
||||
// Within current group
|
||||
num_m_blocks = cell_div(static_cast<uint32_t>(__ldg(grouped_layout + curr_group_idx)), BLOCK_M);
|
||||
auto current_m_block_cumsum = curr_cumsum + num_m_blocks;
|
||||
if (next_block_idx < current_m_block_cumsum * kNumNBlocks)
|
||||
break;
|
||||
|
||||
// Move to check the next group
|
||||
curr_group_idx ++, curr_cumsum = current_m_block_cumsum;
|
||||
}
|
||||
|
||||
get_swizzled_block_idx(num_m_blocks, next_block_idx - curr_cumsum * kNumNBlocks, m_block_idx, n_block_idx);
|
||||
} else {
|
||||
if (next_block_idx >= num_blocks)
|
||||
return false;
|
||||
|
||||
get_swizzled_block_idx(num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
} // namespace deep_gemm
|
||||
@ -0,0 +1,96 @@
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <cuda.h>
|
||||
#include <cudaTypedefs.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda/barrier>
|
||||
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
template <class T>
|
||||
constexpr CUtensorMapDataType get_CUtensorMapDataType() {
|
||||
if constexpr (std::is_same<T, uint8_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
} else if constexpr (std::is_same<T, __nv_fp8_e4m3>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
} else if constexpr (std::is_same<T, __nv_fp8_e5m2>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
} else if constexpr (std::is_same<T, uint16_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT16;
|
||||
} else if constexpr (std::is_same<T, uint32_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT32;
|
||||
} else if constexpr (std::is_same<T, uint64_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT64;
|
||||
} else if constexpr (std::is_same<T, int32_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_INT32;
|
||||
} else if constexpr (std::is_same<T, int64_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_INT64;
|
||||
} else if constexpr (std::is_same<T, __half>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
|
||||
} else if constexpr (std::is_same<T, float>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
|
||||
} else if constexpr (std::is_same<T, double>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_FLOAT64;
|
||||
}
|
||||
}
|
||||
|
||||
PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() {
|
||||
// Get pointer to `cuTensorMapEncodeTiled`
|
||||
cudaDriverEntryPointQueryResult driver_status;
|
||||
void* cuTensorMapEncodeTiled_ptr = nullptr;
|
||||
|
||||
#if CUDA_VERSION >= 12050
|
||||
cudaGetDriverEntryPointByVersion("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, 12000,
|
||||
cudaEnableDefault, &driver_status);
|
||||
#else
|
||||
cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr,
|
||||
cudaEnableDefault, &driver_status);
|
||||
#endif
|
||||
|
||||
if (driver_status != cudaDriverEntryPointSuccess)
|
||||
throw std::runtime_error("driver_status != cudaDriverEntryPointSuccess");
|
||||
return reinterpret_cast<PFN_cuTensorMapEncodeTiled>(cuTensorMapEncodeTiled_ptr);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2],
|
||||
uint64_t stride_in_bytes, uint32_t smem_dim[2],
|
||||
CUtensorMapSwizzle swizzle_type,
|
||||
PFN_cuTensorMapEncodeTiled encode_func = nullptr) {
|
||||
CUtensorMap tensor_map{};
|
||||
constexpr uint32_t rank = 2;
|
||||
uint64_t global_stride[rank - 1] = {stride_in_bytes};
|
||||
uint32_t elem_strides[rank] = {1, 1};
|
||||
|
||||
if (encode_func == nullptr)
|
||||
encode_func = get_cuTensorMapEncodeTiled();
|
||||
|
||||
auto result = encode_func(
|
||||
&tensor_map, get_CUtensorMapDataType<typename std::remove_cv<T>::type>(), rank,
|
||||
global_address, gmem_dim, global_stride, smem_dim, elem_strides,
|
||||
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type,
|
||||
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
|
||||
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
|
||||
DG_HOST_ASSERT(result == CUDA_SUCCESS);
|
||||
return tensor_map;
|
||||
}
|
||||
|
||||
template <uint32_t kNumTMAMulticast = 1>
|
||||
__device__ __forceinline__ void
|
||||
tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,
|
||||
int32_t const& crd_0, int32_t const& crd_1) {
|
||||
constexpr auto cache_hint = static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL);
|
||||
if constexpr (kNumTMAMulticast == 1) {
|
||||
cute::SM90_TMA_LOAD_2D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1);
|
||||
} else if (cute::block_rank_in_cluster() == 0) {
|
||||
cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, (1 << kNumTMAMulticast) - 1, cache_hint, smem_ptr, crd_0, crd_1);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
@ -0,0 +1,48 @@
|
||||
#pragma once
|
||||
|
||||
#include <exception>
|
||||
|
||||
#ifdef __CLION_IDE__
|
||||
__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { asm volatile("trap;"); }
|
||||
#define printf host_device_printf
|
||||
#endif
|
||||
|
||||
class AssertionException : public std::exception {
|
||||
private:
|
||||
std::string message{};
|
||||
|
||||
public:
|
||||
explicit AssertionException(const std::string& message) : message(message) {}
|
||||
|
||||
const char *what() const noexcept override { return message.c_str(); }
|
||||
};
|
||||
|
||||
#ifndef DG_HOST_ASSERT
|
||||
#define DG_HOST_ASSERT(cond) \
|
||||
do { \
|
||||
if (not (cond)) { \
|
||||
printf("Assertion failed: %s:%d, condition: %s\n", \
|
||||
__FILE__, __LINE__, #cond); \
|
||||
throw AssertionException("Assertion failed: " #cond); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#ifndef DG_DEVICE_ASSERT
|
||||
#define DG_DEVICE_ASSERT(cond) \
|
||||
do { \
|
||||
if (not (cond)) { \
|
||||
printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \
|
||||
asm("trap;"); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#ifndef DG_STATIC_ASSERT
|
||||
#define DG_STATIC_ASSERT(cond, reason) static_assert(cond, reason)
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
__device__ __host__ constexpr T cell_div(T a, T b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
@ -0,0 +1,3 @@
|
||||
from .compiler import get_nvcc_compiler, build
|
||||
from .template import cpp_format, generate
|
||||
from .runtime import Runtime
|
||||
@ -0,0 +1,146 @@
|
||||
import hashlib
|
||||
import functools
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import uuid
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
from typing import Tuple
|
||||
|
||||
from .runtime import Runtime, RuntimeCache
|
||||
from .template import typename_map
|
||||
|
||||
runtime_cache = RuntimeCache()
|
||||
|
||||
|
||||
def hash_to_hex(s: str) -> str:
|
||||
md5 = hashlib.md5()
|
||||
md5.update(s.encode('utf-8'))
|
||||
return md5.hexdigest()[0:12]
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_jit_include_dir() -> str:
|
||||
return f'{os.path.dirname(os.path.abspath(__file__))}/../include'
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_deep_gemm_version() -> str:
|
||||
# Update include directories
|
||||
include_dir = f'{get_jit_include_dir()}/deep_gemm'
|
||||
assert os.path.exists(include_dir), f'Cannot find GEMM include directory {include_dir}'
|
||||
md5 = hashlib.md5()
|
||||
for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))):
|
||||
with open(f'{include_dir}/{filename}', 'rb') as f:
|
||||
md5.update(f.read())
|
||||
|
||||
return md5.hexdigest()[0:12]
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_nvcc_compiler() -> Tuple[str, str]:
|
||||
paths = []
|
||||
if os.getenv('DG_NVCC_COMPILER'):
|
||||
paths.append(os.getenv('DG_NVCC_COMPILER'))
|
||||
paths.append(f'{CUDA_HOME}/bin/nvcc')
|
||||
|
||||
# Try to find the first available NVCC compiler
|
||||
least_version_required = '12.3'
|
||||
version_pattern = re.compile(r'release (\d+\.\d+)')
|
||||
for path in paths:
|
||||
if os.path.exists(path):
|
||||
match = version_pattern.search(os.popen(f'{path} --version').read())
|
||||
version = match.group(1)
|
||||
assert match, f'Cannot get the version of NVCC compiler {path}'
|
||||
assert version >= least_version_required, f'NVCC {path} version {version} is lower than {least_version_required}'
|
||||
return path, version
|
||||
raise RuntimeError('Cannot find any available NVCC compiler')
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_default_user_dir():
|
||||
if 'DG_CACHE_DIR' in os.environ:
|
||||
path = os.getenv('DG_CACHE_DIR')
|
||||
os.makedirs(path, exist_ok=True)
|
||||
return path
|
||||
return os.path.expanduser('~') + '/.deep_gemm'
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_tmp_dir():
|
||||
return f'{get_default_user_dir()}/tmp'
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_cache_dir():
|
||||
return f'{get_default_user_dir()}/cache'
|
||||
|
||||
|
||||
def make_tmp_dir():
|
||||
tmp_dir = get_tmp_dir()
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
return tmp_dir
|
||||
|
||||
|
||||
def put(path, data, is_binary=False):
|
||||
# Write and do POSIX atomic replace
|
||||
tmp_file_path = f'{make_tmp_dir()}/file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}'
|
||||
with open(tmp_file_path, 'wb' if is_binary else 'w') as f:
|
||||
f.write(data)
|
||||
os.replace(tmp_file_path, path)
|
||||
|
||||
|
||||
def build(name: str, arg_defs: tuple, code: str) -> Runtime:
|
||||
# Compiler flags
|
||||
nvcc_flags = ['-std=c++17', '-shared', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
|
||||
'-gencode=arch=compute_90a,code=sm_90a',
|
||||
'--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
|
||||
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
|
||||
'--diag-suppress=177,174,940']
|
||||
cxx_flags = ['-fPIC', '-O3', '-Wno-deprecated-declarations', '-Wno-abi']
|
||||
flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}']
|
||||
include_dirs = [get_jit_include_dir()]
|
||||
|
||||
# Build signature
|
||||
enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and int(os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0
|
||||
signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}'
|
||||
name = f'kernel.{name}.{hash_to_hex(signature)}'
|
||||
path = f'{get_cache_dir()}/{name}'
|
||||
|
||||
# Check runtime cache or file system hit
|
||||
global runtime_cache
|
||||
if runtime_cache[path] is not None:
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(f'Using cached JIT runtime {name} during build')
|
||||
return runtime_cache[path]
|
||||
|
||||
# Write the code
|
||||
os.makedirs(path, exist_ok=True)
|
||||
args_path = f'{path}/kernel.args'
|
||||
src_path = f'{path}/kernel.cu'
|
||||
put(args_path, ', '.join([f"('{arg_def[0]}', {typename_map[arg_def[1]]})" for arg_def in arg_defs]))
|
||||
put(src_path, code)
|
||||
|
||||
# Compile into a temporary SO file
|
||||
so_path = f'{path}/kernel.so'
|
||||
tmp_so_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(so_path)}.so'
|
||||
|
||||
# Compile
|
||||
command = [get_nvcc_compiler()[0],
|
||||
src_path, '-o', tmp_so_path,
|
||||
*flags,
|
||||
*[f'-I{d}' for d in include_dirs]]
|
||||
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False):
|
||||
print(f'Compiling JIT runtime {name} with command {command}')
|
||||
assert subprocess.check_call(command) == 0, f'Failed to compile {src_path}'
|
||||
|
||||
# Interleave FFMA reuse
|
||||
if enable_sass_opt:
|
||||
pass
|
||||
|
||||
# Atomic replace SO file
|
||||
os.replace(tmp_so_path, so_path)
|
||||
|
||||
# Put cache and return
|
||||
runtime_cache[path] = Runtime(path)
|
||||
return runtime_cache[path]
|
||||
@ -0,0 +1,66 @@
|
||||
import ctypes
|
||||
import os
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
from .template import map_ctype
|
||||
|
||||
|
||||
class Runtime:
|
||||
def __init__(self, path: str) -> None:
|
||||
self.path = path
|
||||
self.lib = None
|
||||
self.args = None
|
||||
|
||||
assert self.is_path_valid(self.path)
|
||||
|
||||
@staticmethod
|
||||
def is_path_valid(path: str) -> bool:
|
||||
# Exists and is a directory
|
||||
if not os.path.exists(path) or not os.path.isdir(path):
|
||||
return False
|
||||
|
||||
# Contains all necessary files
|
||||
files = ['kernel.cu', 'kernel.args', 'kernel.so']
|
||||
return all(os.path.exists(os.path.join(path, file)) for file in files)
|
||||
|
||||
def __call__(self, *args) -> int:
|
||||
# Load SO file
|
||||
if self.lib is None or self.args is None:
|
||||
self.lib = ctypes.CDLL(os.path.join(self.path, 'kernel.so'))
|
||||
with open(os.path.join(self.path, 'kernel.args'), 'r') as f:
|
||||
self.args = eval(f.read())
|
||||
|
||||
# Check args and launch
|
||||
assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}'
|
||||
cargs = []
|
||||
for arg, (name, dtype) in zip(args, self.args):
|
||||
if isinstance(arg, torch.Tensor):
|
||||
assert arg.dtype == dtype, f'Expected tensor dtype `{dtype}` for `{name}`, got `{arg.dtype}`'
|
||||
else:
|
||||
assert isinstance(arg, dtype), f'Expected built-in type `{dtype}` for `{name}`, got `{type(arg)}`'
|
||||
cargs.append(map_ctype(arg))
|
||||
|
||||
return_code = ctypes.c_int(0)
|
||||
self.lib.launch(*cargs, ctypes.byref(return_code))
|
||||
return return_code.value
|
||||
|
||||
|
||||
class RuntimeCache:
|
||||
def __init__(self) -> None:
|
||||
self.cache = {}
|
||||
|
||||
def __getitem__(self, path: str) -> Optional[Runtime]:
|
||||
# In Python runtime
|
||||
if path in self.cache:
|
||||
return self.cache[path]
|
||||
|
||||
# Already compiled
|
||||
if os.path.exists(path) and Runtime.is_path_valid(path):
|
||||
runtime = Runtime(path)
|
||||
self.cache[path] = runtime
|
||||
return runtime
|
||||
return None
|
||||
|
||||
def __setitem__(self, path, runtime) -> None:
|
||||
self.cache[path] = runtime
|
||||
@ -0,0 +1,93 @@
|
||||
import copy
|
||||
import ctypes
|
||||
import os
|
||||
import torch
|
||||
|
||||
from typing import Any, Iterable, Dict, Tuple
|
||||
|
||||
|
||||
# Name map for Python `eval`
|
||||
typename_map: Dict[Any, str] = {
|
||||
**{t: t.__name__ for t in (bool, int, float)},
|
||||
torch.int: 'torch.int',
|
||||
torch.float: 'torch.float',
|
||||
torch.bfloat16: 'torch.bfloat16',
|
||||
torch.float8_e4m3fn: 'torch.float8_e4m3fn',
|
||||
torch.cuda.Stream: 'torch.cuda.Stream',
|
||||
}
|
||||
|
||||
# `ctype` map for Python casting
|
||||
ctype_map: Dict[Any, Any] = {
|
||||
**{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)},
|
||||
**{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)},
|
||||
}
|
||||
|
||||
|
||||
# Type map for both Python API and source code usages
|
||||
genc_map = {
|
||||
bool: ('bool', 'bool'),
|
||||
int: ('int', 'int'),
|
||||
float: ('float', 'float'),
|
||||
torch.int: ('void*', 'int*'),
|
||||
torch.float: ('void*', 'float*'),
|
||||
torch.bfloat16: ('void*', '__nv_bfloat16*'),
|
||||
torch.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'),
|
||||
torch.cuda.Stream: ('void*', 'cudaStream_t'),
|
||||
}
|
||||
|
||||
|
||||
def map_ctype(value: Any) -> Any:
|
||||
ctype = ctype_map[value.dtype if isinstance(value, torch.Tensor) else type(value)]
|
||||
if isinstance(value, torch.Tensor):
|
||||
return ctype(value.data_ptr())
|
||||
if isinstance(value, torch.cuda.Stream):
|
||||
return ctype(value.cuda_stream)
|
||||
return ctype(value)
|
||||
|
||||
|
||||
def cpp_format(template: str, keys: Dict[str, Any]) -> str:
|
||||
# We don't use `str.format` because it's not safe for C++ {} braces
|
||||
new_template = copy.deepcopy(template)
|
||||
for key, value in keys.items():
|
||||
new_template = new_template.replace(f'{{{key}}}', f'{value}')
|
||||
return new_template
|
||||
|
||||
|
||||
def generate(includes: Iterable[str], arg_defs: Iterable[Tuple], body: str) -> str:
|
||||
# Common prefix
|
||||
code = '// DeepGEMM auto-generated JIT CUDA source file\n\n'
|
||||
|
||||
# Includes
|
||||
preload_sys_includes = ['<cuda.h>', '<cuda_fp8.h>', '<cuda_runtime.h>', '<iostream>']
|
||||
preload_package_includes = ['"cutlass/cutlass.h"']
|
||||
|
||||
assert isinstance(includes, list) or isinstance(includes, tuple)
|
||||
sys_includes = sorted(list(set(preload_sys_includes + [include for include in includes if include.startswith('<')])))
|
||||
package_includes = sorted(list(set(preload_package_includes + [include for include in includes if include.startswith('"')])))
|
||||
code += '\n'.join(f'#include {include}' for include in sys_includes) + '\n\n'
|
||||
code += '\n'.join(f'#include {include}' for include in package_includes) + '\n\n'
|
||||
|
||||
# Function signature
|
||||
raw = '__raw_'
|
||||
get_def = lambda n, t: f'{genc_map[t][0]} ' + (raw if genc_map[t][0] != genc_map[t][1] else '') + n
|
||||
code += f'extern "C" void launch('
|
||||
code += ', '.join([get_def(*arg_def) for arg_def in arg_defs] + ['int& __return_code', ])
|
||||
code += ') {\n'
|
||||
|
||||
# Cast raw types
|
||||
code += ' // Cast raw types (if needed)\n'
|
||||
for arg_name, arg_type in arg_defs:
|
||||
if genc_map[arg_type][0] != genc_map[arg_type][1]:
|
||||
code += f' auto {arg_name} = reinterpret_cast<{genc_map[arg_type][1]}>({raw}{arg_name});\n'
|
||||
|
||||
# Function body
|
||||
code += '\n'.join([((' ' if line else '') + line) for line in body.split('\n')])
|
||||
|
||||
# End the function
|
||||
code += '}\n\n'
|
||||
|
||||
# Debug print
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(f'Generated code:\n{code}')
|
||||
|
||||
return code
|
||||
@ -0,0 +1,10 @@
|
||||
from .gemm import gemm_fp8_fp8_bf16_nt
|
||||
from .m_grouped_gemm import (
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked
|
||||
)
|
||||
from .utils import (
|
||||
cell_div, set_num_sms, get_num_sms,
|
||||
get_col_major_tma_aligned_tensor,
|
||||
get_m_alignment_for_contiguous_layout
|
||||
)
|
||||
@ -0,0 +1,171 @@
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
from .tuner import jit_tuner
|
||||
from .utils import get_num_sms, cell_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
|
||||
|
||||
# C++ code templates
|
||||
includes = ('"deep_gemm/fp8_gemm.cuh"', )
|
||||
template = """
|
||||
using namespace deep_gemm;
|
||||
|
||||
// Templated args from Python JIT call
|
||||
constexpr auto N = {N}, K = {K};
|
||||
constexpr auto BLOCK_M = {BLOCK_M};
|
||||
constexpr auto BLOCK_N = {BLOCK_N};
|
||||
constexpr auto kNumStages = {NUM_STAGES};
|
||||
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
|
||||
|
||||
// Make a templated GEMM
|
||||
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, 1, kNumStages, kNumTMAMulticast, GemmType::Normal>;
|
||||
|
||||
// Launch kernel
|
||||
auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m);
|
||||
auto tma_b_desc = GemmType::make_2d_tma_b_desc(rhs);
|
||||
auto tma_scales_a_desc = GemmType::make_2d_tma_scales_a_desc(lhs_scales, m);
|
||||
auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m);
|
||||
GemmType::run(out, rhs_scales, nullptr,
|
||||
m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
||||
stream, num_sms, smem_size);
|
||||
"""
|
||||
|
||||
|
||||
def is_tma_multicast_legal(n: int, block_n: int, num_tma_multicast: int, num_sms: int) -> bool:
|
||||
if num_tma_multicast == 1:
|
||||
return True
|
||||
return (n % (block_n * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0
|
||||
|
||||
|
||||
def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> int:
|
||||
smem_d = block_m * block_n * 2
|
||||
smem_a_per_stage = block_m * block_k
|
||||
smem_scales_a_per_stage = block_m * 4
|
||||
smem_b_per_stage = block_n * block_k
|
||||
smem_scales_b = cell_div(k, block_k) * 4
|
||||
smem_barrier = num_stages * 8 * 2
|
||||
|
||||
smem_size = 0
|
||||
smem_size += smem_d
|
||||
smem_size += num_stages * smem_a_per_stage
|
||||
smem_size += num_stages * smem_scales_a_per_stage
|
||||
smem_size += num_stages * smem_b_per_stage
|
||||
smem_size += smem_scales_b * (1 if block_k % block_n == 0 else 2)
|
||||
smem_size += smem_barrier
|
||||
return smem_size
|
||||
|
||||
|
||||
def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
is_grouped_contiguous: bool = False) -> Tuple[int, int, int, int, int]:
|
||||
if not is_grouped_contiguous:
|
||||
# TODO: for some cases, smaller M block is better, add them into tuning space
|
||||
block_ms = (64 if m <= 64 else 128, )
|
||||
else:
|
||||
block_ms = (get_m_alignment_for_contiguous_layout(), )
|
||||
block_ns = tuple(range(16, 129, 8))
|
||||
|
||||
fix_wave_saturate = lambda x: num_sms if x == 0 else x
|
||||
get_num_waves = lambda bm, bn: (cell_div(cell_div(m, bm) * cell_div(n, bn) * num_groups, num_sms) if bm else None)
|
||||
get_last_wave_util = lambda bm, bn: fix_wave_saturate((cell_div(m, bm) * cell_div(n, bn) * num_groups) % num_sms)
|
||||
|
||||
# Decide block sizes by waves
|
||||
best_block_m, best_block_n = None, None
|
||||
for block_m in block_ms:
|
||||
for block_n in block_ns:
|
||||
success = False
|
||||
num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n)
|
||||
if best_block_m is None or best_block_n is None:
|
||||
success = True
|
||||
elif num_waves < best_num_waves:
|
||||
success = True
|
||||
elif num_waves == best_num_waves:
|
||||
# Check last wave utilization
|
||||
util = get_last_wave_util(block_m, block_n)
|
||||
best_util = get_last_wave_util(best_block_m, best_block_n)
|
||||
success = util > best_util or (util == best_util and (block_n >= best_block_n and block_m <= best_block_m))
|
||||
best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n)
|
||||
assert best_block_m is not None and best_block_n is not None
|
||||
|
||||
# Always pick the longest one
|
||||
# NOTES: for double B scales, the best number of stages may be reduced
|
||||
best_num_stages, best_smem_size, sm90_capacity = None, None, 232448
|
||||
for num_stages in (6, 5, 4) if 128 % best_block_n != 0 else (8, 7, 6, 5, 4):
|
||||
best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n)
|
||||
if best_smem_size <= sm90_capacity:
|
||||
best_num_stages = num_stages
|
||||
break
|
||||
assert best_num_stages is not None
|
||||
|
||||
# Decide the number of TMA multicast
|
||||
best_num_tma_multicast = 1
|
||||
if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, num_sms) and num_groups == 1:
|
||||
best_num_tma_multicast = 2
|
||||
|
||||
return best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size
|
||||
|
||||
|
||||
def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
out: torch.Tensor) -> None:
|
||||
"""
|
||||
Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
||||
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
||||
RHS and RHS scaling factors are required to be transposed.
|
||||
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
|
||||
this function will do a transposing with a set of slow PyTorch operations.
|
||||
|
||||
Arguments:
|
||||
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
|
||||
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`.
|
||||
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`.
|
||||
the second element is an FP32 128x128 scaling tensor for RHS of shape `[⌈n / 128⌉, ⌈k / 128⌉]`.
|
||||
out: the BF16 output tensor of shape `[m, n]`, representing the result.
|
||||
"""
|
||||
lhs, lhs_scales = lhs
|
||||
rhs, rhs_scales = rhs
|
||||
m, k = lhs.shape
|
||||
n, k_ = rhs.shape
|
||||
m_, n_ = out.shape
|
||||
|
||||
assert n % 64 == 0 and k % 128 == 0
|
||||
|
||||
# Type and shape checks
|
||||
assert m == m_ and n == n_ and k == k_
|
||||
assert n > 0 and k > 0
|
||||
assert lhs_scales.shape == (m, (k + 127) // 128)
|
||||
assert rhs_scales.shape == ((n + 127) // 128, (k + 127) // 128)
|
||||
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||
assert out.dtype == torch.bfloat16
|
||||
assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
assert rhs_scales.is_contiguous()
|
||||
|
||||
# Do nothing if `m` is zero
|
||||
if m == 0:
|
||||
return
|
||||
|
||||
# Auto-tuning with compilation
|
||||
global includes, template
|
||||
num_sms = get_num_sms()
|
||||
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms)
|
||||
args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_size)
|
||||
runtime = jit_tuner.compile_and_tune(
|
||||
name='gemm_fp8_fp8_bf16_nt',
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast},
|
||||
space=(),
|
||||
includes=includes,
|
||||
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
||||
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
|
||||
('out', torch.bfloat16), ('m', int),
|
||||
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
||||
template=template,
|
||||
args=args
|
||||
)
|
||||
|
||||
# Run the kernel
|
||||
runtime(*args)
|
||||
@ -0,0 +1,182 @@
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
from .gemm import get_best_configs
|
||||
from .tuner import jit_tuner
|
||||
from .utils import get_col_major_tma_aligned_tensor, get_num_sms
|
||||
|
||||
# C++ code templates
|
||||
includes = ('"deep_gemm/fp8_gemm.cuh"', )
|
||||
template = """
|
||||
using namespace deep_gemm;
|
||||
|
||||
// Templated args from Python JIT call
|
||||
constexpr auto N = {N}, K = {K};
|
||||
constexpr auto BLOCK_M = {BLOCK_M};
|
||||
constexpr auto BLOCK_N = {BLOCK_N};
|
||||
constexpr auto kNumStages = {NUM_STAGES};
|
||||
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
|
||||
|
||||
// Make a templated grouped GEMM
|
||||
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, {NUM_GROUPS}, kNumStages, kNumTMAMulticast, GemmType::{GEMM_TYPE}>;
|
||||
|
||||
// Launch kernel
|
||||
auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m);
|
||||
auto tma_b_desc = GemmType::make_2d_tma_b_desc(rhs);
|
||||
auto tma_scales_a_desc = GemmType::make_2d_tma_scales_a_desc(lhs_scales, m);
|
||||
auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m);
|
||||
GemmType::run(out, rhs_scales, grouped_layout,
|
||||
m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
||||
stream, num_sms, smem_size);
|
||||
"""
|
||||
|
||||
|
||||
def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
out: torch.Tensor, m_indices: torch.Tensor) -> None:
|
||||
"""
|
||||
Do a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
||||
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
||||
RHS and RHS scaling factors are required to be transposed.
|
||||
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
|
||||
this function will do a transposing with a set of slow PyTorch operations.
|
||||
On the M axis, inputs are grouped into several batches, of which batch sizes aligned to
|
||||
`get_m_alignment_for_contiguous_layout()` (128).
|
||||
|
||||
Arguments:
|
||||
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`,
|
||||
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m_sum, ⌈k / 128⌉]`.
|
||||
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`.
|
||||
the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
|
||||
out: the BF16 output tensor of shape `[m_sum, n]`, representing the result.
|
||||
m_indices: a tensor of shape `[m_sum]` with type `torch.int`.
|
||||
`m_indices[i]` records the group which the j-th row of the LHS belong to,
|
||||
which means that the i-th row of the LHS matrix will be multiplied with `rhs[m_indices[i]]`.
|
||||
Values of `m_indices` in every-m-alignment-block must also be the same.
|
||||
`-1` in this tensor indicates no RHS matrix selected, the kernel will skip the computation for that aligned block.
|
||||
"""
|
||||
lhs, lhs_scales = lhs
|
||||
rhs, rhs_scales = rhs
|
||||
m, k = lhs.shape
|
||||
num_groups, n, k_ = rhs.shape
|
||||
m_, n_ = out.shape
|
||||
m__ = m_indices.numel()
|
||||
|
||||
# Type and shape checks
|
||||
assert m == m_ == m__ and k == k_ and n == n_
|
||||
assert lhs_scales.shape == (m, (k + 127) // 128)
|
||||
assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
|
||||
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||
assert out.dtype == torch.bfloat16
|
||||
assert m_indices.dtype == torch.int32
|
||||
assert lhs.is_contiguous() and rhs.is_contiguous()
|
||||
assert out.is_contiguous() and m_indices.is_contiguous()
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
assert rhs_scales.is_contiguous()
|
||||
|
||||
# Do nothing if `m` is zero
|
||||
if m == 0:
|
||||
return
|
||||
|
||||
# Auto-tuning with compilation
|
||||
global includes, template
|
||||
num_sms = get_num_sms()
|
||||
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms,
|
||||
is_grouped_contiguous=True)
|
||||
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
||||
m_indices, m, num_groups,
|
||||
torch.cuda.current_stream(), num_sms, smem_size)
|
||||
runtime = jit_tuner.compile_and_tune(
|
||||
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
|
||||
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedContiguous'},
|
||||
space=(),
|
||||
includes=includes,
|
||||
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
||||
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
|
||||
('out', torch.bfloat16),
|
||||
('grouped_layout', torch.int32), ('m', int), ('num_groups', int),
|
||||
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
||||
template=template,
|
||||
args=args
|
||||
)
|
||||
|
||||
# Run the kernel
|
||||
runtime(*args)
|
||||
|
||||
|
||||
def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None:
|
||||
"""
|
||||
Do a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
||||
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
||||
RHS and RHS scaling factors are required to be transposed.
|
||||
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
|
||||
this function will do a transposing with a set of slow PyTorch operations.
|
||||
Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch
|
||||
should be separately transposed.
|
||||
|
||||
Arguments:
|
||||
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
|
||||
the second element is an FP32 1x128 scaling tensor for LHS of shape `[num_groups, m_max, ⌈k / 128⌉]`.
|
||||
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`.
|
||||
the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
|
||||
out: the BF16 output tensor of shape `[num_groups, m_max, n]`, representing the result.
|
||||
masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute
|
||||
in the i-th group.
|
||||
expected_m: a value hint (which is a value on CPU) for the M expectation of each batch,
|
||||
correctly setting this value may lead to better performance.
|
||||
"""
|
||||
lhs, lhs_scales = lhs
|
||||
rhs, rhs_scales = rhs
|
||||
num_groups, m, k = lhs.shape
|
||||
num_groups_, n, k_ = rhs.shape
|
||||
num_groups__, m_, n_ = out.shape
|
||||
num_groups___ = masked_m.numel()
|
||||
|
||||
# Type and shape checks
|
||||
assert num_groups == num_groups_ == num_groups__ == num_groups___
|
||||
assert m == m_ and n == n_ and k == k_
|
||||
assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0
|
||||
assert lhs_scales.shape == (num_groups, m, (k + 127) // 128)
|
||||
assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
|
||||
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||
assert out.dtype == torch.bfloat16
|
||||
assert masked_m.dtype == torch.int32
|
||||
assert lhs.is_contiguous() and rhs.is_contiguous()
|
||||
assert out.is_contiguous() and masked_m.is_contiguous()
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
assert rhs_scales.is_contiguous()
|
||||
|
||||
# Auto-tuning with compilation
|
||||
global includes, template
|
||||
num_sms = get_num_sms()
|
||||
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms)
|
||||
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
||||
masked_m, m,
|
||||
torch.cuda.current_stream(), num_sms, smem_size)
|
||||
runtime = jit_tuner.compile_and_tune(
|
||||
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
|
||||
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedMasked'},
|
||||
space=(),
|
||||
includes=includes,
|
||||
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
||||
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
|
||||
('out', torch.bfloat16),
|
||||
('grouped_layout', torch.int32), ('m', int),
|
||||
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
||||
template=template,
|
||||
args=args
|
||||
)
|
||||
|
||||
# Run the kernel
|
||||
runtime(*args)
|
||||
@ -0,0 +1,81 @@
|
||||
import copy
|
||||
import os
|
||||
import torch
|
||||
from typing import Any, Dict
|
||||
|
||||
from ..jit import build, cpp_format, generate, Runtime
|
||||
|
||||
|
||||
class JITTuner:
|
||||
def __init__(self) -> None:
|
||||
self.tuned = {}
|
||||
|
||||
def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple,
|
||||
includes: tuple, arg_defs: tuple, template: str, args: tuple) -> Runtime:
|
||||
# NOTES: we always assume the space and template will not change
|
||||
# We also assume the GPU device will not be changed
|
||||
# NOTES: the function must have no accumulated side effects
|
||||
keys = {k: keys[k] for k in sorted(keys.keys())}
|
||||
signature = (name, f'{keys}')
|
||||
if signature in self.tuned:
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(f'Using cached JIT kernel {name} with keys {keys}')
|
||||
return self.tuned[signature]
|
||||
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(f'Auto-tuning JIT kernel {name} with keys {keys}')
|
||||
|
||||
assert signature not in self.tuned
|
||||
assert args is not None
|
||||
space = (dict(), ) if len(space) == 0 else space
|
||||
|
||||
kernels = []
|
||||
for tuned_keys in space:
|
||||
assert isinstance(tuned_keys, dict)
|
||||
full_keys = copy.deepcopy(keys)
|
||||
full_keys.update(tuned_keys)
|
||||
code = generate(includes, arg_defs, cpp_format(template, full_keys))
|
||||
|
||||
# Illegal build must raise errors
|
||||
kernels.append((build(name, arg_defs, code), tuned_keys))
|
||||
|
||||
best_runtime, best_time, best_keys = None, None, None
|
||||
for runtime, tuned_keys in kernels:
|
||||
if len(space) > 1:
|
||||
# Check kernel validity
|
||||
return_code = runtime(*args)
|
||||
if return_code != 0:
|
||||
# Pass illegal kernels, e.g. insufficient shared memory capacity
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}')
|
||||
continue
|
||||
|
||||
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda').zero_()
|
||||
torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
||||
start_event.record()
|
||||
for i in range(20):
|
||||
assert runtime(*args) == 0
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
elapsed_time = start_event.elapsed_time(end_event)
|
||||
else:
|
||||
elapsed_time = 0
|
||||
|
||||
# Compare if better
|
||||
if best_time is None or elapsed_time < best_time:
|
||||
best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}')
|
||||
assert best_runtime is not None, f'Failed to tune JIT kernel {name} with keys {keys}'
|
||||
|
||||
# Cache the best runtime and return
|
||||
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_PRINT_AUTOTUNE', None):
|
||||
print(f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}')
|
||||
self.tuned[signature] = best_runtime
|
||||
return best_runtime
|
||||
|
||||
|
||||
jit_tuner = JITTuner()
|
||||
@ -0,0 +1,105 @@
|
||||
import torch
|
||||
|
||||
_num_sms = None
|
||||
|
||||
|
||||
def set_num_sms(num_sms: int) -> None:
|
||||
"""
|
||||
Set the maximum SM count for all GEMM kernels to use.
|
||||
|
||||
Arguments:
|
||||
num_sms: the desired maximum SM count for all GEMM kernels to use.
|
||||
"""
|
||||
global _num_sms
|
||||
assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count
|
||||
_num_sms = num_sms
|
||||
|
||||
|
||||
def get_num_sms() -> int:
|
||||
"""
|
||||
Get the current maximum limit of SM count for all GEMM kernels to use.
|
||||
If the count is never specified, the function will return the number of device SMs.
|
||||
|
||||
Returns:
|
||||
Current maximum limit of SM count for all GEMM kernels to use.
|
||||
"""
|
||||
global _num_sms
|
||||
if _num_sms is None:
|
||||
_num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count
|
||||
return _num_sms
|
||||
|
||||
|
||||
def cell_div(x: int, y: int) -> int:
|
||||
"""
|
||||
Perform ceiling division of two integers.
|
||||
|
||||
Args:
|
||||
x: the dividend.
|
||||
y: the divisor.
|
||||
|
||||
Returns:
|
||||
The result of the ceiling division.
|
||||
"""
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
def get_m_alignment_for_contiguous_layout():
|
||||
"""
|
||||
When we do a grouped GEMM in contiguous format, LHS are grouped into several batches along the M axis.
|
||||
Since we deal with exactly one sub-matrix of RHS for each GEMM block, batch sizes above should align well
|
||||
with GEMM block shape.
|
||||
|
||||
Returns:
|
||||
Group-level alignment requirement for grouped contiguous layout, which is always 128.
|
||||
"""
|
||||
return 128
|
||||
|
||||
|
||||
def get_tma_aligned_size(x: int, element_size: int) -> int:
|
||||
"""
|
||||
Global memory address of TMA must be 16-byte aligned.
|
||||
Since we use column-major layout for the LHS scaling tensor,
|
||||
the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes.
|
||||
|
||||
Arguments:
|
||||
x: original M-axis shape of the LHS scaling tensor.
|
||||
element_size: element size of the LHS scaling tensor.
|
||||
|
||||
Returns:
|
||||
M-axis shape of the LHS scaling tensor after padding.
|
||||
"""
|
||||
tma_alignment_bytes = 16
|
||||
assert tma_alignment_bytes % element_size == 0
|
||||
alignment = tma_alignment_bytes // element_size
|
||||
return cell_div(x, alignment) * alignment
|
||||
|
||||
|
||||
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Returns TMA-aligned transposed format of the input tensor. `torch.transpose` will be called if necessary.
|
||||
If the input tensor is already column-major layout and 16-byte aligned along the M axis
|
||||
(thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing.
|
||||
|
||||
Arguments:
|
||||
x: usually the LHS scaling tensor in GEMM.
|
||||
|
||||
Returns:
|
||||
The LHS scaling tensor of TMA-aligned transposed format.
|
||||
"""
|
||||
# NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA
|
||||
assert x.dim() in (2, 3)
|
||||
remove_dim = False
|
||||
if x.dim() == 2:
|
||||
x, remove_dim = x.unsqueeze(0), True
|
||||
|
||||
b, m, n = x.shape
|
||||
aligned_m = get_tma_aligned_size(m, x.element_size())
|
||||
|
||||
# The last kernel gives a column-major TMA aligned layout
|
||||
if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m:
|
||||
return x.squeeze(0) if remove_dim else x
|
||||
|
||||
# Normal layout requires transposing
|
||||
aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
|
||||
aligned_x[:, :m, :] = x
|
||||
return aligned_x.squeeze(0) if remove_dim else aligned_x
|
||||
@ -0,0 +1,154 @@
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def bench(fn, num_warmups: int = 5, num_tests: int = 10,
|
||||
high_precision: bool = False):
|
||||
# Flush L2 cache with 256 MB data
|
||||
torch.cuda.synchronize()
|
||||
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
|
||||
cache.zero_()
|
||||
|
||||
# Warmup
|
||||
for _ in range(num_warmups):
|
||||
fn()
|
||||
|
||||
# Add a large kernel to eliminate the CPU launch overhead
|
||||
if high_precision:
|
||||
x = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
||||
y = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
||||
x @ y
|
||||
|
||||
# Testing
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event.record()
|
||||
for i in range(num_tests):
|
||||
fn()
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return start_event.elapsed_time(end_event) / num_tests
|
||||
|
||||
|
||||
class empty_suppress:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *_):
|
||||
pass
|
||||
|
||||
|
||||
class suppress_stdout_stderr:
|
||||
def __enter__(self):
|
||||
self.outnull_file = open(os.devnull, 'w')
|
||||
self.errnull_file = open(os.devnull, 'w')
|
||||
|
||||
self.old_stdout_fileno_undup = sys.stdout.fileno()
|
||||
self.old_stderr_fileno_undup = sys.stderr.fileno()
|
||||
|
||||
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
|
||||
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
|
||||
|
||||
self.old_stdout = sys.stdout
|
||||
self.old_stderr = sys.stderr
|
||||
|
||||
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
|
||||
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
|
||||
|
||||
sys.stdout = self.outnull_file
|
||||
sys.stderr = self.errnull_file
|
||||
return self
|
||||
|
||||
def __exit__(self, *_):
|
||||
sys.stdout = self.old_stdout
|
||||
sys.stderr = self.old_stderr
|
||||
|
||||
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
|
||||
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
|
||||
|
||||
os.close(self.old_stdout_fileno)
|
||||
os.close(self.old_stderr_fileno)
|
||||
|
||||
self.outnull_file.close()
|
||||
self.errnull_file.close()
|
||||
|
||||
|
||||
def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False,
|
||||
trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = False):
|
||||
# Conflict with Nsight Systems
|
||||
using_nsys = os.environ.get('DG_NSYS_PROFILING', False)
|
||||
|
||||
# For some auto-tuning kernels with prints
|
||||
fn()
|
||||
|
||||
# Profile
|
||||
suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress
|
||||
with suppress():
|
||||
schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) if not using_nsys else None
|
||||
profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress()
|
||||
with profiler:
|
||||
for i in range(2):
|
||||
# NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead
|
||||
if barrier_comm_profiling:
|
||||
lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
||||
rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
||||
lhs @ rhs
|
||||
dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda'))
|
||||
for _ in range(num_tests):
|
||||
if flush_l2:
|
||||
torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda').zero_()
|
||||
fn()
|
||||
|
||||
if not using_nsys:
|
||||
profiler.step()
|
||||
|
||||
# Return 1 if using Nsight Systems
|
||||
if using_nsys:
|
||||
return 1
|
||||
|
||||
# Parse the profiling table
|
||||
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
|
||||
is_tupled = isinstance(kernel_names, tuple)
|
||||
prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
|
||||
kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
|
||||
assert all([isinstance(name, str) for name in kernel_names])
|
||||
for name in kernel_names:
|
||||
assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table'
|
||||
|
||||
# Save chrome traces
|
||||
if trace_path is not None:
|
||||
profiler.export_chrome_trace(trace_path)
|
||||
|
||||
# Return average kernel times
|
||||
units = {'ms': 1e3, 'us': 1e6}
|
||||
kernel_times = []
|
||||
for name in kernel_names:
|
||||
for line in prof_lines:
|
||||
if name in line:
|
||||
time_str = line.split()[-2]
|
||||
for unit, scale in units.items():
|
||||
if unit in time_str:
|
||||
kernel_times.append(float(time_str.replace(unit, '')) / scale)
|
||||
break
|
||||
break
|
||||
return tuple(kernel_times) if is_tupled else kernel_times[0]
|
||||
|
||||
|
||||
def calc_diff(x, y):
|
||||
x, y = x.double(), y.double()
|
||||
denominator = (x * x + y * y).sum()
|
||||
sim = 2 * (x * y).sum() / denominator
|
||||
return 1 - sim
|
||||
|
||||
|
||||
def count_bytes(tensors):
|
||||
total = 0
|
||||
for t in tensors:
|
||||
if isinstance(t, tuple):
|
||||
total += count_bytes(t)
|
||||
else:
|
||||
total += t.numel() * t.element_size()
|
||||
return total
|
||||
@ -0,0 +1,444 @@
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wunknown-attributes"
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/arch/barrier.h>
|
||||
#include <cutlass/arch/reg_reconfig.h>
|
||||
|
||||
#include <cute/arch/cluster_sm90.hpp>
|
||||
#include <cute/arch/copy_sm90_desc.hpp>
|
||||
#include <cute/arch/copy_sm90_tma.hpp>
|
||||
|
||||
#include "mma_utils.cuh"
|
||||
#include "scheduler.cuh"
|
||||
#include "tma_utils.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
enum class Layout {
|
||||
RowMajor,
|
||||
ColMajor
|
||||
};
|
||||
|
||||
template <uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup>
|
||||
__device__ __host__ constexpr int get_num_threads_per_sm(int block_m) {
|
||||
DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group");
|
||||
return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads;
|
||||
}
|
||||
|
||||
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t kNumGroups, uint32_t kNumStages,
|
||||
uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup,
|
||||
uint32_t kNumTMAMulticast,
|
||||
GemmType kGemmType>
|
||||
__global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M), 1)
|
||||
fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
uint32_t shape_m,
|
||||
const __grid_constant__ CUtensorMap tensor_map_a,
|
||||
const __grid_constant__ CUtensorMap tensor_map_b,
|
||||
const __grid_constant__ CUtensorMap tensor_map_scales_a,
|
||||
const __grid_constant__ CUtensorMap tensor_map_d) {
|
||||
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
|
||||
// Scaling checks
|
||||
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
|
||||
DG_STATIC_ASSERT(cell_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block");
|
||||
|
||||
// Types
|
||||
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
|
||||
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
||||
|
||||
// Shared memory
|
||||
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16);
|
||||
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
|
||||
static constexpr uint32_t SHAPE_K_SCALES = cell_div(SHAPE_K, BLOCK_K);
|
||||
static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
|
||||
|
||||
// Configs
|
||||
constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K;
|
||||
constexpr uint32_t kNumThreads = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
|
||||
constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads;
|
||||
constexpr uint32_t kNumIterations = cell_div(SHAPE_K, kFullKOfAllStages);
|
||||
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
const uint32_t lane_idx = get_lane_id();
|
||||
|
||||
// Prefetch TMA descriptors at very beginning
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_a));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_b));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_scales_a));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_d));
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Align to 1024 bytes for swizzle-128B
|
||||
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
||||
DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
|
||||
|
||||
// Data on shared memory
|
||||
auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer);
|
||||
__nv_fp8_e4m3* smem_a[kNumStages];
|
||||
__nv_fp8_e4m3* smem_b[kNumStages];
|
||||
float* smem_scales_a[kNumStages];
|
||||
float* smem_scales_b;
|
||||
|
||||
// TMA Barrier for both divisible and non-divisible cases
|
||||
Barrier* full_barriers[kNumStages];
|
||||
Barrier* empty_barriers[kNumStages];
|
||||
|
||||
// Fill shared memory pointers
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumStages; ++ i) {
|
||||
smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
||||
smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
||||
smem_scales_a[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SCALES_A_SIZE_PER_STAGE);
|
||||
}
|
||||
smem_scales_b = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE));
|
||||
|
||||
// Fill barriers
|
||||
DG_STATIC_ASSERT(sizeof(Barrier) % sizeof(float) == 0, "Misaligned barriers");
|
||||
DG_STATIC_ASSERT(not kMustUseUniformedScaleB or SHAPE_K_SCALES % (sizeof(Barrier) / sizeof(float)) == 0, "Misaligned barriers");
|
||||
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_scales_b + SHAPE_K_SCALES * (kMustUseUniformedScaleB ? 1 : 2));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumStages; ++ i) {
|
||||
full_barriers[i] = barrier_start_ptr + i;
|
||||
empty_barriers[i] = barrier_start_ptr + kNumStages + i;
|
||||
}
|
||||
|
||||
// Initialize barriers
|
||||
DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "To many TMA multicast");
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumStages; ++ i) {
|
||||
full_barriers[i]->init(1);
|
||||
empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32);
|
||||
}
|
||||
|
||||
// Make initialized barrier visible in async proxy
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
(kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void();
|
||||
}
|
||||
|
||||
// Synchronize all threads to make barrier visible in normal memory model
|
||||
(kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
|
||||
|
||||
// For pipeline unrolling
|
||||
struct DivisibleK {};
|
||||
struct NotDivisibleK {};
|
||||
auto launch_k_iterations = [](const auto& func) {
|
||||
if constexpr (SHAPE_K % kFullKOfAllStages == 0) {
|
||||
for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter)
|
||||
func(k_iter, DivisibleK{});
|
||||
} else {
|
||||
for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter)
|
||||
func(k_iter, DivisibleK{});
|
||||
func(kNumIterations - 1, NotDivisibleK{});
|
||||
}
|
||||
};
|
||||
|
||||
// Register reconfigurations
|
||||
constexpr int kNumTMARegisters = 40;
|
||||
constexpr int kNumMathRegisters = 232;
|
||||
|
||||
// Block scheduler
|
||||
uint32_t m_block_idx, n_block_idx;
|
||||
auto scheduler = Scheduler<kGemmType, SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast>(shape_m, grouped_layout);
|
||||
|
||||
if (threadIdx.x >= kNumMathThreads) {
|
||||
// TMA warp-group for loading data
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
||||
|
||||
// NOTES: only one thread (or warp) will be used
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
launch_k_iterations([&](int k_iter, auto type) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
|
||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Wait consumer release
|
||||
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
|
||||
|
||||
// Issue TMA A with broadcasting
|
||||
auto& full_barrier = *full_barriers[s];
|
||||
int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
|
||||
tma_copy<kNumTMAMulticast>(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
|
||||
tma_copy<kNumTMAMulticast>(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_scales_a[s], m_block_idx * BLOCK_M,
|
||||
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
|
||||
|
||||
// Issue TMA B without broadcasting
|
||||
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
|
||||
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE);
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
|
||||
full_barriers[s]->arrive();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// To safely deconstruct distributed shared barriers, we need another round of empty waits
|
||||
if constexpr (kNumTMAMulticast > 1) {
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumStages; ++ s)
|
||||
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Math warp-groups for WGMMA
|
||||
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
||||
|
||||
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
|
||||
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0);
|
||||
const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8;
|
||||
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
// Decide the number of scales B to load
|
||||
DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N");
|
||||
uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters;
|
||||
if constexpr (not kMustUseUniformedScaleB) {
|
||||
num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8;
|
||||
num_full_iters = min(SHAPE_N - n_block_idx * BLOCK_N, BLOCK_N) / 8;
|
||||
}
|
||||
uint32_t num_scales_b = SHAPE_K_SCALES * (num_former_iters >= num_full_iters ? 1 : 2);
|
||||
|
||||
// Load B scales with math warp-groups
|
||||
// NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks
|
||||
if (threadIdx.x >= 32) {
|
||||
auto num_previous_lines = scheduler.get_global_idx<false>(cell_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx);
|
||||
auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES;
|
||||
#pragma unroll
|
||||
for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32)
|
||||
st_shared(smem_scales_b + i, __ldg(local_scales_b + i));
|
||||
}
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
|
||||
// Accumulation for WGMMA or CUDA promotion
|
||||
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0};
|
||||
|
||||
// Empty barrier arrival
|
||||
auto empty_barrier_arrive = [&](int s) {
|
||||
if constexpr (kNumTMAMulticast == 1) {
|
||||
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
|
||||
} else {
|
||||
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void();
|
||||
}
|
||||
};
|
||||
|
||||
// Launch MMAs
|
||||
launch_k_iterations([&](int k_iter, auto type) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
|
||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||
|
||||
#pragma unroll
|
||||
for (int s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Read B scales
|
||||
float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1;
|
||||
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
|
||||
if constexpr (not kMustUseUniformedScaleB)
|
||||
scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES);
|
||||
|
||||
// Wait TMA arrivals
|
||||
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
|
||||
|
||||
// Read A scales
|
||||
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
|
||||
auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1);
|
||||
|
||||
// Commit WGMMA instructions
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_arrive();
|
||||
#pragma unroll
|
||||
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
|
||||
auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1);
|
||||
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
|
||||
WGMMA::wgmma(desc_a, desc_b, accum, k);
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_wait<0>();
|
||||
|
||||
// Notify barrier arrival
|
||||
empty_barrier_arrive(s);
|
||||
|
||||
// Promote with scales
|
||||
float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0;
|
||||
float scale_0_1, scale_1_1;
|
||||
if constexpr (not kMustUseUniformedScaleB)
|
||||
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
||||
bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
|
||||
final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
|
||||
final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
|
||||
final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
|
||||
final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
|
||||
}
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
|
||||
empty_barrier_arrive(s);
|
||||
}
|
||||
});
|
||||
|
||||
// Write back to shared memory using STSM
|
||||
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
|
||||
#pragma unroll
|
||||
for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) {
|
||||
SM90_U32x4_STSM_N<nv_bfloat162>::copy(
|
||||
__float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}),
|
||||
__float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}),
|
||||
__float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}),
|
||||
__float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}),
|
||||
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16)
|
||||
);
|
||||
}
|
||||
if constexpr (WGMMA::kNumAccum % 8 != 0) {
|
||||
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
|
||||
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}),
|
||||
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}),
|
||||
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16
|
||||
);
|
||||
}
|
||||
cute::tma_store_fence();
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
|
||||
// Use TMA store to write back to global memory
|
||||
if (threadIdx.x == 0) {
|
||||
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N,
|
||||
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
|
||||
cute::tma_store_arrive();
|
||||
cute::tma_store_wait<0>();
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
}
|
||||
#else
|
||||
if (blockIdx.x == 0 and threadIdx.x == 0)
|
||||
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t kNumGroups, uint32_t kNumStages,
|
||||
uint32_t kNumTMAMulticast,
|
||||
GemmType kGemmType>
|
||||
class Gemm {
|
||||
private:
|
||||
using Barrier = cuda::barrier<cuda::thread_scope_block>;
|
||||
|
||||
public:
|
||||
Gemm() = default;
|
||||
|
||||
static void run(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
uint32_t shape_m,
|
||||
const CUtensorMap& tma_a_desc,
|
||||
const CUtensorMap& tma_b_desc,
|
||||
const CUtensorMap& tma_scales_a_desc,
|
||||
const CUtensorMap& tma_d_desc,
|
||||
cudaStream_t stream,
|
||||
int num_sms, uint32_t smem_size) {
|
||||
// NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps
|
||||
constexpr uint32_t kNumTMAThreads = 128;
|
||||
constexpr uint32_t kNumMathThreadsPerGroup = 128;
|
||||
auto kernel = fp8_gemm_kernel<SHAPE_N, SHAPE_K, BLOCK_M, BLOCK_N, BLOCK_K,
|
||||
kNumGroups, kNumStages, kNumTMAThreads, kNumMathThreadsPerGroup,
|
||||
kNumTMAMulticast, kGemmType>;
|
||||
DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess);
|
||||
|
||||
// Cluster launch
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = num_sms;
|
||||
config.blockDim = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
|
||||
config.dynamicSmemBytes = smem_size;
|
||||
config.stream = stream;
|
||||
|
||||
// Clusters for TMA multicast
|
||||
// NOTES: `>= 4` cluster size will cause performance degradation
|
||||
cudaLaunchAttribute attr;
|
||||
attr.id = cudaLaunchAttributeClusterDimension;
|
||||
attr.val.clusterDim = {kNumTMAMulticast, 1, 1};
|
||||
config.attrs = &attr;
|
||||
config.numAttrs = 1;
|
||||
|
||||
// Launch
|
||||
auto status = cudaLaunchKernelEx(&config, kernel,
|
||||
gmem_d, scales_b, grouped_layout,
|
||||
shape_m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc);
|
||||
DG_HOST_ASSERT(status == cudaSuccess);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_a_desc(T* global_address, uint32_t shape_m) {
|
||||
return make_2d_tma_desc(global_address, Layout::RowMajor,
|
||||
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_K, BLOCK_M, BLOCK_K);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_b_desc(T* global_address) {
|
||||
return make_2d_tma_desc(global_address, Layout::ColMajor,
|
||||
SHAPE_K, SHAPE_N * (kGemmType != GemmType::Normal ? kNumGroups : 1), BLOCK_K, BLOCK_N);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) {
|
||||
return make_2d_tma_desc(global_address, Layout::RowMajor,
|
||||
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N, BLOCK_M, BLOCK_N,
|
||||
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) {
|
||||
// Make TMA aligned to 16 bytes
|
||||
constexpr uint32_t kAlignment = 16 / sizeof(T);
|
||||
shape_m = cell_div(shape_m, kAlignment) * kAlignment;
|
||||
|
||||
return make_2d_tma_desc(global_address, Layout::ColMajor,
|
||||
shape_m, cell_div(SHAPE_K, BLOCK_K) * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), BLOCK_M, 1,
|
||||
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_desc(
|
||||
T* global_address, Layout layout,
|
||||
uint32_t gmem_rows, uint32_t gmem_cols,
|
||||
uint32_t smem_rows, uint32_t smem_cols,
|
||||
CUtensorMapSwizzle swizzle_type = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) {
|
||||
if (layout == Layout::RowMajor) {
|
||||
uint64_t gmem_dim[2] = {gmem_cols, gmem_rows};
|
||||
uint32_t smem_dim[2] = {smem_cols, smem_rows};
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_cols * sizeof(T), smem_dim, swizzle_type);
|
||||
} else {
|
||||
uint64_t gmem_dim[2] = {gmem_rows, gmem_cols};
|
||||
uint32_t smem_dim[2] = {smem_rows, smem_cols};
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_rows * sizeof(T), smem_dim, swizzle_type);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
}; // namespace deep_gemm
|
||||
|
||||
#pragma clang diagnostic pop
|
||||
@ -0,0 +1,885 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
struct SM90_64x16x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %10, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
||||
" %8,"
|
||||
" %9,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 16;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x24x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %14, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11},"
|
||||
" %12,"
|
||||
" %13,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 24;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x32x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %18, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
||||
" %16,"
|
||||
" %17,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 32;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x40x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %22, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n40k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19},"
|
||||
" %20,"
|
||||
" %21,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 40;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x48x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %26, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23},"
|
||||
" %24,"
|
||||
" %25,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 48;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x56x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %30, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n56k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27}, "
|
||||
" %28,"
|
||||
" %29,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 56;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x64x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %34, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31}, "
|
||||
" %32,"
|
||||
" %33,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 64;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x72x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %38, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n72k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35}, "
|
||||
" %36,"
|
||||
" %37,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 72;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x80x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %42, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39}, "
|
||||
" %40,"
|
||||
" %41,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 80;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x88x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %46, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n88k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43}, "
|
||||
" %44,"
|
||||
" %45,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 88;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x96x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %50, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47}, "
|
||||
" %48,"
|
||||
" %49,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 96;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x104x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
float& d48, float& d49, float& d50, float& d51,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %54, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n104k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
||||
" %48, %49, %50, %51}, "
|
||||
" %52,"
|
||||
" %53,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
||||
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
d[48], d[49], d[50], d[51],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 104;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x112x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %58, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
||||
" %48, %49, %50, %51, %52, %53, %54, %55}, "
|
||||
" %56,"
|
||||
" %57,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
||||
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 112;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x120x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
|
||||
float& d56, float& d57, float& d58, float& d59,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %62, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n120k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
||||
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
||||
" %56, %57, %58, %59}, "
|
||||
" %60,"
|
||||
" %61,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
||||
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
||||
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
|
||||
d[56], d[57], d[58], d[59],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 120;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x128x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
|
||||
float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %66, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
||||
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
||||
" %56, %57, %58, %59, %60, %61, %62, %63}, "
|
||||
" %64,"
|
||||
" %65,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
||||
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
||||
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
|
||||
d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 128;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x192x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
|
||||
float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63,
|
||||
float& d64, float& d65, float& d66, float& d67, float& d68, float& d69, float& d70, float& d71,
|
||||
float& d72, float& d73, float& d74, float& d75, float& d76, float& d77, float& d78, float& d79,
|
||||
float& d80, float& d81, float& d82, float& d83, float& d84, float& d85, float& d86, float& d87,
|
||||
float& d88, float& d89, float& d90, float& d91, float& d92, float& d93, float& d94, float& d95,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %98, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
||||
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
||||
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
||||
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
||||
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
||||
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
||||
" %88, %89, %90, %91, %92, %93, %94, %95}, "
|
||||
" %96,"
|
||||
" %97,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
||||
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
||||
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63),
|
||||
"+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71),
|
||||
"+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79),
|
||||
"+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87),
|
||||
"+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
|
||||
d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63],
|
||||
d[64], d[65], d[66], d[67], d[68], d[69], d[70], d[71],
|
||||
d[72], d[73], d[74], d[75], d[76], d[77], d[78], d[79],
|
||||
d[80], d[81], d[82], d[83], d[84], d[85], d[86], d[87],
|
||||
d[88], d[89], d[90], d[91], d[92], d[93], d[94], d[95],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 192;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
template <typename dtype_t>
|
||||
struct SM90_U32x2_STSM_N {
|
||||
__device__ __forceinline__ static void
|
||||
copy(dtype_t src_0, dtype_t src_1, void* smem_dst) {
|
||||
const uint32_t src[2] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1)};
|
||||
asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n"
|
||||
:: "l"(smem_dst), "r"(src[0]), "r"(src[1]));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename dtype_t>
|
||||
struct SM90_U32x4_STSM_N {
|
||||
__device__ __forceinline__ static void
|
||||
copy(dtype_t src_0, dtype_t src_1, dtype_t src_2, dtype_t src_3, void* smem_dst) {
|
||||
const uint32_t src[4] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1),
|
||||
*reinterpret_cast<uint32_t*>(&src_2), *reinterpret_cast<uint32_t*>(&src_3)};
|
||||
asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n"
|
||||
:: "l"(smem_dst), "r"(src[0]), "r"(src[1]), "r"(src[2]), "r"(src[3]));
|
||||
}
|
||||
};
|
||||
|
||||
__device__ void warpgroup_arrive() {
|
||||
asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory");
|
||||
}
|
||||
|
||||
__device__ void warpgroup_commit_batch() {
|
||||
asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory");
|
||||
}
|
||||
|
||||
__device__ void warpgroup_fence_operand(float& reg) {
|
||||
asm volatile("" : "+f"(reg) :: "memory");
|
||||
}
|
||||
|
||||
__forceinline__ __device__ uint32_t get_lane_id() {
|
||||
uint32_t lane_id;
|
||||
asm("mov.u32 %0, %laneid;" : "=r"(lane_id));
|
||||
return lane_id;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint32_t ld_shared(const uint32_t* __restrict__ ptr) {
|
||||
uint32_t ret;
|
||||
asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int4 ld_shared(const int4* __restrict__ ptr) {
|
||||
int4 ret;
|
||||
asm volatile("ld.shared.v4.s32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float ld_shared(const float* __restrict__ ptr) {
|
||||
float ret;
|
||||
asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_shared(const float* ptr, float val) {
|
||||
asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) {
|
||||
asm volatile("st.shared.u32 [%0], %1;" :: "l"(ptr), "r"(val));
|
||||
}
|
||||
|
||||
template <int N>
|
||||
__device__ void warpgroup_wait() {
|
||||
DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]");
|
||||
asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory");
|
||||
}
|
||||
|
||||
union GmmaDescriptor {
|
||||
__host__ __device__ constexpr GmmaDescriptor() noexcept: desc_(0) {}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor(uint64_t desc) noexcept: desc_(desc) {}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor const &t) noexcept: desc_(t.desc_) {}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor &&t) noexcept: desc_(t.desc_) {}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor const &t) noexcept {
|
||||
desc_ = t.desc_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor &&t) noexcept {
|
||||
desc_ = t.desc_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
uint64_t desc_;
|
||||
uint32_t reg32_[2];
|
||||
uint16_t reg16_[4];
|
||||
|
||||
struct {
|
||||
uint16_t start_address_: 14, : 2;
|
||||
uint16_t leading_byte_offset_: 14, : 2;
|
||||
uint16_t stride_byte_offset_: 14, : 2;
|
||||
uint8_t : 1, base_offset_: 3, : 4;
|
||||
uint8_t : 6, layout_type_: 2;
|
||||
} bitfield;
|
||||
|
||||
// Decay to an `uint64_t`
|
||||
__host__ __device__ constexpr operator uint64_t() const noexcept { return desc_; }
|
||||
};
|
||||
|
||||
template <class PointerType>
|
||||
__device__ GmmaDescriptor make_smem_desc(PointerType smem_ptr, int layout_type,
|
||||
int leading_byte_offset = 0,
|
||||
int stride_byte_offset = 1024) {
|
||||
GmmaDescriptor desc;
|
||||
auto uint_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
desc.bitfield.start_address_ = uint_ptr >> 4;
|
||||
desc.bitfield.layout_type_ = layout_type;
|
||||
desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4;
|
||||
desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4;
|
||||
desc.bitfield.base_offset_ = 0;
|
||||
return desc;
|
||||
}
|
||||
|
||||
template <int N>
|
||||
struct FP8MMASelector {
|
||||
static constexpr auto select_type() {
|
||||
if constexpr (N == 16) return SM90_64x16x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 24) return SM90_64x24x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 32) return SM90_64x32x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 40) return SM90_64x40x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 48) return SM90_64x48x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 56) return SM90_64x56x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 64) return SM90_64x64x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 72) return SM90_64x72x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 80) return SM90_64x80x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 88) return SM90_64x88x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 96) return SM90_64x96x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 104) return SM90_64x104x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 112) return SM90_64x112x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 120) return SM90_64x120x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 128) return SM90_64x128x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 192) return SM90_64x192x32_F32E4M3E4M3_SS();
|
||||
}
|
||||
|
||||
using type = decltype(select_type());
|
||||
};
|
||||
|
||||
} // namespace deep_gemm
|
||||
@ -0,0 +1,103 @@
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
enum class GemmType {
|
||||
Normal,
|
||||
GroupedContiguous,
|
||||
GroupedMasked
|
||||
};
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init"
|
||||
template <GemmType kGemmType,
|
||||
uint32_t SHAPE_N, uint32_t BLOCK_M, uint32_t BLOCK_N,
|
||||
uint32_t kNumGroups, uint32_t kNumTMAMulticast,
|
||||
uint32_t kNumNBlocks = cell_div(SHAPE_N, BLOCK_N),
|
||||
uint32_t kNumNBlocksPerGroup = 16>
|
||||
struct Scheduler {
|
||||
int current_iter = -1;
|
||||
uint32_t num_aligned_m_blocks;
|
||||
|
||||
// For normal GEMM
|
||||
// Maybe not used in the masked grouped GEMM
|
||||
uint32_t num_blocks;
|
||||
|
||||
// For grouped GEMM
|
||||
int* grouped_layout;
|
||||
// Only used for masked layout
|
||||
uint32_t curr_group_idx, curr_cumsum;
|
||||
|
||||
__device__ __forceinline__ explicit Scheduler(const uint32_t shape_m,
|
||||
int* grouped_layout = nullptr) {
|
||||
num_aligned_m_blocks = cell_div(shape_m, BLOCK_M);
|
||||
if constexpr (kGemmType == GemmType::Normal) {
|
||||
num_blocks = num_aligned_m_blocks * kNumNBlocks;
|
||||
} else if (kGemmType == GemmType::GroupedContiguous) {
|
||||
num_blocks = num_aligned_m_blocks * kNumNBlocks;
|
||||
this->grouped_layout = grouped_layout;
|
||||
} else if (kGemmType == GemmType::GroupedMasked) {
|
||||
curr_group_idx = curr_cumsum = 0;
|
||||
this->grouped_layout = grouped_layout;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) {
|
||||
DG_STATIC_ASSERT(kNumNBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
|
||||
|
||||
// Swizzle for better L2 usages
|
||||
auto num_blocks_per_group = num_m_blocks * kNumNBlocksPerGroup;
|
||||
auto group_idx = block_idx / num_blocks_per_group;
|
||||
auto first_n_block_idx = group_idx * kNumNBlocksPerGroup;
|
||||
auto num_n_blocks_in_group = min(kNumNBlocksPerGroup, kNumNBlocks - first_n_block_idx);
|
||||
auto in_group_idx = block_idx % num_blocks_per_group;
|
||||
m_block_idx = in_group_idx / num_n_blocks_in_group;
|
||||
n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group;
|
||||
}
|
||||
|
||||
template <bool kIgnoreGroupedForGroupedContiguous=true>
|
||||
__device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size,
|
||||
const uint32_t& block_idx, const uint32_t& m_block_idx=0) {
|
||||
if constexpr (kGemmType == GemmType::Normal) {
|
||||
return block_idx * block_size;
|
||||
} else if (kGemmType == GemmType::GroupedContiguous) {
|
||||
auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg(grouped_layout + m_block_idx * BLOCK_M);
|
||||
return offset * shape_dim + block_idx * block_size;
|
||||
} else if (kGemmType == GemmType::GroupedMasked) {
|
||||
return curr_group_idx * shape_dim + block_idx * block_size;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) {
|
||||
const auto next_block_idx = (++ current_iter) * gridDim.x + blockIdx.x;
|
||||
|
||||
if constexpr (kGemmType == GemmType::GroupedMasked) {
|
||||
uint32_t num_m_blocks;
|
||||
while (true) {
|
||||
// End of the task
|
||||
if (curr_group_idx == kNumGroups)
|
||||
return false;
|
||||
|
||||
// Within current group
|
||||
num_m_blocks = cell_div(static_cast<uint32_t>(__ldg(grouped_layout + curr_group_idx)), BLOCK_M);
|
||||
auto current_m_block_cumsum = curr_cumsum + num_m_blocks;
|
||||
if (next_block_idx < current_m_block_cumsum * kNumNBlocks)
|
||||
break;
|
||||
|
||||
// Move to check the next group
|
||||
curr_group_idx ++, curr_cumsum = current_m_block_cumsum;
|
||||
}
|
||||
|
||||
get_swizzled_block_idx(num_m_blocks, next_block_idx - curr_cumsum * kNumNBlocks, m_block_idx, n_block_idx);
|
||||
} else {
|
||||
if (next_block_idx >= num_blocks)
|
||||
return false;
|
||||
|
||||
get_swizzled_block_idx(num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
} // namespace deep_gemm
|
||||
@ -0,0 +1,96 @@
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <cuda.h>
|
||||
#include <cudaTypedefs.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda/barrier>
|
||||
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
template <class T>
|
||||
constexpr CUtensorMapDataType get_CUtensorMapDataType() {
|
||||
if constexpr (std::is_same<T, uint8_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
} else if constexpr (std::is_same<T, __nv_fp8_e4m3>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
} else if constexpr (std::is_same<T, __nv_fp8_e5m2>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
} else if constexpr (std::is_same<T, uint16_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT16;
|
||||
} else if constexpr (std::is_same<T, uint32_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT32;
|
||||
} else if constexpr (std::is_same<T, uint64_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT64;
|
||||
} else if constexpr (std::is_same<T, int32_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_INT32;
|
||||
} else if constexpr (std::is_same<T, int64_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_INT64;
|
||||
} else if constexpr (std::is_same<T, __half>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
|
||||
} else if constexpr (std::is_same<T, float>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
|
||||
} else if constexpr (std::is_same<T, double>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_FLOAT64;
|
||||
}
|
||||
}
|
||||
|
||||
PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() {
|
||||
// Get pointer to `cuTensorMapEncodeTiled`
|
||||
cudaDriverEntryPointQueryResult driver_status;
|
||||
void* cuTensorMapEncodeTiled_ptr = nullptr;
|
||||
|
||||
#if CUDA_VERSION >= 12050
|
||||
cudaGetDriverEntryPointByVersion("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, 12000,
|
||||
cudaEnableDefault, &driver_status);
|
||||
#else
|
||||
cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr,
|
||||
cudaEnableDefault, &driver_status);
|
||||
#endif
|
||||
|
||||
if (driver_status != cudaDriverEntryPointSuccess)
|
||||
throw std::runtime_error("driver_status != cudaDriverEntryPointSuccess");
|
||||
return reinterpret_cast<PFN_cuTensorMapEncodeTiled>(cuTensorMapEncodeTiled_ptr);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2],
|
||||
uint64_t stride_in_bytes, uint32_t smem_dim[2],
|
||||
CUtensorMapSwizzle swizzle_type,
|
||||
PFN_cuTensorMapEncodeTiled encode_func = nullptr) {
|
||||
CUtensorMap tensor_map{};
|
||||
constexpr uint32_t rank = 2;
|
||||
uint64_t global_stride[rank - 1] = {stride_in_bytes};
|
||||
uint32_t elem_strides[rank] = {1, 1};
|
||||
|
||||
if (encode_func == nullptr)
|
||||
encode_func = get_cuTensorMapEncodeTiled();
|
||||
|
||||
auto result = encode_func(
|
||||
&tensor_map, get_CUtensorMapDataType<typename std::remove_cv<T>::type>(), rank,
|
||||
global_address, gmem_dim, global_stride, smem_dim, elem_strides,
|
||||
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type,
|
||||
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
|
||||
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
|
||||
DG_HOST_ASSERT(result == CUDA_SUCCESS);
|
||||
return tensor_map;
|
||||
}
|
||||
|
||||
template <uint32_t kNumTMAMulticast = 1>
|
||||
__device__ __forceinline__ void
|
||||
tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,
|
||||
int32_t const& crd_0, int32_t const& crd_1) {
|
||||
constexpr auto cache_hint = static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL);
|
||||
if constexpr (kNumTMAMulticast == 1) {
|
||||
cute::SM90_TMA_LOAD_2D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1);
|
||||
} else if (cute::block_rank_in_cluster() == 0) {
|
||||
cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, (1 << kNumTMAMulticast) - 1, cache_hint, smem_ptr, crd_0, crd_1);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
@ -0,0 +1,48 @@
|
||||
#pragma once
|
||||
|
||||
#include <exception>
|
||||
|
||||
#ifdef __CLION_IDE__
|
||||
__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { asm volatile("trap;"); }
|
||||
#define printf host_device_printf
|
||||
#endif
|
||||
|
||||
class AssertionException : public std::exception {
|
||||
private:
|
||||
std::string message{};
|
||||
|
||||
public:
|
||||
explicit AssertionException(const std::string& message) : message(message) {}
|
||||
|
||||
const char *what() const noexcept override { return message.c_str(); }
|
||||
};
|
||||
|
||||
#ifndef DG_HOST_ASSERT
|
||||
#define DG_HOST_ASSERT(cond) \
|
||||
do { \
|
||||
if (not (cond)) { \
|
||||
printf("Assertion failed: %s:%d, condition: %s\n", \
|
||||
__FILE__, __LINE__, #cond); \
|
||||
throw AssertionException("Assertion failed: " #cond); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#ifndef DG_DEVICE_ASSERT
|
||||
#define DG_DEVICE_ASSERT(cond) \
|
||||
do { \
|
||||
if (not (cond)) { \
|
||||
printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \
|
||||
asm("trap;"); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#ifndef DG_STATIC_ASSERT
|
||||
#define DG_STATIC_ASSERT(cond, reason) static_assert(cond, reason)
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
__device__ __host__ constexpr T cell_div(T a, T b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
@ -0,0 +1,69 @@
|
||||
import os
|
||||
import setuptools
|
||||
import shutil
|
||||
import subprocess
|
||||
from setuptools.command.develop import develop
|
||||
from setuptools.command.install import install
|
||||
|
||||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
jit_include_dirs = ('deep_gemm/include/deep_gemm', )
|
||||
cutlass_dirs = '../../include'
|
||||
third_party_include_dirs = (os.path.join(cutlass_dirs, 'cute'), os.path.join(cutlass_dirs, 'cutlass'))
|
||||
print(third_party_include_dirs)
|
||||
|
||||
|
||||
class PostDevelopCommand(develop):
|
||||
def run(self):
|
||||
develop.run(self)
|
||||
self.make_jit_include_symlinks()
|
||||
|
||||
@staticmethod
|
||||
def make_jit_include_symlinks():
|
||||
# Make symbolic links of third-party include directories
|
||||
for d in third_party_include_dirs:
|
||||
dirname = d.split('/')[-1]
|
||||
src_dir = f'{current_dir}/{d}'
|
||||
dst_dir = f'{current_dir}/deep_gemm/include/{dirname}'
|
||||
if not os.path.exists(src_dir):
|
||||
os.makedirs(src_dir, exist_ok=True)
|
||||
assert os.path.exists(src_dir)
|
||||
if os.path.exists(dst_dir):
|
||||
assert os.path.islink(dst_dir)
|
||||
os.unlink(dst_dir)
|
||||
os.symlink(src_dir, dst_dir, target_is_directory=True)
|
||||
|
||||
|
||||
class PostInstallCommand(install):
|
||||
def run(self):
|
||||
install.run(self)
|
||||
self.copy_jit_includes()
|
||||
|
||||
def copy_jit_includes(self):
|
||||
# Copy include directories needed by JIT
|
||||
shutil.rmtree(f'{self.build_lib}/deep_gemm/include', ignore_errors=True)
|
||||
os.makedirs(f'{self.build_lib}/deep_gemm/include', exist_ok=False)
|
||||
for d in jit_include_dirs + third_party_include_dirs:
|
||||
src_dir = f'{current_dir}/{d}'
|
||||
dst_dir = f'{self.build_lib}/deep_gemm/include/{d.split("/")[-1]}'
|
||||
assert os.path.exists(src_dir)
|
||||
shutil.copytree(src_dir, dst_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
cmd = ['git', 'rev-parse', '--short', 'HEAD']
|
||||
revision = '+' + subprocess.check_output(cmd).decode('ascii').rstrip()
|
||||
except:
|
||||
revision = ''
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
setuptools.setup(
|
||||
name='deep_gemm',
|
||||
version='1.0.0' + revision,
|
||||
packages=['deep_gemm', 'deep_gemm/jit', 'deep_gemm/jit_kernels'],
|
||||
cmdclass={
|
||||
'develop': PostDevelopCommand,
|
||||
'install': PostInstallCommand
|
||||
}
|
||||
)
|
||||
@ -0,0 +1,158 @@
|
||||
import random
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
import deep_gemm
|
||||
from deep_gemm import bench_kineto, calc_diff, cell_div, get_col_major_tma_aligned_tensor
|
||||
|
||||
|
||||
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2 and x.size(1) % 128 == 0
|
||||
m, n = x.shape
|
||||
x_view = x.view(m, -1, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros((cell_div(m, 128) * 128, cell_div(n, 128) * 128), dtype=x.dtype, device=x.device)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
||||
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
|
||||
|
||||
|
||||
def construct(m: int, k: int, n: int) -> \
|
||||
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
|
||||
x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
|
||||
y = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
|
||||
out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
|
||||
ref_out = x @ y.t()
|
||||
|
||||
x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_block_cast_to_fp8(y)
|
||||
# Transpose earlier so that the testing will not trigger transposing kernels
|
||||
x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
|
||||
return x_fp8, y_fp8, out, ref_out
|
||||
|
||||
|
||||
def construct_grouped(num_groups: int, m: int, k: int, n: int, is_masked: bool) -> \
|
||||
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
|
||||
x = torch.randn((num_groups, m, k), device='cuda', dtype=torch.bfloat16)
|
||||
y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16)
|
||||
out = torch.empty((num_groups, m, n), device='cuda', dtype=torch.bfloat16)
|
||||
ref_out = torch.einsum('gmk,gnk->gmn', x, y)
|
||||
|
||||
assert m % 4 == 0, f'TMA alignment error: {m}'
|
||||
x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), torch.empty((num_groups, m, k // 128), device='cuda', dtype=torch.float))
|
||||
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, (n + 127) // 128, k // 128), device='cuda', dtype=torch.float))
|
||||
for i in range(num_groups):
|
||||
x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i])
|
||||
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
|
||||
|
||||
# For non-masked input, we must merge the group and M dims
|
||||
if not is_masked:
|
||||
x_fp8 = (x_fp8[0].view(-1, k), per_token_cast_to_fp8(x.view(-1, k))[1])
|
||||
out, ref_out = out.view(-1, n), ref_out.view(-1, n)
|
||||
|
||||
# Transpose earlier so that the testing will not trigger transposing kernels
|
||||
x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
|
||||
return x_fp8, y_fp8, out, ref_out
|
||||
|
||||
|
||||
def test_gemm() -> None:
|
||||
print('Testing GEMM:')
|
||||
for m in (64, 128, 4096):
|
||||
for k, n in [(7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)]:
|
||||
x_fp8, y_fp8, out, ref_out = construct(m, k, n)
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out)
|
||||
diff = calc_diff(out, ref_out)
|
||||
assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}'
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
# Construct new tensors every time to avoid L2 cache acceleration
|
||||
x_fp8, y_fp8, out, ref_out = construct(m, k, n)
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out)
|
||||
|
||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||
print(f' > Performance (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | '
|
||||
f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, '
|
||||
f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
def test_m_grouped_gemm_contiguous() -> None:
|
||||
print('Testing grouped contiguous GEMM:')
|
||||
|
||||
for num_groups, m, k, n in ((4, 8192, 7168, 4096), (4, 8192, 2048, 7168), (8, 4096, 7168, 4096), (8, 4096, 2048, 7168)):
|
||||
# TODO: make a stronger test
|
||||
x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=False)
|
||||
m_indices = torch.arange(0, num_groups, device='cuda', dtype=torch.int)
|
||||
m_indices = m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1)
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices)
|
||||
diff = calc_diff(out, ref_out)
|
||||
assert diff < 0.001, f'm={m * num_groups}, {k=}, {n=}, {diff:.5f}'
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
# Construct new tensors every time to avoid L2 cache acceleration
|
||||
x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=False)
|
||||
m_indices = torch.arange(0, num_groups, device='cuda', dtype=torch.int)
|
||||
m_indices = m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1)
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices)
|
||||
|
||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||
print(f' > Performance ({num_groups=}, m_per_group={m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
|
||||
f'throughput: {2 * num_groups * m * n * k / t / 1e12:4.0f} TFLOPS, '
|
||||
f'{(num_groups * (m * k + k * n + m * n * 2)) / 1e9 / t:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
def test_m_grouped_gemm_masked() -> None:
|
||||
print('Testing grouped masked GEMM:')
|
||||
|
||||
for num_groups, m in ((1, 1024), (2, 512), (4, 256)):
|
||||
for k, n in ((7168, 4096), (2048, 7168), ):
|
||||
# Test correctness
|
||||
masked_m_candidates = list(filter(lambda candidate: candidate <= m, (64, 128, 192, 256, 320, 384)))
|
||||
for i in range(10):
|
||||
x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True)
|
||||
masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int)
|
||||
for j in range(num_groups):
|
||||
masked_m[j] = random.choice(masked_m_candidates)
|
||||
expected_m = int(masked_m.float().mean()) + 1
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m)
|
||||
for j in range(num_groups):
|
||||
diff = calc_diff(out[j, :masked_m[j].item()], ref_out[j, :masked_m[j].item()])
|
||||
assert diff < 0.001, f'{m=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}'
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
# Construct new tensors every time to avoid L2 cache acceleration
|
||||
x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True)
|
||||
masked_m = torch.ones((num_groups, ), device='cuda', dtype=torch.int) * m
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, m)
|
||||
|
||||
# Test performance with fixed shapes
|
||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||
print(f' > Performance ({num_groups=}, m_per_group={m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
|
||||
f'throughput: {2 * num_groups * m * n * k / t / 1e12:4.0f} TFLOPS, '
|
||||
f'{(num_groups * (m * k + k * n + m * n * 2)) / 1e9 / t:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
|
||||
print('Library path:')
|
||||
print(f' > {deep_gemm.__path__}\n')
|
||||
|
||||
test_gemm()
|
||||
test_m_grouped_gemm_contiguous()
|
||||
test_m_grouped_gemm_masked()
|
||||
@ -0,0 +1,64 @@
|
||||
import os
|
||||
import torch
|
||||
from typing import Any
|
||||
|
||||
from deep_gemm import jit
|
||||
|
||||
|
||||
class Capture:
|
||||
def __init__(self) -> None:
|
||||
self.read_fd = None
|
||||
self.write_fd = None
|
||||
self.saved_stdout = None
|
||||
self.captured = None
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
self.read_fd, self.write_fd = os.pipe()
|
||||
self.saved_stdout = os.dup(1)
|
||||
os.dup2(self.write_fd, 1)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
os.dup2(self.saved_stdout, 1)
|
||||
os.close(self.write_fd)
|
||||
with os.fdopen(self.read_fd, 'r') as f:
|
||||
self.captured = f.read()
|
||||
|
||||
def capture(self) -> str:
|
||||
return self.captured
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Runtime
|
||||
print(f'NVCC compiler: {jit.get_nvcc_compiler()}\n')
|
||||
|
||||
# Templates
|
||||
print('Generated code:')
|
||||
args = (('lhs', torch.float8_e4m3fn), ('rhs', torch.float8_e4m3fn), ('scale', torch.float), ('out', torch.bfloat16),
|
||||
('enable_double_streams', bool), ('stream', torch.cuda.Stream))
|
||||
body = "\n"
|
||||
body += 'std::cout << reinterpret_cast<uint64_t>(lhs) << std::endl;\n'
|
||||
body += 'std::cout << reinterpret_cast<uint64_t>(rhs) << std::endl;\n'
|
||||
body += 'std::cout << reinterpret_cast<uint64_t>(scale) << std::endl;\n'
|
||||
body += 'std::cout << reinterpret_cast<uint64_t>(out) << std::endl;\n'
|
||||
body += 'std::cout << enable_double_streams << std::endl;\n'
|
||||
body += 'std::cout << reinterpret_cast<uint64_t>(stream) << std::endl;\n'
|
||||
code = jit.generate((), args, body)
|
||||
print(code)
|
||||
|
||||
# Build
|
||||
print('Building ...')
|
||||
func = jit.build('test_func', args, code)
|
||||
|
||||
# Test correctness
|
||||
print('Running ...')
|
||||
fp8_tensor = torch.empty((1, ), dtype=torch.float8_e4m3fn, device='cuda')
|
||||
fp32_tensor = torch.empty((1, ), dtype=torch.float, device='cuda')
|
||||
bf16_tensor = torch.empty((1, ), dtype=torch.bfloat16, device='cuda')
|
||||
with Capture() as capture:
|
||||
assert func(fp8_tensor, fp8_tensor, fp32_tensor, bf16_tensor, True, torch.cuda.current_stream()) == 0
|
||||
output = capture.capture()
|
||||
ref_output = f'{fp8_tensor.data_ptr()}\n{fp8_tensor.data_ptr()}\n{fp32_tensor.data_ptr()}\n{bf16_tensor.data_ptr()}\n1\n{torch.cuda.current_stream().cuda_stream}\n'
|
||||
assert output == ref_output, f'{output=}, {ref_output=}'
|
||||
|
||||
print('JIT test passed')
|
||||
12
examples/68_hopper_flash_mla/README.md
Normal file
12
examples/68_hopper_flash_mla/README.md
Normal file
@ -0,0 +1,12 @@
|
||||
# Hopper FlashMLA - Examples
|
||||
The codes in this example are migrated from [FlashMLA](https://github.com/deepseek-ai/FlashMLA/tree/main), it implements an efficient MLA decoding kernel for Hopper GPU.
|
||||
|
||||
# Run the example
|
||||
### Install
|
||||
```
|
||||
python setup.py install
|
||||
```
|
||||
### Run the test
|
||||
```
|
||||
python tests/test_flash_mla.py
|
||||
```
|
||||
213
examples/68_hopper_flash_mla/csrc/flash_api.cpp
Normal file
213
examples/68_hopper_flash_mla/csrc/flash_api.cpp
Normal file
@ -0,0 +1,213 @@
|
||||
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/flash_api.cpp
|
||||
|
||||
#include <torch/python.h>
|
||||
#include <torch/nn/functional.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <cutlass/fast_math.h>
|
||||
|
||||
#include "flash_mla.h"
|
||||
#include "static_switch.h"
|
||||
|
||||
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
|
||||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
|
||||
std::vector<at::Tensor>
|
||||
get_mla_metadata(
|
||||
at::Tensor &seqlens_k,
|
||||
const int num_heads_per_head_k,
|
||||
const int num_heads_k
|
||||
) {
|
||||
// This should match the logic in the MLA kernel.
|
||||
static constexpr int block_size_m = 64;
|
||||
static constexpr int block_size_n = 64;
|
||||
static constexpr int fixed_overhead_num_blocks = 5;
|
||||
|
||||
CHECK_DEVICE(seqlens_k);
|
||||
TORCH_CHECK(seqlens_k.is_contiguous());
|
||||
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32);
|
||||
|
||||
int batch_size = seqlens_k.size(0);
|
||||
int *seqlens_k_ptr = seqlens_k.data_ptr<int>();
|
||||
auto options = seqlens_k.options();
|
||||
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
int sm_count = dprops->multiProcessorCount;
|
||||
int num_sm_parts = sm_count / num_heads_k / cutlass::ceil_div(num_heads_per_head_k, block_size_m);
|
||||
|
||||
auto tile_scheduler_metadata = torch::empty({num_sm_parts, TileSchedulerMetaDataSize}, options);
|
||||
auto num_splits = torch::empty({batch_size + 1}, options);
|
||||
int *tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();
|
||||
int *num_splits_ptr = num_splits.data_ptr<int>();
|
||||
|
||||
at::cuda::CUDAGuard device_guard{(char)seqlens_k.get_device()};
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
Mla_metadata_params params = {};
|
||||
params.seqlens_k_ptr = seqlens_k_ptr;
|
||||
params.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr;
|
||||
params.num_splits_ptr = num_splits_ptr;
|
||||
params.batch_size = batch_size;
|
||||
params.block_size_n = block_size_n;
|
||||
params.fixed_overhead_num_blocks = fixed_overhead_num_blocks;
|
||||
params.num_sm_parts = num_sm_parts;
|
||||
get_mla_metadata_func(params, stream);
|
||||
|
||||
return {tile_scheduler_metadata, num_splits};
|
||||
}
|
||||
|
||||
std::vector<at::Tensor>
|
||||
mha_fwd_kvcache_mla(
|
||||
at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size
|
||||
std::optional<const at::Tensor> &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
|
||||
const int head_size_v,
|
||||
const at::Tensor &seqlens_k, // batch_size
|
||||
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
|
||||
const float softmax_scale,
|
||||
bool is_causal,
|
||||
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
|
||||
const at::Tensor &num_splits // batch_size + 1
|
||||
) {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
||||
TORCH_CHECK(is_sm90);
|
||||
|
||||
at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
|
||||
|
||||
auto q_dtype = q.dtype();
|
||||
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
|
||||
|
||||
CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
|
||||
|
||||
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
|
||||
CHECK_DEVICE(block_table);
|
||||
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
|
||||
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
|
||||
|
||||
const auto sizes = q.sizes();
|
||||
const int batch_size = sizes[0];
|
||||
const int seqlen_q_ori = sizes[1];
|
||||
const int num_heads_ori = sizes[2];
|
||||
const int head_size = sizes[3];
|
||||
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
|
||||
TORCH_CHECK(head_size_v % 32 == 0, "head_size_v should be a multiple of 32");
|
||||
|
||||
const int max_num_blocks_per_seq = block_table.size(1);
|
||||
const int num_blocks = kcache.size(0);
|
||||
const int page_block_size = kcache.size(1);
|
||||
const int num_heads_k = kcache.size(2);
|
||||
TORCH_CHECK(batch_size > 0, "batch size must be postive");
|
||||
TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
|
||||
if (seqlen_q_ori == 1) { is_causal = false; }
|
||||
|
||||
const int ngroups = num_heads_ori / num_heads_k;
|
||||
const int seqlen_q = seqlen_q_ori * ngroups;
|
||||
const int num_heads = num_heads_k;
|
||||
q = q.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size}).transpose(2, 3)
|
||||
.reshape({batch_size, seqlen_q, num_heads, head_size});
|
||||
|
||||
int head_size_k = head_size;
|
||||
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
|
||||
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k);
|
||||
if (vcache_.has_value()) { CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_v); }
|
||||
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
|
||||
|
||||
|
||||
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
|
||||
CHECK_DEVICE(seqlens_k);
|
||||
CHECK_CONTIGUOUS(seqlens_k);
|
||||
CHECK_SHAPE(seqlens_k, batch_size);
|
||||
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
|
||||
auto opts = q.options();
|
||||
at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts);
|
||||
at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
||||
|
||||
Flash_fwd_mla_params params = {};
|
||||
// Set the sizes.
|
||||
params.b = batch_size;
|
||||
params.seqlen_q = seqlen_q;
|
||||
params.cu_seqlens_k = seqlens_k.data_ptr<int>();
|
||||
params.h = num_heads;
|
||||
params.h_h_k_ratio = num_heads / num_heads_k;
|
||||
params.ngroups = ngroups;
|
||||
params.is_causal = is_causal;
|
||||
params.d = head_size;
|
||||
params.d_v = head_size_v;
|
||||
params.scale_softmax = softmax_scale;
|
||||
params.scale_softmax_log2 = float(softmax_scale * M_LOG2E);
|
||||
// Set the pointers and strides.
|
||||
params.q_ptr = q.data_ptr();
|
||||
params.k_ptr = kcache.data_ptr();
|
||||
params.v_ptr = vcache.data_ptr();
|
||||
params.o_ptr = out.data_ptr();
|
||||
params.softmax_lse_ptr = softmax_lse.data_ptr();
|
||||
// All stride are in elements, not bytes.
|
||||
params.q_batch_stride = q.stride(0);
|
||||
params.k_batch_stride = kcache.stride(0);
|
||||
params.v_batch_stride = vcache.stride(0);
|
||||
params.o_batch_stride = out.stride(0);
|
||||
params.q_row_stride = q.stride(-3);
|
||||
params.k_row_stride = kcache.stride(-3);
|
||||
params.v_row_stride = vcache.stride(-3);
|
||||
params.o_row_stride = out.stride(-3);
|
||||
params.q_head_stride = q.stride(-2);
|
||||
params.k_head_stride = kcache.stride(-2);
|
||||
params.v_head_stride = vcache.stride(-2);
|
||||
params.o_head_stride = out.stride(-2);
|
||||
|
||||
params.block_table = block_table.data_ptr<int>();
|
||||
params.block_table_batch_stride = block_table.stride(0);
|
||||
params.page_block_size = page_block_size;
|
||||
|
||||
TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32");
|
||||
TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize);
|
||||
CHECK_DEVICE(tile_scheduler_metadata);
|
||||
CHECK_CONTIGUOUS(tile_scheduler_metadata);
|
||||
params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();
|
||||
params.num_sm_parts = tile_scheduler_metadata.size(0);
|
||||
TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32");
|
||||
CHECK_DEVICE(num_splits);
|
||||
CHECK_CONTIGUOUS(num_splits);
|
||||
params.num_splits_ptr = num_splits.data_ptr<int>();
|
||||
|
||||
at::Tensor softmax_lse_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
||||
at::Tensor out_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q, head_size_v}, opts.dtype(at::kFloat));
|
||||
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
|
||||
params.oaccum_ptr = out_accum.data_ptr();
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
TORCH_CHECK(head_size == 576);
|
||||
|
||||
if (q_dtype == torch::kBFloat16) {
|
||||
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(params, stream);
|
||||
}
|
||||
#ifndef FLASH_MLA_DISABLE_FP16
|
||||
else if (q_dtype == torch::kHalf) {
|
||||
run_mha_fwd_splitkv_mla<cutlass::half_t, 576>(params, stream);
|
||||
}
|
||||
#endif
|
||||
else {
|
||||
TORCH_CHECK(false, "Unsupported tensor dtype for query");
|
||||
}
|
||||
|
||||
out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
|
||||
.reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v});
|
||||
softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, ngroups}).transpose(2, 3)
|
||||
.reshape({batch_size, num_heads_ori, seqlen_q_ori});
|
||||
|
||||
return {out, softmax_lse};
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.doc() = "FlashMLA";
|
||||
m.def("get_mla_metadata", &get_mla_metadata);
|
||||
m.def("fwd_kvcache_mla", &mha_fwd_kvcache_mla);
|
||||
}
|
||||
@ -0,0 +1,3 @@
|
||||
#include "flash_fwd_mla_kernel.h"
|
||||
|
||||
template void run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
|
||||
@ -0,0 +1,3 @@
|
||||
#include "flash_fwd_mla_kernel.h"
|
||||
|
||||
template void run_mha_fwd_splitkv_mla<cutlass::half_t, 576>(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
|
||||
603
examples/68_hopper_flash_mla/csrc/flash_fwd_mla_kernel.h
Normal file
603
examples/68_hopper_flash_mla/csrc/flash_fwd_mla_kernel.h
Normal file
@ -0,0 +1,603 @@
|
||||
#pragma once
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#include "named_barrier.h"
|
||||
#include "utils.h"
|
||||
#include "softmax.h"
|
||||
#include "static_switch.h"
|
||||
#include "flash_mla.h"
|
||||
|
||||
|
||||
template<typename PrecType, int DIM, int DIM2 = DIM>
|
||||
constexpr auto getSmemLayoutK() {
|
||||
constexpr int headSizeBytes = sizeof(PrecType) * DIM;
|
||||
constexpr int headSizeBytes2 = sizeof(PrecType) * DIM2;
|
||||
|
||||
if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) {
|
||||
return GMMA::Layout_K_SW128_Atom<PrecType>{};
|
||||
} else if constexpr (headSizeBytes % 64 == 0 && headSizeBytes2 % 64 == 0) {
|
||||
return GMMA::Layout_K_SW64_Atom<PrecType>{};
|
||||
} else {
|
||||
return GMMA::Layout_K_SW32_Atom<PrecType>{};
|
||||
}
|
||||
}
|
||||
|
||||
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::bfloat16_t, int kHeadDimV_ = 0>
|
||||
struct Flash_fwd_kernel_traits_mla {
|
||||
using Element = elem_type;
|
||||
using ElementAccum = float;
|
||||
using index_t = int64_t;
|
||||
|
||||
static constexpr int kNWarps = kNWarps_;
|
||||
static constexpr int kNThreads = kNWarps * 32;
|
||||
static constexpr int kNWarpsS = 4;
|
||||
static constexpr int kNThreadsS = kNWarpsS * 32;
|
||||
|
||||
static constexpr int kBlockM = kBlockM_;
|
||||
static constexpr int kBlockN = kBlockN_;
|
||||
static constexpr int kHeadDim = kHeadDim_;
|
||||
static_assert(kHeadDim % 32 == 0);
|
||||
static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim;
|
||||
static_assert(kHeadDimV % 32 == 0);
|
||||
static_assert(kHeadDimV <= kHeadDim);
|
||||
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
|
||||
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
|
||||
|
||||
using TiledMma = decltype(make_tiled_mma(
|
||||
cute::GMMA::ss_op_selector<Element, Element, ElementAccum, Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>,
|
||||
GMMA::Major::K, GMMA::Major::K>(),
|
||||
Layout<Shape<Int<kNWarpsS / 4>, _1, _1>>{}));
|
||||
|
||||
static constexpr int AtomLayoutNO = kNThreads / kNThreadsS;
|
||||
using TiledMmaO = decltype(make_tiled_mma(
|
||||
cute::GMMA::rs_op_selector<Element, Element, ElementAccum, Shape<Int<kBlockM>, Int<kHeadDimV / AtomLayoutNO>, Int<kBlockN>>,
|
||||
GMMA::Major::K, GMMA::Major::MN>(),
|
||||
Layout<Shape<Int<kNWarpsS / 4>, Int<AtomLayoutNO>, _1>>{}));
|
||||
|
||||
using SmemLayoutQ = decltype(tile_to_shape(
|
||||
getSmemLayoutK<Element, kHeadDim>(),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
|
||||
|
||||
using SmemLayoutK = decltype(tile_to_shape(
|
||||
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
|
||||
|
||||
using SmemLayoutV = decltype(tile_to_shape(
|
||||
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
|
||||
Shape<Int<kBlockN>, Int<kHeadDimV>>{}));
|
||||
using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape<Int<kHeadDimV>, Int<kBlockN>>{}, GenRowMajor{})));
|
||||
|
||||
using SmemLayoutP = Layout<Shape<Shape<_2, _2>, Int<kNThreadsS>, _1, Int<kBlockN / 8>>>;
|
||||
using SmemLayoutRow = Layout<Shape<_2, Int<kNThreadsS>>, Stride<_1, _2>>;
|
||||
|
||||
using SmemLayoutAtomO = decltype(composition(
|
||||
Swizzle<kSwizzle, 3, 3>{},
|
||||
Layout<Shape<Int<8>, Int<kBlockKSmem>>, Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
using SmemLayoutO = decltype(tile_to_shape(
|
||||
SmemLayoutAtomO{},
|
||||
Shape<Int<kBlockM>, Int<kHeadDimV>>{}));
|
||||
using SmemCopyAtomO = Copy_Atom<SM90_U32x4_STSM_N, Element>;
|
||||
using SmemCopyAtomOaccum = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>;
|
||||
|
||||
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
|
||||
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
|
||||
static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
|
||||
using Gmem_copy_struct = SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>;
|
||||
static constexpr int kNThreadsLoad = kNThreads - kNThreadsS;
|
||||
static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
|
||||
|
||||
using GmemLayoutAtom = Layout<
|
||||
Shape<Int<kNThreadsLoad / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
||||
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
||||
using GmemTiledCopy = decltype(make_tiled_copy(
|
||||
Copy_Atom<Gmem_copy_struct, Element>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
||||
|
||||
using GmemLayoutAtomO = Layout<
|
||||
Shape<Int<kNThreadsS / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
||||
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
||||
using GmemTiledCopyO = decltype(make_tiled_copy(
|
||||
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
|
||||
GmemLayoutAtomO{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
||||
|
||||
static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum);
|
||||
static constexpr int kGmemThreadsPerRowAccum = kBlockKSmem / kGmemElemsPerLoadAccum;
|
||||
using GmemLayoutAtomOaccum = Layout<
|
||||
Shape<Int<kNThreadsS / kGmemThreadsPerRowAccum>, Int<kGmemThreadsPerRowAccum>>,
|
||||
Stride<Int<kGmemThreadsPerRowAccum>, _1>>;
|
||||
using GmemTiledCopyOaccum = decltype(make_tiled_copy(
|
||||
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
||||
GmemLayoutAtomOaccum{},
|
||||
Layout<Shape<_1, _4>>{})); // Val layout, 4 vals per store
|
||||
};
|
||||
|
||||
namespace flash {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<typename Kernel_traits>
|
||||
struct SharedStorageMLA {
|
||||
union {
|
||||
struct {
|
||||
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutQ>> smem_q;
|
||||
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutK> * 2> smem_k; // Double buffer
|
||||
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutP>> smem_p;
|
||||
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_scale;
|
||||
};
|
||||
struct {
|
||||
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_max;
|
||||
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_sum;
|
||||
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutO>> smem_o;
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, bool Split, typename SharedStorage, typename AccO, typename Softmax>
|
||||
__forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx,
|
||||
SharedStorage &shared_storage, AccO tOrO, Softmax softmax) {
|
||||
constexpr int kBlockM = Kernel_traits::kBlockM;
|
||||
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
|
||||
constexpr int kNThreadsS = Kernel_traits::kNThreadsS;
|
||||
using Element = typename Kernel_traits::Element;
|
||||
using ElementAccum = typename Kernel_traits::ElementAccum;
|
||||
using index_t = typename Kernel_traits::index_t;
|
||||
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
typename Kernel_traits::TiledMmaO tiled_mma_o;
|
||||
auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);
|
||||
|
||||
// Epilogue
|
||||
|
||||
const int split_offset = __ldg(params.num_splits_ptr + bidb);
|
||||
|
||||
Tensor lse = softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(tOrO, params.scale_softmax);
|
||||
|
||||
using ElementO = std::conditional_t<!Split, Element, ElementAccum>;
|
||||
Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(shared_storage.smem_o.data())), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
|
||||
// Partition sO to match the accumulator partitioning
|
||||
using SmemTiledCopyO = std::conditional_t<
|
||||
!Split,
|
||||
typename Kernel_traits::SmemCopyAtomO,
|
||||
typename Kernel_traits::SmemCopyAtomOaccum
|
||||
>;
|
||||
auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma_o);
|
||||
auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx);
|
||||
Tensor rO = flash::convert_type<ElementO>(tOrO);
|
||||
Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
|
||||
Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
||||
|
||||
__syncthreads();
|
||||
|
||||
cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum);
|
||||
|
||||
const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
|
||||
const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v;
|
||||
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
|
||||
const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
|
||||
|
||||
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
|
||||
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
|
||||
make_stride(Split ? kHeadDimV : params.o_row_stride, _1{}));
|
||||
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (Split ? row_offset_lseaccum : row_offset_lse)),
|
||||
Shape<Int<kBlockM>>{}, Stride<_1>{});
|
||||
|
||||
using GmemTiledCopyO = std::conditional_t<!Split, typename Kernel_traits::GmemTiledCopyO, typename Kernel_traits::GmemTiledCopyOaccum>;
|
||||
GmemTiledCopyO gmem_tiled_copy_Oaccum;
|
||||
auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
|
||||
Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
||||
Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (tidx >= kNThreadsS) { return; }
|
||||
|
||||
Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
|
||||
cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum);
|
||||
|
||||
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
Tensor taccOcO = thr_mma_o.partition_C(caccO); // ((MMA=4, X), MMA_M, MMA_K=1)
|
||||
Tensor taccOcO_row = taccOcO(make_coord(0, _, 0), _, 0);
|
||||
CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
|
||||
if (get<1>(taccOcO_row(0)) == 0) {
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(lse); ++mi) {
|
||||
const int row = get<0>(taccOcO_row(mi));
|
||||
if (row < params.seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); }
|
||||
}
|
||||
}
|
||||
|
||||
// Construct identity layout for sO
|
||||
Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
// Repeat the partitioning with identity layouts
|
||||
Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
|
||||
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
|
||||
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
||||
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||
gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, params.seqlen_q - m_block * kBlockM
|
||||
);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_causal, typename SharedStorage>
|
||||
__forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_fwd_mla_params ¶ms,
|
||||
const int bidb, const int bidh, const int m_block,
|
||||
const int n_split_idx, const int seqlen_k,
|
||||
const int n_block_min, const int n_block_max, const bool NoSplit,
|
||||
SharedStorage &shared_storage) {
|
||||
constexpr int kBlockM = Kernel_traits::kBlockM;
|
||||
constexpr int kBlockN = Kernel_traits::kBlockN;
|
||||
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
||||
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
|
||||
constexpr int kNThreads = Kernel_traits::kNThreads;
|
||||
constexpr int kNThreadsS = Kernel_traits::kNThreadsS;
|
||||
static_assert(kNThreads == 256 and kNThreadsS == 128);
|
||||
using Element = typename Kernel_traits::Element;
|
||||
using index_t = typename Kernel_traits::index_t;
|
||||
|
||||
const int tidx = threadIdx.x;
|
||||
int n_block = n_block_max - 1;
|
||||
|
||||
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{});
|
||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{});
|
||||
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{});
|
||||
Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{});
|
||||
|
||||
Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{});
|
||||
Tensor tPsP = sP(_, tidx % kNThreadsS, _, _);
|
||||
Tensor sScale_o = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), typename Kernel_traits::SmemLayoutRow{});
|
||||
Tensor tScale_osScale_o = sScale_o(_, tidx % kNThreadsS);
|
||||
Tensor sRow_max = make_tensor(make_smem_ptr(shared_storage.smem_max.data()), typename Kernel_traits::SmemLayoutRow{});
|
||||
Tensor tRow_maxsRow_max = sRow_max(_, tidx % kNThreadsS);
|
||||
Tensor sRow_sum = make_tensor(make_smem_ptr(shared_storage.smem_sum.data()), typename Kernel_traits::SmemLayoutRow{});
|
||||
Tensor tRow_sumsRow_sum = sRow_sum(_, tidx % kNThreadsS);
|
||||
|
||||
typename Kernel_traits::TiledMmaO tiled_mma_o;
|
||||
auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);
|
||||
Tensor tOrVt = thr_mma_o.partition_fragment_B(sVt); // (MMA, MMA_K,MMA_N)
|
||||
Tensor tOrO = partition_fragment_C(tiled_mma_o, Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // ((MMA=4, X), MMA_M, MMA_N=1)
|
||||
clear(tOrO);
|
||||
|
||||
flash::Softmax<2 * size<1>(tOrO)> softmax;
|
||||
|
||||
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
||||
if (warp_group_idx == 0) {
|
||||
typename Kernel_traits::TiledMma tiled_mma;
|
||||
auto thr_mma = tiled_mma.get_thread_slice(tidx);
|
||||
Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
|
||||
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
|
||||
|
||||
if (n_block % 2 == 1) {
|
||||
// Double buffer for sK
|
||||
constexpr int sK_offset = size(sK);
|
||||
tSrK.data() = tSrK.data() + sK_offset / 8;
|
||||
tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
||||
}
|
||||
|
||||
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
|
||||
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
|
||||
// We will have at least 1 "masking" iteration.
|
||||
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
|
||||
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
|
||||
constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1;
|
||||
#pragma unroll 1
|
||||
for (int masking_step = n_masking_steps; n_block >= n_block_min; --masking_step, --n_block) {
|
||||
__syncthreads();
|
||||
|
||||
Tensor tSrS = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // ((MMA=4, X), MMA_M, MMA_N=1)
|
||||
flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma, tSrQ, tSrK, tSrS);
|
||||
|
||||
const bool is_masking_step = masking_step > 0;
|
||||
const bool is_first_masking_step = masking_step == n_masking_steps;
|
||||
|
||||
if (is_masking_step) {
|
||||
Tensor cS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{});
|
||||
Tensor tScS = thr_mma.partition_C(cS);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tSrS); ++i) {
|
||||
if constexpr (!Is_causal) { // Just masking based on col
|
||||
if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) tSrS(i) = -INFINITY;
|
||||
} else {
|
||||
// Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
|
||||
// col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
|
||||
int row = int(get<0>(tScS(i)));
|
||||
int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups;
|
||||
if (int(get<1>(tScS(i))) > col_limit_right) tSrS(i) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We have key_padding_mask so we'll need to Check_inf
|
||||
Tensor scale_o = is_first_masking_step
|
||||
? softmax.template softmax</*Is_first=*/true, /*Check_inf=*/Is_causal>(tSrS, params.scale_softmax_log2)
|
||||
: is_masking_step ?
|
||||
softmax.template softmax</*Is_first=*/false, /*Check_inf=*/Is_causal>(tSrS, params.scale_softmax_log2)
|
||||
: softmax.template softmax</*Is_first=*/false, /*Check_inf=*//*Is_local=*/false>(tSrS, params.scale_softmax_log2);
|
||||
|
||||
Tensor rP = flash::convert_type<Element>(tSrS);
|
||||
cute::copy(rP, tPsP);
|
||||
cute::copy(scale_o, tScale_osScale_o);
|
||||
|
||||
cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::SReady));
|
||||
|
||||
flash::rescale_o(tOrO, scale_o);
|
||||
|
||||
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
||||
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
|
||||
|
||||
// Double buffer for sK
|
||||
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
|
||||
tSrK.data() = tSrK.data() + sK_offset / 8;
|
||||
tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
||||
}
|
||||
|
||||
cute::copy(softmax.row_max, tRow_maxsRow_max);
|
||||
cute::copy(softmax.row_sum, tRow_sumsRow_sum);
|
||||
cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));
|
||||
} else {
|
||||
const int *block_table = params.block_table + bidb * params.block_table_batch_stride;
|
||||
int cur_block_table = __ldg(&block_table[n_block]);
|
||||
|
||||
const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
|
||||
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.q_row_stride, _1{}));
|
||||
typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_Q;
|
||||
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx - kNThreadsS);
|
||||
Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ);
|
||||
Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ);
|
||||
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
|
||||
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
|
||||
|
||||
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
|
||||
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true>(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ,
|
||||
params.seqlen_q - m_block * kBlockM);
|
||||
|
||||
const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride;
|
||||
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(params.k_row_stride, _1{}));
|
||||
typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_K;
|
||||
auto gmem_thr_copy_K = gmem_tiled_copy_K.get_thread_slice(tidx - kNThreadsS);
|
||||
Tensor tKgK = gmem_thr_copy_K.partition_S(gK);
|
||||
Tensor tKsK = gmem_thr_copy_K.partition_D(sK);
|
||||
Tensor cK = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
|
||||
Tensor tKcK = gmem_thr_copy_K.partition_S(cK); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
|
||||
Tensor tKpK = make_tensor<bool>(make_shape(size<2>(tKsK)));
|
||||
|
||||
if (n_block % 2 == 1) {
|
||||
// Double buffer for sK
|
||||
constexpr int sK_offset = size(sK);
|
||||
tKsK.data() = tKsK.data() + sK_offset;
|
||||
tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
||||
}
|
||||
|
||||
// We need to clear the sK smem tiles because K is V.
|
||||
const index_t offset_k = cur_block_table * params.k_batch_stride;
|
||||
tKgK.data() = tKgK.data() + offset_k;
|
||||
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true, /*Clear_OOB_MN=*/true>(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK,
|
||||
seqlen_k - n_block * kBlockN);
|
||||
tKgK.data() = tKgK.data() + -offset_k;
|
||||
cute::cp_async_fence();
|
||||
|
||||
if (n_block - 1 >= n_block_min) {
|
||||
cur_block_table = __ldg(&block_table[n_block - 1]);
|
||||
}
|
||||
|
||||
#pragma unroll 1
|
||||
for (; n_block >= n_block_min; --n_block) {
|
||||
flash::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
if (n_block - 1 >= n_block_min) {
|
||||
// Double buffer for sK
|
||||
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
|
||||
tKsK.data() = tKsK.data() + sK_offset;
|
||||
|
||||
const index_t offset_k = cur_block_table * params.k_batch_stride;
|
||||
tKgK.data() = tKgK.data() + offset_k;
|
||||
flash::copy</*Is_even_MN=*/true, /*Is_even_K=*/true>(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK);
|
||||
tKgK.data() = tKgK.data() + -offset_k;
|
||||
cute::cp_async_fence();
|
||||
}
|
||||
|
||||
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::SReady));
|
||||
|
||||
if (n_block - 2 >= n_block_min) {
|
||||
cur_block_table = __ldg(&block_table[n_block - 2]);
|
||||
}
|
||||
|
||||
typename Kernel_traits::TiledMma tiled_mma;
|
||||
auto tSrS_layout = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}).layout();
|
||||
Tensor rP = make_tensor<Element>(tSrS_layout);
|
||||
Tensor scale_o = make_tensor<float>(Shape<_2>{});
|
||||
cute::copy(tScale_osScale_o, scale_o);
|
||||
cute::copy(tPsP, rP);
|
||||
|
||||
flash::rescale_o(tOrO, scale_o);
|
||||
|
||||
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
||||
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
|
||||
|
||||
// Double buffer for sK
|
||||
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
|
||||
tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
||||
}
|
||||
|
||||
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));
|
||||
cute::copy(tRow_maxsRow_max, softmax.row_max);
|
||||
cute::copy(tRow_sumsRow_sum, softmax.row_sum);
|
||||
}
|
||||
|
||||
if (NoSplit)
|
||||
store<Kernel_traits, false>(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax);
|
||||
else
|
||||
store<Kernel_traits, true>(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_causal, typename SharedStorage>
|
||||
__global__ void __launch_bounds__(Kernel_traits::kNThreads, 1, 1)
|
||||
flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params) {
|
||||
constexpr int kBlockN = Kernel_traits::kBlockN;
|
||||
const int m_block = blockIdx.x;
|
||||
const int bidh = blockIdx.y;
|
||||
const int partition_idx = blockIdx.z;
|
||||
|
||||
extern __shared__ char shared_memory[];
|
||||
auto &shared_storage = *reinterpret_cast<SharedStorage *>(shared_memory);
|
||||
|
||||
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize;
|
||||
int4 tile_scheduler_metadata = __ldg(reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr));
|
||||
int begin_idx = tile_scheduler_metadata.x;
|
||||
int begin_seqlen = tile_scheduler_metadata.y;
|
||||
int end_idx = tile_scheduler_metadata.z;
|
||||
int end_seqlen = tile_scheduler_metadata.w;
|
||||
if (begin_idx >= params.b) return;
|
||||
int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4);
|
||||
|
||||
#pragma unroll 1
|
||||
for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) {
|
||||
const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0;
|
||||
const int seqlen_k = __ldg(params.cu_seqlens_k + batch_id);
|
||||
const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0;
|
||||
const int n_block_max = batch_id == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN);
|
||||
const bool NoSplit = n_block_min == 0 && n_block_max == cute::ceil_div(seqlen_k, kBlockN);
|
||||
if (batch_id > begin_idx) {
|
||||
__syncthreads(); // Barrier between two tiles.
|
||||
}
|
||||
flash::compute_attn_1rowblock_splitkv_mla<Kernel_traits, Is_causal>(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Element, typename ElementAccum, typename index_t, int kHeadDimV, int kMaxSplits>
|
||||
__global__ void __launch_bounds__(256, 1, 1)
|
||||
flash_fwd_splitkv_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params) {
|
||||
constexpr int kNThreads = 128;
|
||||
|
||||
const int tidx = threadIdx.x;
|
||||
const int bidx = blockIdx.x;
|
||||
const int hs = params.h * params.seqlen_q;
|
||||
const int batch_idx = bidx / hs;
|
||||
const int hs_idx = bidx % hs;
|
||||
|
||||
const int split_offset = __ldg(params.num_splits_ptr + batch_idx);
|
||||
const int actual_num_splits = __ldg(params.num_splits_ptr + batch_idx + 1) - split_offset;
|
||||
FLASH_DEVICE_ASSERT(actual_num_splits <= kMaxSplits);
|
||||
if (actual_num_splits == 1) return;
|
||||
|
||||
__shared__ ElementAccum sLseScale[kMaxSplits];
|
||||
|
||||
const index_t row_offset_lseaccum = split_offset * hs + hs_idx;
|
||||
const index_t row_offset_lse = bidx;
|
||||
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lseaccum_ptr) + row_offset_lseaccum),
|
||||
Shape<Int<kMaxSplits>>{}, make_stride(hs));
|
||||
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
|
||||
Shape<_1>{}, Stride<_1>{});
|
||||
|
||||
int warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
if (warp_idx == 0) {
|
||||
constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32);
|
||||
|
||||
float local_lse[kNLsePerThread];
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
const int split = i * 32 + tidx;
|
||||
local_lse[i] = split < actual_num_splits ? gLSEaccum(split) : -INFINITY;
|
||||
}
|
||||
|
||||
float max_lse = -INFINITY;
|
||||
for (int i = 0; i < kNLsePerThread; ++i) max_lse = max(max_lse, local_lse[i]);
|
||||
for (int offset = 16; offset >= 1; offset /= 2) max_lse = max(max_lse, __shfl_xor_sync(uint32_t(-1), max_lse, offset));
|
||||
max_lse = max_lse == -INFINITY ? 0.0f : max_lse; // In case all local LSEs are -inf
|
||||
|
||||
float sum_lse = 0;
|
||||
for (int i = 0; i < kNLsePerThread; ++i) sum_lse = sum_lse + expf(local_lse[i] - max_lse);
|
||||
for (int offset = 16; offset >= 1; offset /= 2) sum_lse = sum_lse + __shfl_xor_sync(uint32_t(-1), sum_lse, offset);
|
||||
|
||||
float global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? INFINITY : logf(sum_lse) + max_lse;
|
||||
if (tidx == 0) gLSE(0) = global_lse;
|
||||
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
const int split = i * 32 + tidx;
|
||||
if (split < actual_num_splits) sLseScale[split] = expf(local_lse[i] - global_lse);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
static_assert(kHeadDimV % kNThreads == 0);
|
||||
constexpr int Elements = kHeadDimV / kNThreads;
|
||||
const index_t row_offset_oaccum = (split_offset * hs + hs_idx) * kHeadDimV;
|
||||
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.oaccum_ptr) + row_offset_oaccum),
|
||||
Shape<Int<kHeadDimV>>{}, Stride<_1>{});
|
||||
using GmemTiledCopyOaccum = decltype(make_tiled_copy(
|
||||
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
||||
Layout<Shape<Int<kNThreads>>>{},
|
||||
Layout<Shape<Int<Elements>>>{}));
|
||||
GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
|
||||
auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
|
||||
Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum);
|
||||
Tensor tOrOaccum = make_tensor<ElementAccum>(shape(tOgOaccum));
|
||||
Tensor tOrO = make_tensor<ElementAccum>(shape(tOgOaccum));
|
||||
clear(tOrO);
|
||||
|
||||
for (int split = 0; split < actual_num_splits; ++split) {
|
||||
cute::copy(tOgOaccum, tOrOaccum);
|
||||
ElementAccum lse_scale = sLseScale[split];
|
||||
for (int i = 0; i < size(tOrO); ++i) {
|
||||
tOrO(i) += lse_scale * tOrOaccum(i);
|
||||
}
|
||||
tOgOaccum.data() = tOgOaccum.data() + hs * kHeadDimV;
|
||||
}
|
||||
|
||||
Tensor rO = flash::convert_type<Element>(tOrO);
|
||||
const int head_idx = (bidx - batch_idx * hs) / params.seqlen_q;
|
||||
const int row = bidx - batch_idx * hs - head_idx * params.seqlen_q;
|
||||
auto o_ptr = reinterpret_cast<Element *>(params.o_ptr) + batch_idx * params.o_batch_stride + head_idx * params.o_head_stride + row * params.o_row_stride;
|
||||
Tensor gO = make_tensor(make_gmem_ptr(o_ptr + tidx * Elements), Shape<Int<decltype(size<0>(rO))::value>>{}, Stride<_1>{});
|
||||
cute::copy(rO, gO);
|
||||
}
|
||||
|
||||
} // namespace flash
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, typename SharedStorage>
|
||||
void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) {
|
||||
FLASH_ASSERT(params.page_block_size == Kernel_traits::kBlockN);
|
||||
const int num_m_block = cute::ceil_div(params.seqlen_q, Kernel_traits::kBlockM);
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
auto kernel = &flash::flash_fwd_splitkv_mla_kernel<Kernel_traits, Is_causal, SharedStorage>;
|
||||
constexpr size_t smem_size = sizeof(SharedStorage);
|
||||
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
kernel<<<dim3(num_m_block, params.h, params.num_sm_parts), Kernel_traits::kNThreads, smem_size, stream>>>(params);
|
||||
});
|
||||
CHECK_CUDA_KERNEL_LAUNCH();
|
||||
|
||||
dim3 grid_combine(params.b * params.h * params.seqlen_q);
|
||||
MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] {
|
||||
auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel<
|
||||
typename Kernel_traits::Element, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits>;
|
||||
combine_kernel<<<grid_combine, 128, 0, stream>>>(params);
|
||||
});
|
||||
CHECK_CUDA_KERNEL_LAUNCH();
|
||||
}
|
||||
|
||||
template<typename T, int Headdim>
|
||||
void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) {
|
||||
static_assert(Headdim == 576);
|
||||
FLASH_ASSERT(params.d_v == 512);
|
||||
FLASH_ASSERT(params.k_ptr == params.v_ptr); // Shared_KV
|
||||
using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, 512>;
|
||||
run_flash_splitkv_fwd_mla<Kernel_traits, flash::SharedStorageMLA<Kernel_traits>>(params, stream);
|
||||
}
|
||||
77
examples/68_hopper_flash_mla/csrc/flash_fwd_mla_metadata.cu
Normal file
77
examples/68_hopper_flash_mla/csrc/flash_fwd_mla_metadata.cu
Normal file
@ -0,0 +1,77 @@
|
||||
#include "flash_fwd_mla_kernel.h"
|
||||
|
||||
static constexpr int MaxBatchSize = 4096;
|
||||
|
||||
__global__ void __launch_bounds__(256, 1, 1)
|
||||
get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
|
||||
int *seqlens_k_ptr = params.seqlens_k_ptr;
|
||||
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr;
|
||||
int *num_splits_ptr = params.num_splits_ptr;
|
||||
int batch_size = params.batch_size;
|
||||
int block_size_n = params.block_size_n;
|
||||
int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks;
|
||||
int num_sm_parts = params.num_sm_parts;
|
||||
|
||||
__shared__ int num_blocks_shared[MaxBatchSize];
|
||||
__shared__ int num_splits_shared[MaxBatchSize];
|
||||
|
||||
int total_num_blocks = 0;
|
||||
for (int i = threadIdx.x; i < batch_size; i += 32) {
|
||||
int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n);
|
||||
total_num_blocks += num_blocks + fixed_overhead_num_blocks;
|
||||
num_blocks_shared[i] = num_blocks;
|
||||
}
|
||||
for (int offset = 16; offset >= 1; offset /= 2) {
|
||||
total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks;
|
||||
|
||||
int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0;
|
||||
num_splits_shared[0] = 0;
|
||||
for (int i = 0; i < num_sm_parts; ++i) {
|
||||
int tile_scheduler_metadata0[4], tile_scheduler_metadata1;
|
||||
tile_scheduler_metadata0[0] = now_idx;
|
||||
tile_scheduler_metadata0[1] = now_block * block_size_n;
|
||||
tile_scheduler_metadata1 = now_n_split_idx;
|
||||
int remain_payload = payload;
|
||||
while (now_idx < batch_size) {
|
||||
int num_blocks = num_blocks_shared[now_idx];
|
||||
int now_remain_blocks = num_blocks - now_block;
|
||||
if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) {
|
||||
cum_num_splits += now_n_split_idx + 1;
|
||||
num_splits_shared[now_idx + 1] = cum_num_splits;
|
||||
remain_payload -= now_remain_blocks + fixed_overhead_num_blocks;
|
||||
++now_idx;
|
||||
now_block = 0;
|
||||
now_n_split_idx = 0;
|
||||
} else {
|
||||
if (remain_payload - fixed_overhead_num_blocks > 0) {
|
||||
now_block += remain_payload - fixed_overhead_num_blocks;
|
||||
++now_n_split_idx;
|
||||
remain_payload = 0;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1;
|
||||
tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1];
|
||||
*reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast<int4 *>(tile_scheduler_metadata0);
|
||||
tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1;
|
||||
}
|
||||
FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
for (int i = threadIdx.x; i <= batch_size; i += 32) {
|
||||
num_splits_ptr[i] = num_splits_shared[i];
|
||||
}
|
||||
}
|
||||
|
||||
void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream) {
|
||||
FLASH_ASSERT(params.batch_size < MaxBatchSize);
|
||||
get_mla_metadata_kernel<<<1, 32, 0, stream>>>(params);
|
||||
CHECK_CUDA_KERNEL_LAUNCH();
|
||||
}
|
||||
63
examples/68_hopper_flash_mla/csrc/flash_mla.h
Normal file
63
examples/68_hopper_flash_mla/csrc/flash_mla.h
Normal file
@ -0,0 +1,63 @@
|
||||
#pragma once
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Flash_fwd_mla_params {
|
||||
using index_t = int64_t;
|
||||
|
||||
int b, seqlen_q, d, d_v;
|
||||
int h, h_h_k_ratio, ngroups;
|
||||
bool is_causal;
|
||||
float scale_softmax, scale_softmax_log2;
|
||||
int *__restrict__ cu_seqlens_k;
|
||||
|
||||
void *__restrict__ q_ptr;
|
||||
void *__restrict__ k_ptr;
|
||||
void *__restrict__ v_ptr;
|
||||
void *__restrict__ o_ptr;
|
||||
void *__restrict__ softmax_lse_ptr;
|
||||
|
||||
index_t q_batch_stride;
|
||||
index_t k_batch_stride;
|
||||
index_t v_batch_stride;
|
||||
index_t o_batch_stride;
|
||||
index_t q_row_stride;
|
||||
index_t k_row_stride;
|
||||
index_t v_row_stride;
|
||||
index_t o_row_stride;
|
||||
index_t q_head_stride;
|
||||
index_t k_head_stride;
|
||||
index_t v_head_stride;
|
||||
index_t o_head_stride;
|
||||
|
||||
int *__restrict__ block_table;
|
||||
index_t block_table_batch_stride;
|
||||
int page_block_size;
|
||||
|
||||
int *__restrict__ tile_scheduler_metadata_ptr;
|
||||
int num_sm_parts;
|
||||
int *__restrict__ num_splits_ptr;
|
||||
|
||||
void *__restrict__ softmax_lseaccum_ptr;
|
||||
void *__restrict__ oaccum_ptr;
|
||||
};
|
||||
|
||||
static constexpr int TileSchedulerMetaDataSize = 8;
|
||||
// [begin_idx, begin_seqlen, end_idx, end_seqlen, begin_n_split_idx, _, _, _]
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T, int Headdim>
|
||||
void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
|
||||
|
||||
struct Mla_metadata_params {
|
||||
int *__restrict__ seqlens_k_ptr;
|
||||
int *__restrict__ tile_scheduler_metadata_ptr;
|
||||
int *__restrict__ num_splits_ptr;
|
||||
int batch_size;
|
||||
int block_size_n;
|
||||
int fixed_overhead_num_blocks;
|
||||
int num_sm_parts;
|
||||
};
|
||||
|
||||
void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream);
|
||||
15
examples/68_hopper_flash_mla/csrc/named_barrier.h
Normal file
15
examples/68_hopper_flash_mla/csrc/named_barrier.h
Normal file
@ -0,0 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/barrier.h"
|
||||
|
||||
namespace flash {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Enumerates the reserved named barriers to avoid potential conflicts
|
||||
|
||||
enum class NamedBarriers {
|
||||
SReady = 1,
|
||||
SoftmaxReady = 2,
|
||||
};
|
||||
|
||||
} // flash
|
||||
197
examples/68_hopper_flash_mla/csrc/softmax.h
Normal file
197
examples/68_hopper_flash_mla/csrc/softmax.h
Normal file
@ -0,0 +1,197 @@
|
||||
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/softmax.h
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
namespace flash {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); mi++) {
|
||||
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
|
||||
#pragma unroll
|
||||
for (int ni = 1; ni < size<1>(tensor); ni++) {
|
||||
summary(mi) = op(summary(mi), tensor(mi, ni));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
|
||||
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(dst); i++){
|
||||
dst(i) = Allreduce<4>::run(src(i), op);
|
||||
}
|
||||
}
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
|
||||
thread_reduce_<zero_init>(tensor, summary, op);
|
||||
quad_allreduce_(summary, summary, op);
|
||||
}
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
|
||||
MaxOp<float> max_op;
|
||||
reduce_<zero_init>(tensor, max, max_op);
|
||||
}
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
|
||||
SumOp<float> sum_op;
|
||||
thread_reduce_<zero_init>(tensor, sum, sum_op);
|
||||
}
|
||||
|
||||
// Apply the exp to all the elements.
|
||||
template <bool Scale_max=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__forceinline__ __device__ auto scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
// If max is -inf, then all elements must have been -inf (possibly due to masking).
|
||||
// We don't want (-inf - (-inf)) since that would give NaN.
|
||||
// If we don't have float around M_LOG2E the multiplication is done in fp64.
|
||||
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
||||
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
|
||||
// max * log_2(e)) This allows the compiler to use the ffma
|
||||
// instruction instead of fadd and fmul separately.
|
||||
// The following macro will disable the use of fma.
|
||||
// See: https://github.com/pytorch/pytorch/issues/121558 for more details
|
||||
// This macro is set in PyTorch and not FlashAttention
|
||||
#ifdef UNFUSE_FMA
|
||||
tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled);
|
||||
#else
|
||||
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
return tensor;
|
||||
}
|
||||
|
||||
// Apply the exp to all the elements.
|
||||
template <bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
MaxOp<float> max_op;
|
||||
max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
|
||||
#pragma unroll
|
||||
for (int ni = 1; ni < size<1>(tensor); ni++) {
|
||||
max(mi) = max_op(max(mi), tensor(mi, ni));
|
||||
}
|
||||
max(mi) = Allreduce<4>::run(max(mi), max_op);
|
||||
// If max is -inf, then all elements must have been -inf (possibly due to masking).
|
||||
// We don't want (-inf - (-inf)) since that would give NaN.
|
||||
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
|
||||
sum(mi) = 0;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
||||
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
|
||||
// max * log_2(e)) This allows the compiler to use the ffma
|
||||
// instruction instead of fadd and fmul separately.
|
||||
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
|
||||
sum(mi) += tensor(mi, ni);
|
||||
}
|
||||
SumOp<float> sum_op;
|
||||
sum(mi) = Allreduce<4>::run(sum(mi), sum_op);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Tensor0, typename Tensor1>
|
||||
__forceinline__ __device__ void rescale_o(Tensor0 &acc_o, Tensor1 &scale_o) {
|
||||
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(scale_o); ++mi) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale_o(mi); }
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int kNRows>
|
||||
struct Softmax {
|
||||
|
||||
using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
|
||||
TensorT row_max, row_sum;
|
||||
|
||||
__forceinline__ __device__ Softmax() {};
|
||||
|
||||
template<bool Is_first, bool Check_inf=false, typename Tensor0>
|
||||
__forceinline__ __device__ TensorT softmax(Tensor0 &acc_s, float softmax_scale_log2) {
|
||||
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
|
||||
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
|
||||
static_assert(decltype(size<0>(scores))::value == kNRows);
|
||||
TensorT scale_o;
|
||||
clear(scale_o);
|
||||
if (Is_first) {
|
||||
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
|
||||
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
|
||||
flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
|
||||
} else {
|
||||
Tensor scores_max_prev = make_fragment_like(row_max);
|
||||
cute::copy(row_max, scores_max_prev);
|
||||
flash::template reduce_max</*zero_init=*/false>(scores, row_max);
|
||||
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
float scores_max_cur = !Check_inf
|
||||
? row_max(mi)
|
||||
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
|
||||
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
|
||||
scale_o(mi) = scores_scale;
|
||||
row_sum(mi) *= scores_scale;
|
||||
}
|
||||
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
|
||||
// We don't do the reduce across threads here since we don't need to use the row_sum.
|
||||
// We do that reduce at the end when we need to normalize the softmax.
|
||||
flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
|
||||
}
|
||||
return scale_o;
|
||||
};
|
||||
|
||||
template<bool Is_dropout=false, bool Split=false, typename Tensor0>
|
||||
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) {
|
||||
SumOp<float> sum_op;
|
||||
quad_allreduce_(row_sum, row_sum, sum_op);
|
||||
TensorT lse = make_fragment_like(row_sum);
|
||||
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
|
||||
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
|
||||
float sum = row_sum(mi);
|
||||
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
|
||||
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
|
||||
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
|
||||
}
|
||||
return lse;
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace flash
|
||||
65
examples/68_hopper_flash_mla/csrc/static_switch.h
Normal file
65
examples/68_hopper_flash_mla/csrc/static_switch.h
Normal file
@ -0,0 +1,65 @@
|
||||
#pragma once
|
||||
|
||||
#define CHECK_CUDA(call) \
|
||||
do { \
|
||||
cudaError_t status_ = call; \
|
||||
if (status_ != cudaSuccess) { \
|
||||
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
|
||||
exit(1); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError())
|
||||
|
||||
|
||||
#define FLASH_ASSERT(cond) \
|
||||
do { \
|
||||
if (not (cond)) { \
|
||||
fprintf(stderr, "Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \
|
||||
exit(1); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
|
||||
#define FLASH_DEVICE_ASSERT(cond) \
|
||||
do { \
|
||||
if (not (cond)) { \
|
||||
printf("Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \
|
||||
asm("trap;"); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
|
||||
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
constexpr static bool CONST_NAME = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
constexpr static bool CONST_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
|
||||
|
||||
#define MLA_NUM_SPLITS_SWITCH(NUM_SPLITS, NAME, ...) \
|
||||
[&] { \
|
||||
if (NUM_SPLITS <= 32) { \
|
||||
constexpr static int NAME = 32; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (NUM_SPLITS <= 64) { \
|
||||
constexpr static int NAME = 64; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (NUM_SPLITS <= 96) { \
|
||||
constexpr static int NAME = 96; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (NUM_SPLITS <= 128) { \
|
||||
constexpr static int NAME = 128; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (NUM_SPLITS <= 160) { \
|
||||
constexpr static int NAME = 160; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
FLASH_ASSERT(false); \
|
||||
} \
|
||||
}()
|
||||
238
examples/68_hopper_flash_mla/csrc/utils.h
Normal file
238
examples/68_hopper_flash_mla/csrc/utils.h
Normal file
@ -0,0 +1,238 @@
|
||||
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/utils.h
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace flash {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct MaxOp {
|
||||
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MaxOp<float> {
|
||||
// This is slightly faster
|
||||
__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct SumOp {
|
||||
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int THREADS>
|
||||
struct Allreduce {
|
||||
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
||||
template<typename T, typename Operator>
|
||||
static __device__ __forceinline__ T run(T x, Operator &op) {
|
||||
constexpr int OFFSET = THREADS / 2;
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
|
||||
return Allreduce<OFFSET>::run(x, op);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<>
|
||||
struct Allreduce<2> {
|
||||
template<typename T, typename Operator>
|
||||
static __device__ __forceinline__ T run(T x, Operator &op) {
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <bool zero_init=false, int wg_wait=0, bool arrive=true, bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
|
||||
__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) {
|
||||
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
|
||||
// Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
|
||||
if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
|
||||
warpgroup_fence_operand(tCrC);
|
||||
if constexpr (arrive) {
|
||||
warpgroup_arrive();
|
||||
}
|
||||
if constexpr (zero_init) {
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
} else {
|
||||
// cute::gemm(tiled_mma, tCrA, tCrB, tCrC);
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
}
|
||||
if constexpr (commit) {
|
||||
warpgroup_commit_batch();
|
||||
}
|
||||
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
|
||||
warpgroup_fence_operand(tCrC);
|
||||
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||
// For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
|
||||
template<bool Transposed=false, typename Layout0>
|
||||
__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout0 acc_layout) {
|
||||
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
|
||||
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
|
||||
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
auto l = acc_layout;
|
||||
if constexpr (!Transposed) {
|
||||
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)));
|
||||
} else {
|
||||
return make_layout(make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l)));
|
||||
}
|
||||
|
||||
} else { // SM80
|
||||
static_assert(decltype(size<0>(acc_layout))::value == 4);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
|
||||
if constexpr (!Transposed) {
|
||||
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
|
||||
} else {
|
||||
return make_layout(make_layout(get<0, 0>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l)));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
|
||||
// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8.
|
||||
// For SM90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N))
|
||||
// For SM90, FP8, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((4, 2, 2), MMA_M, (N / 32, MMA_N))
|
||||
template<typename MMA_Traits, typename Layout0>
|
||||
__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout0 acc_layout) {
|
||||
using X = Underscore;
|
||||
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
|
||||
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
|
||||
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);
|
||||
if constexpr (sizeof(typename MMA_Traits::ValTypeA) == 2) {
|
||||
auto l = logical_divide(get<0, 2>(acc_layout), Tile<_2>{}); // ((2, N / 16))
|
||||
return make_layout(make_layout(get<0, 0>(acc_layout), get<0, 1>(acc_layout), get<0, 0>(l)), get<1>(acc_layout), coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout))));
|
||||
} else {
|
||||
static_assert(sizeof(typename MMA_Traits::ValTypeA) == 1);
|
||||
static_assert(decltype(stride<0, 0>(acc_layout))::value == 1);
|
||||
static_assert(decltype(stride<0, 1>(acc_layout))::value == 2);
|
||||
auto l = logical_divide(get<0, 2>(acc_layout), Tile<Layout<Shape<_2, _2>>>{}); // (((2, 2), N / 32))
|
||||
// This combines the first two modes (<0, 0> and <0, 1>) into one mode.
|
||||
// Will require register shuffling later to be correct.
|
||||
return make_layout(make_layout(Layout<_4>{}, get<0, 0, 0>(l), get<0, 0, 1>(l)),
|
||||
get<1>(acc_layout),
|
||||
coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); // ((4, 2, 2), MMA_M, N / 32 * MMA_N)
|
||||
// This combination is right but doesn't work with register shuffling.
|
||||
// return make_layout(make_layout(coalesce(make_layout(get<0, 0>(acc_layout), get<0, 0, 0>(l))), get<0, 1>(acc_layout), get<0, 0, 1>(l)),
|
||||
// get<1>(acc_layout),
|
||||
// coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout))));
|
||||
}
|
||||
} else { // SM80
|
||||
static_assert(decltype(size<0>(acc_layout))::value == 4);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
constexpr int mma_shape_K = get<2>(typename MMA_Traits::Shape_MNK{});
|
||||
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
|
||||
if constexpr (mma_shape_K == 8) {
|
||||
return acc_layout;
|
||||
} else {
|
||||
auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
|
||||
return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename To_type, typename Engine, typename Layout>
|
||||
__forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
|
||||
using From_type = typename Engine::value_type;
|
||||
constexpr int numel = decltype(size(tensor))::value;
|
||||
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
|
||||
// HACK: this requires tensor to be "contiguous"
|
||||
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
|
||||
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Blocks until all but N previous cp.async.commit_group operations have committed.
|
||||
// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all
|
||||
// (which is equivalent to commit_group then wait_group 0).
|
||||
// Instead we just call cp.async.wait_group 0, which is slightly faster.
|
||||
// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113
|
||||
template <int N>
|
||||
CUTE_HOST_DEVICE
|
||||
void cp_async_wait() {
|
||||
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
|
||||
asm volatile("cp.async.wait_group %0;\n" :: "n"(N));
|
||||
#endif
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
|
||||
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
||||
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
|
||||
__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
|
||||
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
|
||||
Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {
|
||||
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
|
||||
// There's no case where !Clear_OOB_K && Clear_OOB_MN
|
||||
static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(S); ++m) {
|
||||
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size<2>(S); ++k) {
|
||||
if (Is_even_K || predicate_K(k)) {
|
||||
cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
|
||||
} else if (Clear_OOB_K) {
|
||||
cute::clear(D(_, m, k));
|
||||
}
|
||||
}
|
||||
} else if (Clear_OOB_MN) {
|
||||
cute::clear(D(_, m, _));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace flash
|
||||
6
examples/68_hopper_flash_mla/flash_mla/__init__.py
Normal file
6
examples/68_hopper_flash_mla/flash_mla/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
__version__ = "1.0.0"
|
||||
|
||||
from flash_mla.flash_mla_interface import (
|
||||
get_mla_metadata,
|
||||
flash_mla_with_kvcache,
|
||||
)
|
||||
@ -0,0 +1,67 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import flash_mla_cuda
|
||||
|
||||
|
||||
def get_mla_metadata(
|
||||
cache_seqlens: torch.Tensor,
|
||||
num_heads_per_head_k: int,
|
||||
num_heads_k: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
cache_seqlens: (batch_size), dtype torch.int32.
|
||||
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
|
||||
num_heads_k: num_heads_k.
|
||||
|
||||
Returns:
|
||||
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
|
||||
num_splits: (batch_size + 1), dtype torch.int32.
|
||||
"""
|
||||
return flash_mla_cuda.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k)
|
||||
|
||||
|
||||
def flash_mla_with_kvcache(
|
||||
q: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
cache_seqlens: torch.Tensor,
|
||||
head_dim_v: int,
|
||||
tile_scheduler_metadata: torch.Tensor,
|
||||
num_splits: torch.Tensor,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch_size, seq_len_q, num_heads_q, head_dim).
|
||||
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
|
||||
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
|
||||
cache_seqlens: (batch_size), torch.int32.
|
||||
head_dim_v: Head dimension of v.
|
||||
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
|
||||
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
|
||||
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
|
||||
causal: bool. Whether to apply causal attention mask.
|
||||
|
||||
Returns:
|
||||
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
||||
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
|
||||
"""
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
|
||||
q,
|
||||
k_cache,
|
||||
None,
|
||||
head_dim_v,
|
||||
cache_seqlens,
|
||||
block_table,
|
||||
softmax_scale,
|
||||
causal,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
)
|
||||
return out, softmax_lse
|
||||
87
examples/68_hopper_flash_mla/setup.py
Normal file
87
examples/68_hopper_flash_mla/setup.py
Normal file
@ -0,0 +1,87 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
from torch.utils.cpp_extension import (
|
||||
BuildExtension,
|
||||
CUDAExtension,
|
||||
IS_WINDOWS,
|
||||
)
|
||||
|
||||
DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") == "TRUE"
|
||||
|
||||
def append_nvcc_threads(nvcc_extra_args):
|
||||
nvcc_threads = os.getenv("NVCC_THREADS") or "32"
|
||||
return nvcc_extra_args + ["--threads", nvcc_threads]
|
||||
|
||||
def get_sources():
|
||||
sources = [
|
||||
"csrc/flash_api.cpp",
|
||||
"csrc/flash_fwd_mla_bf16_sm90.cu",
|
||||
"csrc/flash_fwd_mla_metadata.cu",
|
||||
]
|
||||
|
||||
if not DISABLE_FP16:
|
||||
sources.append("csrc/flash_fwd_mla_fp16_sm90.cu")
|
||||
|
||||
return sources
|
||||
|
||||
def get_features_args():
|
||||
features_args = []
|
||||
if DISABLE_FP16:
|
||||
features_args.append("-DFLASH_MLA_DISABLE_FP16")
|
||||
return features_args
|
||||
|
||||
cc_flag = []
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_90a,code=sm_90a")
|
||||
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
if IS_WINDOWS:
|
||||
cxx_args = ["/O2", "/std:c++17", "/DNDEBUG", "/W0"]
|
||||
else:
|
||||
cxx_args = ["-O3", "-std=c++17", "-DNDEBUG", "-Wno-deprecated-declarations"]
|
||||
|
||||
ext_modules = []
|
||||
ext_modules.append(
|
||||
CUDAExtension(
|
||||
name="flash_mla_cuda",
|
||||
sources=get_sources(),
|
||||
extra_compile_args={
|
||||
"cxx": cxx_args + get_features_args(),
|
||||
"nvcc": append_nvcc_threads(
|
||||
[
|
||||
"-O3",
|
||||
"-std=c++17",
|
||||
"-DNDEBUG",
|
||||
"-D_USE_MATH_DEFINES",
|
||||
"-Wno-deprecated-declarations",
|
||||
"-U__CUDA_NO_HALF_OPERATORS__",
|
||||
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"-U__CUDA_NO_HALF2_OPERATORS__",
|
||||
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
|
||||
"--expt-relaxed-constexpr",
|
||||
"--expt-extended-lambda",
|
||||
"--use_fast_math",
|
||||
"--ptxas-options=-v,--register-usage-level=10"
|
||||
]
|
||||
+ cc_flag
|
||||
) + get_features_args(),
|
||||
},
|
||||
include_dirs=[
|
||||
Path(this_dir) / "csrc",
|
||||
Path(this_dir) / ".." / ".." / "include",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
setup(
|
||||
name="flash_mla",
|
||||
version="1.0.0",
|
||||
packages=find_packages(include=['flash_mla']),
|
||||
ext_modules=ext_modules,
|
||||
cmdclass={"build_ext": BuildExtension},
|
||||
)
|
||||
153
examples/68_hopper_flash_mla/tests/test_flash_mla.py
Normal file
153
examples/68_hopper_flash_mla/tests/test_flash_mla.py
Normal file
@ -0,0 +1,153 @@
|
||||
import argparse
|
||||
import math
|
||||
import random
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from flash_mla import flash_mla_with_kvcache, get_mla_metadata
|
||||
|
||||
|
||||
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
|
||||
query = query.float()
|
||||
key = key.float()
|
||||
value = value.float()
|
||||
key = key.repeat_interleave(h_q // h_kv, dim=0)
|
||||
value = value.repeat_interleave(h_q // h_kv, dim=0)
|
||||
attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
|
||||
if is_causal:
|
||||
s_q = query.shape[-2]
|
||||
s_k = key.shape[-2]
|
||||
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)
|
||||
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q)
|
||||
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
||||
attn_bias.to(query.dtype)
|
||||
attn_weight += attn_bias
|
||||
lse = attn_weight.logsumexp(dim=-1)
|
||||
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
|
||||
return attn_weight @ value, lse
|
||||
|
||||
|
||||
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
|
||||
x, y = x.double(), y.double()
|
||||
RMSE = ((x - y) * (x - y)).mean().sqrt().item()
|
||||
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
|
||||
amax_diff = (x - y).abs().max().item()
|
||||
# print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}")
|
||||
assert cos_diff < 1e-5
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen):
|
||||
print(
|
||||
f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}"
|
||||
)
|
||||
|
||||
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
|
||||
if varlen:
|
||||
for i in range(b):
|
||||
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q)
|
||||
total_seqlens = cache_seqlens.sum().item()
|
||||
mean_seqlens = cache_seqlens.float().mean().int().item()
|
||||
max_seqlen = cache_seqlens.max().item()
|
||||
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
|
||||
# print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}")
|
||||
|
||||
q = torch.randn(b, s_q, h_q, d)
|
||||
block_size = 64
|
||||
block_table = torch.arange(
|
||||
b * max_seqlen_pad // block_size, dtype=torch.int32
|
||||
).view(b, max_seqlen_pad // block_size)
|
||||
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
|
||||
for i in range(b):
|
||||
blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = (
|
||||
float("nan")
|
||||
)
|
||||
blocked_v = blocked_k[..., :dv]
|
||||
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||
cache_seqlens, s_q * h_q // h_kv, h_kv
|
||||
)
|
||||
|
||||
def flash_mla():
|
||||
return flash_mla_with_kvcache(
|
||||
q,
|
||||
blocked_k,
|
||||
block_table,
|
||||
cache_seqlens,
|
||||
dv,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
def ref_mla():
|
||||
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
|
||||
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
|
||||
for i in range(b):
|
||||
begin = i * max_seqlen_pad
|
||||
end = begin + cache_seqlens[i]
|
||||
O, LSE = scaled_dot_product_attention(
|
||||
q[i].transpose(0, 1),
|
||||
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
|
||||
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
|
||||
h_q=h_q,
|
||||
h_kv=h_kv,
|
||||
is_causal=causal,
|
||||
)
|
||||
out[i] = O.transpose(0, 1)
|
||||
lse[i] = LSE
|
||||
return out, lse
|
||||
|
||||
out_flash, lse_flash = flash_mla()
|
||||
out_torch, lse_torch = ref_mla()
|
||||
cal_diff(out_flash, out_torch, "out")
|
||||
cal_diff(lse_flash, lse_torch, "lse")
|
||||
|
||||
t = triton.testing.do_bench(flash_mla)
|
||||
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
|
||||
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (
|
||||
torch.finfo(q.dtype).bits // 8
|
||||
)
|
||||
print(
|
||||
f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s"
|
||||
)
|
||||
|
||||
|
||||
def main(torch_dtype):
|
||||
device = torch.device("cuda:0")
|
||||
torch.set_default_dtype(torch_dtype)
|
||||
torch.set_default_device(device)
|
||||
torch.cuda.set_device(device)
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
|
||||
h_kv = 1
|
||||
d, dv = 576, 512
|
||||
causal = True
|
||||
|
||||
for b in [128]:
|
||||
for s in [4096, 8192]:
|
||||
for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1
|
||||
for s_q in [1, 2]: # MTP = 1, 2
|
||||
for varlen in [False, True]:
|
||||
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
choices=["bf16", "fp16"],
|
||||
default="bf16",
|
||||
help="Data type to use for testing (bf16 or fp16)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
torch_dtype = torch.bfloat16
|
||||
if args.dtype == "fp16":
|
||||
torch_dtype = torch.float16
|
||||
|
||||
main(torch_dtype)
|
||||
@ -146,6 +146,7 @@ foreach(EXAMPLE
|
||||
64_ada_fp8_gemm_grouped
|
||||
65_distributed_gemm
|
||||
67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling
|
||||
68_hopper_flash_mla
|
||||
69_hopper_mixed_dtype_grouped_gemm
|
||||
70_blackwell_gemm
|
||||
71_blackwell_gemm_with_collective_builder
|
||||
|
||||
Reference in New Issue
Block a user