DeepGemm Support - Step 2 (#2142)
* add cpp example for DeepGemm. * add grouped_gemm_contiguous. * add groupedgemm_masked. * add python example and tests.
This commit is contained in:
@ -68,7 +68,7 @@
|
||||
#include "helper.h"
|
||||
// #include "reference/host/gemm_with_groupwise_scaling.h"
|
||||
|
||||
#include "include/deep_gemm/fp8_gemm.cuh"
|
||||
#include "deep_gemm/include/deep_gemm/fp8_gemm.cuh"
|
||||
|
||||
// using namespace cute;
|
||||
using namespace deep_gemm;
|
||||
|
||||
@ -0,0 +1,13 @@
|
||||
import torch
|
||||
|
||||
from . import jit
|
||||
from .jit_kernels import (
|
||||
gemm_fp8_fp8_bf16_nt,
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
|
||||
cell_div,
|
||||
set_num_sms, get_num_sms,
|
||||
get_col_major_tma_aligned_tensor,
|
||||
get_m_alignment_for_contiguous_layout
|
||||
)
|
||||
from .utils import bench, bench_kineto, calc_diff
|
||||
@ -0,0 +1,444 @@
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wunknown-attributes"
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/arch/barrier.h>
|
||||
#include <cutlass/arch/reg_reconfig.h>
|
||||
|
||||
#include <cute/arch/cluster_sm90.hpp>
|
||||
#include <cute/arch/copy_sm90_desc.hpp>
|
||||
#include <cute/arch/copy_sm90_tma.hpp>
|
||||
|
||||
#include "mma_utils.cuh"
|
||||
#include "scheduler.cuh"
|
||||
#include "tma_utils.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
enum class Layout {
|
||||
RowMajor,
|
||||
ColMajor
|
||||
};
|
||||
|
||||
template <uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup>
|
||||
__device__ __host__ constexpr int get_num_threads_per_sm(int block_m) {
|
||||
DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group");
|
||||
return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads;
|
||||
}
|
||||
|
||||
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t kNumGroups, uint32_t kNumStages,
|
||||
uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup,
|
||||
uint32_t kNumTMAMulticast,
|
||||
GemmType kGemmType>
|
||||
__global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M), 1)
|
||||
fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
uint32_t shape_m,
|
||||
const __grid_constant__ CUtensorMap tensor_map_a,
|
||||
const __grid_constant__ CUtensorMap tensor_map_b,
|
||||
const __grid_constant__ CUtensorMap tensor_map_scales_a,
|
||||
const __grid_constant__ CUtensorMap tensor_map_d) {
|
||||
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
|
||||
// Scaling checks
|
||||
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
|
||||
DG_STATIC_ASSERT(cell_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block");
|
||||
|
||||
// Types
|
||||
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
|
||||
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
||||
|
||||
// Shared memory
|
||||
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16);
|
||||
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
|
||||
static constexpr uint32_t SHAPE_K_SCALES = cell_div(SHAPE_K, BLOCK_K);
|
||||
static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
|
||||
|
||||
// Configs
|
||||
constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K;
|
||||
constexpr uint32_t kNumThreads = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
|
||||
constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads;
|
||||
constexpr uint32_t kNumIterations = cell_div(SHAPE_K, kFullKOfAllStages);
|
||||
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
const uint32_t lane_idx = get_lane_id();
|
||||
|
||||
// Prefetch TMA descriptors at very beginning
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_a));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_b));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_scales_a));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_d));
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Align to 1024 bytes for swizzle-128B
|
||||
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
||||
DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
|
||||
|
||||
// Data on shared memory
|
||||
auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer);
|
||||
__nv_fp8_e4m3* smem_a[kNumStages];
|
||||
__nv_fp8_e4m3* smem_b[kNumStages];
|
||||
float* smem_scales_a[kNumStages];
|
||||
float* smem_scales_b;
|
||||
|
||||
// TMA Barrier for both divisible and non-divisible cases
|
||||
Barrier* full_barriers[kNumStages];
|
||||
Barrier* empty_barriers[kNumStages];
|
||||
|
||||
// Fill shared memory pointers
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumStages; ++ i) {
|
||||
smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
||||
smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
||||
smem_scales_a[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SCALES_A_SIZE_PER_STAGE);
|
||||
}
|
||||
smem_scales_b = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE));
|
||||
|
||||
// Fill barriers
|
||||
DG_STATIC_ASSERT(sizeof(Barrier) % sizeof(float) == 0, "Misaligned barriers");
|
||||
DG_STATIC_ASSERT(not kMustUseUniformedScaleB or SHAPE_K_SCALES % (sizeof(Barrier) / sizeof(float)) == 0, "Misaligned barriers");
|
||||
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_scales_b + SHAPE_K_SCALES * (kMustUseUniformedScaleB ? 1 : 2));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumStages; ++ i) {
|
||||
full_barriers[i] = barrier_start_ptr + i;
|
||||
empty_barriers[i] = barrier_start_ptr + kNumStages + i;
|
||||
}
|
||||
|
||||
// Initialize barriers
|
||||
DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "To many TMA multicast");
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumStages; ++ i) {
|
||||
full_barriers[i]->init(1);
|
||||
empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32);
|
||||
}
|
||||
|
||||
// Make initialized barrier visible in async proxy
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
(kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void();
|
||||
}
|
||||
|
||||
// Synchronize all threads to make barrier visible in normal memory model
|
||||
(kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
|
||||
|
||||
// For pipeline unrolling
|
||||
struct DivisibleK {};
|
||||
struct NotDivisibleK {};
|
||||
auto launch_k_iterations = [](const auto& func) {
|
||||
if constexpr (SHAPE_K % kFullKOfAllStages == 0) {
|
||||
for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter)
|
||||
func(k_iter, DivisibleK{});
|
||||
} else {
|
||||
for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter)
|
||||
func(k_iter, DivisibleK{});
|
||||
func(kNumIterations - 1, NotDivisibleK{});
|
||||
}
|
||||
};
|
||||
|
||||
// Register reconfigurations
|
||||
constexpr int kNumTMARegisters = 40;
|
||||
constexpr int kNumMathRegisters = 232;
|
||||
|
||||
// Block scheduler
|
||||
uint32_t m_block_idx, n_block_idx;
|
||||
auto scheduler = Scheduler<kGemmType, SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast>(shape_m, grouped_layout);
|
||||
|
||||
if (threadIdx.x >= kNumMathThreads) {
|
||||
// TMA warp-group for loading data
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
||||
|
||||
// NOTES: only one thread (or warp) will be used
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
launch_k_iterations([&](int k_iter, auto type) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
|
||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Wait consumer release
|
||||
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
|
||||
|
||||
// Issue TMA A with broadcasting
|
||||
auto& full_barrier = *full_barriers[s];
|
||||
int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
|
||||
tma_copy<kNumTMAMulticast>(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
|
||||
tma_copy<kNumTMAMulticast>(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_scales_a[s], m_block_idx * BLOCK_M,
|
||||
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
|
||||
|
||||
// Issue TMA B without broadcasting
|
||||
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
|
||||
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE);
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
|
||||
full_barriers[s]->arrive();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// To safely deconstruct distributed shared barriers, we need another round of empty waits
|
||||
if constexpr (kNumTMAMulticast > 1) {
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumStages; ++ s)
|
||||
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Math warp-groups for WGMMA
|
||||
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
||||
|
||||
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
|
||||
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0);
|
||||
const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8;
|
||||
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
// Decide the number of scales B to load
|
||||
DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N");
|
||||
uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters;
|
||||
if constexpr (not kMustUseUniformedScaleB) {
|
||||
num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8;
|
||||
num_full_iters = min(SHAPE_N - n_block_idx * BLOCK_N, BLOCK_N) / 8;
|
||||
}
|
||||
uint32_t num_scales_b = SHAPE_K_SCALES * (num_former_iters >= num_full_iters ? 1 : 2);
|
||||
|
||||
// Load B scales with math warp-groups
|
||||
// NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks
|
||||
if (threadIdx.x >= 32) {
|
||||
auto num_previous_lines = scheduler.get_global_idx<false>(cell_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx);
|
||||
auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES;
|
||||
#pragma unroll
|
||||
for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32)
|
||||
st_shared(smem_scales_b + i, __ldg(local_scales_b + i));
|
||||
}
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
|
||||
// Accumulation for WGMMA or CUDA promotion
|
||||
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0};
|
||||
|
||||
// Empty barrier arrival
|
||||
auto empty_barrier_arrive = [&](int s) {
|
||||
if constexpr (kNumTMAMulticast == 1) {
|
||||
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
|
||||
} else {
|
||||
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void();
|
||||
}
|
||||
};
|
||||
|
||||
// Launch MMAs
|
||||
launch_k_iterations([&](int k_iter, auto type) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
|
||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||
|
||||
#pragma unroll
|
||||
for (int s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Read B scales
|
||||
float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1;
|
||||
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
|
||||
if constexpr (not kMustUseUniformedScaleB)
|
||||
scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES);
|
||||
|
||||
// Wait TMA arrivals
|
||||
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
|
||||
|
||||
// Read A scales
|
||||
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
|
||||
auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1);
|
||||
|
||||
// Commit WGMMA instructions
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_arrive();
|
||||
#pragma unroll
|
||||
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
|
||||
auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1);
|
||||
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
|
||||
WGMMA::wgmma(desc_a, desc_b, accum, k);
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_wait<0>();
|
||||
|
||||
// Notify barrier arrival
|
||||
empty_barrier_arrive(s);
|
||||
|
||||
// Promote with scales
|
||||
float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0;
|
||||
float scale_0_1, scale_1_1;
|
||||
if constexpr (not kMustUseUniformedScaleB)
|
||||
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
||||
bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
|
||||
final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
|
||||
final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
|
||||
final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
|
||||
final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
|
||||
}
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
|
||||
empty_barrier_arrive(s);
|
||||
}
|
||||
});
|
||||
|
||||
// Write back to shared memory using STSM
|
||||
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
|
||||
#pragma unroll
|
||||
for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) {
|
||||
SM90_U32x4_STSM_N<nv_bfloat162>::copy(
|
||||
__float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}),
|
||||
__float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}),
|
||||
__float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}),
|
||||
__float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}),
|
||||
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16)
|
||||
);
|
||||
}
|
||||
if constexpr (WGMMA::kNumAccum % 8 != 0) {
|
||||
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
|
||||
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}),
|
||||
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}),
|
||||
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16
|
||||
);
|
||||
}
|
||||
cute::tma_store_fence();
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
|
||||
// Use TMA store to write back to global memory
|
||||
if (threadIdx.x == 0) {
|
||||
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N,
|
||||
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
|
||||
cute::tma_store_arrive();
|
||||
cute::tma_store_wait<0>();
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
}
|
||||
#else
|
||||
if (blockIdx.x == 0 and threadIdx.x == 0)
|
||||
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t kNumGroups, uint32_t kNumStages,
|
||||
uint32_t kNumTMAMulticast,
|
||||
GemmType kGemmType>
|
||||
class Gemm {
|
||||
private:
|
||||
using Barrier = cuda::barrier<cuda::thread_scope_block>;
|
||||
|
||||
public:
|
||||
Gemm() = default;
|
||||
|
||||
static void run(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
uint32_t shape_m,
|
||||
const CUtensorMap& tma_a_desc,
|
||||
const CUtensorMap& tma_b_desc,
|
||||
const CUtensorMap& tma_scales_a_desc,
|
||||
const CUtensorMap& tma_d_desc,
|
||||
cudaStream_t stream,
|
||||
int num_sms, uint32_t smem_size) {
|
||||
// NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps
|
||||
constexpr uint32_t kNumTMAThreads = 128;
|
||||
constexpr uint32_t kNumMathThreadsPerGroup = 128;
|
||||
auto kernel = fp8_gemm_kernel<SHAPE_N, SHAPE_K, BLOCK_M, BLOCK_N, BLOCK_K,
|
||||
kNumGroups, kNumStages, kNumTMAThreads, kNumMathThreadsPerGroup,
|
||||
kNumTMAMulticast, kGemmType>;
|
||||
DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess);
|
||||
|
||||
// Cluster launch
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = num_sms;
|
||||
config.blockDim = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
|
||||
config.dynamicSmemBytes = smem_size;
|
||||
config.stream = stream;
|
||||
|
||||
// Clusters for TMA multicast
|
||||
// NOTES: `>= 4` cluster size will cause performance degradation
|
||||
cudaLaunchAttribute attr;
|
||||
attr.id = cudaLaunchAttributeClusterDimension;
|
||||
attr.val.clusterDim = {kNumTMAMulticast, 1, 1};
|
||||
config.attrs = &attr;
|
||||
config.numAttrs = 1;
|
||||
|
||||
// Launch
|
||||
auto status = cudaLaunchKernelEx(&config, kernel,
|
||||
gmem_d, scales_b, grouped_layout,
|
||||
shape_m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc);
|
||||
DG_HOST_ASSERT(status == cudaSuccess);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_a_desc(T* global_address, uint32_t shape_m) {
|
||||
return make_2d_tma_desc(global_address, Layout::RowMajor,
|
||||
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_K, BLOCK_M, BLOCK_K);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_b_desc(T* global_address) {
|
||||
return make_2d_tma_desc(global_address, Layout::ColMajor,
|
||||
SHAPE_K, SHAPE_N * (kGemmType != GemmType::Normal ? kNumGroups : 1), BLOCK_K, BLOCK_N);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) {
|
||||
return make_2d_tma_desc(global_address, Layout::RowMajor,
|
||||
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N, BLOCK_M, BLOCK_N,
|
||||
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) {
|
||||
// Make TMA aligned to 16 bytes
|
||||
constexpr uint32_t kAlignment = 16 / sizeof(T);
|
||||
shape_m = cell_div(shape_m, kAlignment) * kAlignment;
|
||||
|
||||
return make_2d_tma_desc(global_address, Layout::ColMajor,
|
||||
shape_m, cell_div(SHAPE_K, BLOCK_K) * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), BLOCK_M, 1,
|
||||
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_desc(
|
||||
T* global_address, Layout layout,
|
||||
uint32_t gmem_rows, uint32_t gmem_cols,
|
||||
uint32_t smem_rows, uint32_t smem_cols,
|
||||
CUtensorMapSwizzle swizzle_type = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) {
|
||||
if (layout == Layout::RowMajor) {
|
||||
uint64_t gmem_dim[2] = {gmem_cols, gmem_rows};
|
||||
uint32_t smem_dim[2] = {smem_cols, smem_rows};
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_cols * sizeof(T), smem_dim, swizzle_type);
|
||||
} else {
|
||||
uint64_t gmem_dim[2] = {gmem_rows, gmem_cols};
|
||||
uint32_t smem_dim[2] = {smem_rows, smem_cols};
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_rows * sizeof(T), smem_dim, swizzle_type);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
}; // namespace deep_gemm
|
||||
|
||||
#pragma clang diagnostic pop
|
||||
@ -0,0 +1,885 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
struct SM90_64x16x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %10, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
||||
" %8,"
|
||||
" %9,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 16;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x24x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %14, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11},"
|
||||
" %12,"
|
||||
" %13,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 24;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x32x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %18, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
||||
" %16,"
|
||||
" %17,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 32;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x40x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %22, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n40k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19},"
|
||||
" %20,"
|
||||
" %21,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 40;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x48x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %26, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23},"
|
||||
" %24,"
|
||||
" %25,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 48;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x56x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %30, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n56k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27}, "
|
||||
" %28,"
|
||||
" %29,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 56;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x64x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %34, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31}, "
|
||||
" %32,"
|
||||
" %33,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 64;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x72x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %38, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n72k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35}, "
|
||||
" %36,"
|
||||
" %37,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 72;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x80x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %42, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39}, "
|
||||
" %40,"
|
||||
" %41,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 80;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x88x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %46, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n88k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43}, "
|
||||
" %44,"
|
||||
" %45,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 88;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x96x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %50, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47}, "
|
||||
" %48,"
|
||||
" %49,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 96;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x104x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
float& d48, float& d49, float& d50, float& d51,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %54, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n104k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
||||
" %48, %49, %50, %51}, "
|
||||
" %52,"
|
||||
" %53,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
||||
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
d[48], d[49], d[50], d[51],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 104;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x112x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %58, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
||||
" %48, %49, %50, %51, %52, %53, %54, %55}, "
|
||||
" %56,"
|
||||
" %57,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
||||
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 112;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x120x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
|
||||
float& d56, float& d57, float& d58, float& d59,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %62, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n120k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
||||
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
||||
" %56, %57, %58, %59}, "
|
||||
" %60,"
|
||||
" %61,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
||||
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
||||
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
|
||||
d[56], d[57], d[58], d[59],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 120;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x128x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
|
||||
float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %66, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
||||
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
||||
" %56, %57, %58, %59, %60, %61, %62, %63}, "
|
||||
" %64,"
|
||||
" %65,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
||||
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
||||
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
|
||||
d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 128;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x192x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
|
||||
float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63,
|
||||
float& d64, float& d65, float& d66, float& d67, float& d68, float& d69, float& d70, float& d71,
|
||||
float& d72, float& d73, float& d74, float& d75, float& d76, float& d77, float& d78, float& d79,
|
||||
float& d80, float& d81, float& d82, float& d83, float& d84, float& d85, float& d86, float& d87,
|
||||
float& d88, float& d89, float& d90, float& d91, float& d92, float& d93, float& d94, float& d95,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %98, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
||||
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
||||
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
||||
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
||||
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
||||
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
||||
" %88, %89, %90, %91, %92, %93, %94, %95}, "
|
||||
" %96,"
|
||||
" %97,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
||||
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
||||
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63),
|
||||
"+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71),
|
||||
"+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79),
|
||||
"+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87),
|
||||
"+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
|
||||
d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63],
|
||||
d[64], d[65], d[66], d[67], d[68], d[69], d[70], d[71],
|
||||
d[72], d[73], d[74], d[75], d[76], d[77], d[78], d[79],
|
||||
d[80], d[81], d[82], d[83], d[84], d[85], d[86], d[87],
|
||||
d[88], d[89], d[90], d[91], d[92], d[93], d[94], d[95],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 192;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
template <typename dtype_t>
|
||||
struct SM90_U32x2_STSM_N {
|
||||
__device__ __forceinline__ static void
|
||||
copy(dtype_t src_0, dtype_t src_1, void* smem_dst) {
|
||||
const uint32_t src[2] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1)};
|
||||
asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n"
|
||||
:: "l"(smem_dst), "r"(src[0]), "r"(src[1]));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename dtype_t>
|
||||
struct SM90_U32x4_STSM_N {
|
||||
__device__ __forceinline__ static void
|
||||
copy(dtype_t src_0, dtype_t src_1, dtype_t src_2, dtype_t src_3, void* smem_dst) {
|
||||
const uint32_t src[4] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1),
|
||||
*reinterpret_cast<uint32_t*>(&src_2), *reinterpret_cast<uint32_t*>(&src_3)};
|
||||
asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n"
|
||||
:: "l"(smem_dst), "r"(src[0]), "r"(src[1]), "r"(src[2]), "r"(src[3]));
|
||||
}
|
||||
};
|
||||
|
||||
__device__ void warpgroup_arrive() {
|
||||
asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory");
|
||||
}
|
||||
|
||||
__device__ void warpgroup_commit_batch() {
|
||||
asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory");
|
||||
}
|
||||
|
||||
__device__ void warpgroup_fence_operand(float& reg) {
|
||||
asm volatile("" : "+f"(reg) :: "memory");
|
||||
}
|
||||
|
||||
__forceinline__ __device__ uint32_t get_lane_id() {
|
||||
uint32_t lane_id;
|
||||
asm("mov.u32 %0, %laneid;" : "=r"(lane_id));
|
||||
return lane_id;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint32_t ld_shared(const uint32_t* __restrict__ ptr) {
|
||||
uint32_t ret;
|
||||
asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int4 ld_shared(const int4* __restrict__ ptr) {
|
||||
int4 ret;
|
||||
asm volatile("ld.shared.v4.s32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float ld_shared(const float* __restrict__ ptr) {
|
||||
float ret;
|
||||
asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_shared(const float* ptr, float val) {
|
||||
asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) {
|
||||
asm volatile("st.shared.u32 [%0], %1;" :: "l"(ptr), "r"(val));
|
||||
}
|
||||
|
||||
template <int N>
|
||||
__device__ void warpgroup_wait() {
|
||||
DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]");
|
||||
asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory");
|
||||
}
|
||||
|
||||
union GmmaDescriptor {
|
||||
__host__ __device__ constexpr GmmaDescriptor() noexcept: desc_(0) {}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor(uint64_t desc) noexcept: desc_(desc) {}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor const &t) noexcept: desc_(t.desc_) {}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor &&t) noexcept: desc_(t.desc_) {}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor const &t) noexcept {
|
||||
desc_ = t.desc_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor &&t) noexcept {
|
||||
desc_ = t.desc_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
uint64_t desc_;
|
||||
uint32_t reg32_[2];
|
||||
uint16_t reg16_[4];
|
||||
|
||||
struct {
|
||||
uint16_t start_address_: 14, : 2;
|
||||
uint16_t leading_byte_offset_: 14, : 2;
|
||||
uint16_t stride_byte_offset_: 14, : 2;
|
||||
uint8_t : 1, base_offset_: 3, : 4;
|
||||
uint8_t : 6, layout_type_: 2;
|
||||
} bitfield;
|
||||
|
||||
// Decay to an `uint64_t`
|
||||
__host__ __device__ constexpr operator uint64_t() const noexcept { return desc_; }
|
||||
};
|
||||
|
||||
template <class PointerType>
|
||||
__device__ GmmaDescriptor make_smem_desc(PointerType smem_ptr, int layout_type,
|
||||
int leading_byte_offset = 0,
|
||||
int stride_byte_offset = 1024) {
|
||||
GmmaDescriptor desc;
|
||||
auto uint_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
desc.bitfield.start_address_ = uint_ptr >> 4;
|
||||
desc.bitfield.layout_type_ = layout_type;
|
||||
desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4;
|
||||
desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4;
|
||||
desc.bitfield.base_offset_ = 0;
|
||||
return desc;
|
||||
}
|
||||
|
||||
template <int N>
|
||||
struct FP8MMASelector {
|
||||
static constexpr auto select_type() {
|
||||
if constexpr (N == 16) return SM90_64x16x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 24) return SM90_64x24x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 32) return SM90_64x32x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 40) return SM90_64x40x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 48) return SM90_64x48x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 56) return SM90_64x56x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 64) return SM90_64x64x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 72) return SM90_64x72x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 80) return SM90_64x80x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 88) return SM90_64x88x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 96) return SM90_64x96x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 104) return SM90_64x104x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 112) return SM90_64x112x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 120) return SM90_64x120x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 128) return SM90_64x128x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 192) return SM90_64x192x32_F32E4M3E4M3_SS();
|
||||
}
|
||||
|
||||
using type = decltype(select_type());
|
||||
};
|
||||
|
||||
} // namespace deep_gemm
|
||||
@ -0,0 +1,103 @@
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
enum class GemmType {
|
||||
Normal,
|
||||
GroupedContiguous,
|
||||
GroupedMasked
|
||||
};
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init"
|
||||
template <GemmType kGemmType,
|
||||
uint32_t SHAPE_N, uint32_t BLOCK_M, uint32_t BLOCK_N,
|
||||
uint32_t kNumGroups, uint32_t kNumTMAMulticast,
|
||||
uint32_t kNumNBlocks = cell_div(SHAPE_N, BLOCK_N),
|
||||
uint32_t kNumNBlocksPerGroup = 16>
|
||||
struct Scheduler {
|
||||
int current_iter = -1;
|
||||
uint32_t num_aligned_m_blocks;
|
||||
|
||||
// For normal GEMM
|
||||
// Maybe not used in the masked grouped GEMM
|
||||
uint32_t num_blocks;
|
||||
|
||||
// For grouped GEMM
|
||||
int* grouped_layout;
|
||||
// Only used for masked layout
|
||||
uint32_t curr_group_idx, curr_cumsum;
|
||||
|
||||
__device__ __forceinline__ explicit Scheduler(const uint32_t shape_m,
|
||||
int* grouped_layout = nullptr) {
|
||||
num_aligned_m_blocks = cell_div(shape_m, BLOCK_M);
|
||||
if constexpr (kGemmType == GemmType::Normal) {
|
||||
num_blocks = num_aligned_m_blocks * kNumNBlocks;
|
||||
} else if (kGemmType == GemmType::GroupedContiguous) {
|
||||
num_blocks = num_aligned_m_blocks * kNumNBlocks;
|
||||
this->grouped_layout = grouped_layout;
|
||||
} else if (kGemmType == GemmType::GroupedMasked) {
|
||||
curr_group_idx = curr_cumsum = 0;
|
||||
this->grouped_layout = grouped_layout;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) {
|
||||
DG_STATIC_ASSERT(kNumNBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
|
||||
|
||||
// Swizzle for better L2 usages
|
||||
auto num_blocks_per_group = num_m_blocks * kNumNBlocksPerGroup;
|
||||
auto group_idx = block_idx / num_blocks_per_group;
|
||||
auto first_n_block_idx = group_idx * kNumNBlocksPerGroup;
|
||||
auto num_n_blocks_in_group = min(kNumNBlocksPerGroup, kNumNBlocks - first_n_block_idx);
|
||||
auto in_group_idx = block_idx % num_blocks_per_group;
|
||||
m_block_idx = in_group_idx / num_n_blocks_in_group;
|
||||
n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group;
|
||||
}
|
||||
|
||||
template <bool kIgnoreGroupedForGroupedContiguous=true>
|
||||
__device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size,
|
||||
const uint32_t& block_idx, const uint32_t& m_block_idx=0) {
|
||||
if constexpr (kGemmType == GemmType::Normal) {
|
||||
return block_idx * block_size;
|
||||
} else if (kGemmType == GemmType::GroupedContiguous) {
|
||||
auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg(grouped_layout + m_block_idx * BLOCK_M);
|
||||
return offset * shape_dim + block_idx * block_size;
|
||||
} else if (kGemmType == GemmType::GroupedMasked) {
|
||||
return curr_group_idx * shape_dim + block_idx * block_size;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) {
|
||||
const auto next_block_idx = (++ current_iter) * gridDim.x + blockIdx.x;
|
||||
|
||||
if constexpr (kGemmType == GemmType::GroupedMasked) {
|
||||
uint32_t num_m_blocks;
|
||||
while (true) {
|
||||
// End of the task
|
||||
if (curr_group_idx == kNumGroups)
|
||||
return false;
|
||||
|
||||
// Within current group
|
||||
num_m_blocks = cell_div(static_cast<uint32_t>(__ldg(grouped_layout + curr_group_idx)), BLOCK_M);
|
||||
auto current_m_block_cumsum = curr_cumsum + num_m_blocks;
|
||||
if (next_block_idx < current_m_block_cumsum * kNumNBlocks)
|
||||
break;
|
||||
|
||||
// Move to check the next group
|
||||
curr_group_idx ++, curr_cumsum = current_m_block_cumsum;
|
||||
}
|
||||
|
||||
get_swizzled_block_idx(num_m_blocks, next_block_idx - curr_cumsum * kNumNBlocks, m_block_idx, n_block_idx);
|
||||
} else {
|
||||
if (next_block_idx >= num_blocks)
|
||||
return false;
|
||||
|
||||
get_swizzled_block_idx(num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
} // namespace deep_gemm
|
||||
@ -0,0 +1,96 @@
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <cuda.h>
|
||||
#include <cudaTypedefs.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda/barrier>
|
||||
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
template <class T>
|
||||
constexpr CUtensorMapDataType get_CUtensorMapDataType() {
|
||||
if constexpr (std::is_same<T, uint8_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
} else if constexpr (std::is_same<T, __nv_fp8_e4m3>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
} else if constexpr (std::is_same<T, __nv_fp8_e5m2>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
} else if constexpr (std::is_same<T, uint16_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT16;
|
||||
} else if constexpr (std::is_same<T, uint32_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT32;
|
||||
} else if constexpr (std::is_same<T, uint64_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT64;
|
||||
} else if constexpr (std::is_same<T, int32_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_INT32;
|
||||
} else if constexpr (std::is_same<T, int64_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_INT64;
|
||||
} else if constexpr (std::is_same<T, __half>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
|
||||
} else if constexpr (std::is_same<T, float>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
|
||||
} else if constexpr (std::is_same<T, double>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_FLOAT64;
|
||||
}
|
||||
}
|
||||
|
||||
PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() {
|
||||
// Get pointer to `cuTensorMapEncodeTiled`
|
||||
cudaDriverEntryPointQueryResult driver_status;
|
||||
void* cuTensorMapEncodeTiled_ptr = nullptr;
|
||||
|
||||
#if CUDA_VERSION >= 12050
|
||||
cudaGetDriverEntryPointByVersion("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, 12000,
|
||||
cudaEnableDefault, &driver_status);
|
||||
#else
|
||||
cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr,
|
||||
cudaEnableDefault, &driver_status);
|
||||
#endif
|
||||
|
||||
if (driver_status != cudaDriverEntryPointSuccess)
|
||||
throw std::runtime_error("driver_status != cudaDriverEntryPointSuccess");
|
||||
return reinterpret_cast<PFN_cuTensorMapEncodeTiled>(cuTensorMapEncodeTiled_ptr);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2],
|
||||
uint64_t stride_in_bytes, uint32_t smem_dim[2],
|
||||
CUtensorMapSwizzle swizzle_type,
|
||||
PFN_cuTensorMapEncodeTiled encode_func = nullptr) {
|
||||
CUtensorMap tensor_map{};
|
||||
constexpr uint32_t rank = 2;
|
||||
uint64_t global_stride[rank - 1] = {stride_in_bytes};
|
||||
uint32_t elem_strides[rank] = {1, 1};
|
||||
|
||||
if (encode_func == nullptr)
|
||||
encode_func = get_cuTensorMapEncodeTiled();
|
||||
|
||||
auto result = encode_func(
|
||||
&tensor_map, get_CUtensorMapDataType<typename std::remove_cv<T>::type>(), rank,
|
||||
global_address, gmem_dim, global_stride, smem_dim, elem_strides,
|
||||
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type,
|
||||
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
|
||||
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
|
||||
DG_HOST_ASSERT(result == CUDA_SUCCESS);
|
||||
return tensor_map;
|
||||
}
|
||||
|
||||
template <uint32_t kNumTMAMulticast = 1>
|
||||
__device__ __forceinline__ void
|
||||
tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,
|
||||
int32_t const& crd_0, int32_t const& crd_1) {
|
||||
constexpr auto cache_hint = static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL);
|
||||
if constexpr (kNumTMAMulticast == 1) {
|
||||
cute::SM90_TMA_LOAD_2D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1);
|
||||
} else if (cute::block_rank_in_cluster() == 0) {
|
||||
cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, (1 << kNumTMAMulticast) - 1, cache_hint, smem_ptr, crd_0, crd_1);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
@ -0,0 +1,48 @@
|
||||
#pragma once
|
||||
|
||||
#include <exception>
|
||||
|
||||
#ifdef __CLION_IDE__
|
||||
__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { asm volatile("trap;"); }
|
||||
#define printf host_device_printf
|
||||
#endif
|
||||
|
||||
class AssertionException : public std::exception {
|
||||
private:
|
||||
std::string message{};
|
||||
|
||||
public:
|
||||
explicit AssertionException(const std::string& message) : message(message) {}
|
||||
|
||||
const char *what() const noexcept override { return message.c_str(); }
|
||||
};
|
||||
|
||||
#ifndef DG_HOST_ASSERT
|
||||
#define DG_HOST_ASSERT(cond) \
|
||||
do { \
|
||||
if (not (cond)) { \
|
||||
printf("Assertion failed: %s:%d, condition: %s\n", \
|
||||
__FILE__, __LINE__, #cond); \
|
||||
throw AssertionException("Assertion failed: " #cond); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#ifndef DG_DEVICE_ASSERT
|
||||
#define DG_DEVICE_ASSERT(cond) \
|
||||
do { \
|
||||
if (not (cond)) { \
|
||||
printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \
|
||||
asm("trap;"); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#ifndef DG_STATIC_ASSERT
|
||||
#define DG_STATIC_ASSERT(cond, reason) static_assert(cond, reason)
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
__device__ __host__ constexpr T cell_div(T a, T b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
@ -0,0 +1,3 @@
|
||||
from .compiler import get_nvcc_compiler, build
|
||||
from .template import cpp_format, generate
|
||||
from .runtime import Runtime
|
||||
@ -0,0 +1,146 @@
|
||||
import hashlib
|
||||
import functools
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import uuid
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
from typing import Tuple
|
||||
|
||||
from .runtime import Runtime, RuntimeCache
|
||||
from .template import typename_map
|
||||
|
||||
runtime_cache = RuntimeCache()
|
||||
|
||||
|
||||
def hash_to_hex(s: str) -> str:
|
||||
md5 = hashlib.md5()
|
||||
md5.update(s.encode('utf-8'))
|
||||
return md5.hexdigest()[0:12]
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_jit_include_dir() -> str:
|
||||
return f'{os.path.dirname(os.path.abspath(__file__))}/../include'
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_deep_gemm_version() -> str:
|
||||
# Update include directories
|
||||
include_dir = f'{get_jit_include_dir()}/deep_gemm'
|
||||
assert os.path.exists(include_dir), f'Cannot find GEMM include directory {include_dir}'
|
||||
md5 = hashlib.md5()
|
||||
for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))):
|
||||
with open(f'{include_dir}/{filename}', 'rb') as f:
|
||||
md5.update(f.read())
|
||||
|
||||
return md5.hexdigest()[0:12]
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_nvcc_compiler() -> Tuple[str, str]:
|
||||
paths = []
|
||||
if os.getenv('DG_NVCC_COMPILER'):
|
||||
paths.append(os.getenv('DG_NVCC_COMPILER'))
|
||||
paths.append(f'{CUDA_HOME}/bin/nvcc')
|
||||
|
||||
# Try to find the first available NVCC compiler
|
||||
least_version_required = '12.3'
|
||||
version_pattern = re.compile(r'release (\d+\.\d+)')
|
||||
for path in paths:
|
||||
if os.path.exists(path):
|
||||
match = version_pattern.search(os.popen(f'{path} --version').read())
|
||||
version = match.group(1)
|
||||
assert match, f'Cannot get the version of NVCC compiler {path}'
|
||||
assert version >= least_version_required, f'NVCC {path} version {version} is lower than {least_version_required}'
|
||||
return path, version
|
||||
raise RuntimeError('Cannot find any available NVCC compiler')
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_default_user_dir():
|
||||
if 'DG_CACHE_DIR' in os.environ:
|
||||
path = os.getenv('DG_CACHE_DIR')
|
||||
os.makedirs(path, exist_ok=True)
|
||||
return path
|
||||
return os.path.expanduser('~') + '/.deep_gemm'
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_tmp_dir():
|
||||
return f'{get_default_user_dir()}/tmp'
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_cache_dir():
|
||||
return f'{get_default_user_dir()}/cache'
|
||||
|
||||
|
||||
def make_tmp_dir():
|
||||
tmp_dir = get_tmp_dir()
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
return tmp_dir
|
||||
|
||||
|
||||
def put(path, data, is_binary=False):
|
||||
# Write and do POSIX atomic replace
|
||||
tmp_file_path = f'{make_tmp_dir()}/file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}'
|
||||
with open(tmp_file_path, 'wb' if is_binary else 'w') as f:
|
||||
f.write(data)
|
||||
os.replace(tmp_file_path, path)
|
||||
|
||||
|
||||
def build(name: str, arg_defs: tuple, code: str) -> Runtime:
|
||||
# Compiler flags
|
||||
nvcc_flags = ['-std=c++17', '-shared', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
|
||||
'-gencode=arch=compute_90a,code=sm_90a',
|
||||
'--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
|
||||
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
|
||||
'--diag-suppress=177,174,940']
|
||||
cxx_flags = ['-fPIC', '-O3', '-Wno-deprecated-declarations', '-Wno-abi']
|
||||
flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}']
|
||||
include_dirs = [get_jit_include_dir()]
|
||||
|
||||
# Build signature
|
||||
enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and int(os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0
|
||||
signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}'
|
||||
name = f'kernel.{name}.{hash_to_hex(signature)}'
|
||||
path = f'{get_cache_dir()}/{name}'
|
||||
|
||||
# Check runtime cache or file system hit
|
||||
global runtime_cache
|
||||
if runtime_cache[path] is not None:
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(f'Using cached JIT runtime {name} during build')
|
||||
return runtime_cache[path]
|
||||
|
||||
# Write the code
|
||||
os.makedirs(path, exist_ok=True)
|
||||
args_path = f'{path}/kernel.args'
|
||||
src_path = f'{path}/kernel.cu'
|
||||
put(args_path, ', '.join([f"('{arg_def[0]}', {typename_map[arg_def[1]]})" for arg_def in arg_defs]))
|
||||
put(src_path, code)
|
||||
|
||||
# Compile into a temporary SO file
|
||||
so_path = f'{path}/kernel.so'
|
||||
tmp_so_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(so_path)}.so'
|
||||
|
||||
# Compile
|
||||
command = [get_nvcc_compiler()[0],
|
||||
src_path, '-o', tmp_so_path,
|
||||
*flags,
|
||||
*[f'-I{d}' for d in include_dirs]]
|
||||
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False):
|
||||
print(f'Compiling JIT runtime {name} with command {command}')
|
||||
assert subprocess.check_call(command) == 0, f'Failed to compile {src_path}'
|
||||
|
||||
# Interleave FFMA reuse
|
||||
if enable_sass_opt:
|
||||
pass
|
||||
|
||||
# Atomic replace SO file
|
||||
os.replace(tmp_so_path, so_path)
|
||||
|
||||
# Put cache and return
|
||||
runtime_cache[path] = Runtime(path)
|
||||
return runtime_cache[path]
|
||||
@ -0,0 +1,66 @@
|
||||
import ctypes
|
||||
import os
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
from .template import map_ctype
|
||||
|
||||
|
||||
class Runtime:
|
||||
def __init__(self, path: str) -> None:
|
||||
self.path = path
|
||||
self.lib = None
|
||||
self.args = None
|
||||
|
||||
assert self.is_path_valid(self.path)
|
||||
|
||||
@staticmethod
|
||||
def is_path_valid(path: str) -> bool:
|
||||
# Exists and is a directory
|
||||
if not os.path.exists(path) or not os.path.isdir(path):
|
||||
return False
|
||||
|
||||
# Contains all necessary files
|
||||
files = ['kernel.cu', 'kernel.args', 'kernel.so']
|
||||
return all(os.path.exists(os.path.join(path, file)) for file in files)
|
||||
|
||||
def __call__(self, *args) -> int:
|
||||
# Load SO file
|
||||
if self.lib is None or self.args is None:
|
||||
self.lib = ctypes.CDLL(os.path.join(self.path, 'kernel.so'))
|
||||
with open(os.path.join(self.path, 'kernel.args'), 'r') as f:
|
||||
self.args = eval(f.read())
|
||||
|
||||
# Check args and launch
|
||||
assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}'
|
||||
cargs = []
|
||||
for arg, (name, dtype) in zip(args, self.args):
|
||||
if isinstance(arg, torch.Tensor):
|
||||
assert arg.dtype == dtype, f'Expected tensor dtype `{dtype}` for `{name}`, got `{arg.dtype}`'
|
||||
else:
|
||||
assert isinstance(arg, dtype), f'Expected built-in type `{dtype}` for `{name}`, got `{type(arg)}`'
|
||||
cargs.append(map_ctype(arg))
|
||||
|
||||
return_code = ctypes.c_int(0)
|
||||
self.lib.launch(*cargs, ctypes.byref(return_code))
|
||||
return return_code.value
|
||||
|
||||
|
||||
class RuntimeCache:
|
||||
def __init__(self) -> None:
|
||||
self.cache = {}
|
||||
|
||||
def __getitem__(self, path: str) -> Optional[Runtime]:
|
||||
# In Python runtime
|
||||
if path in self.cache:
|
||||
return self.cache[path]
|
||||
|
||||
# Already compiled
|
||||
if os.path.exists(path) and Runtime.is_path_valid(path):
|
||||
runtime = Runtime(path)
|
||||
self.cache[path] = runtime
|
||||
return runtime
|
||||
return None
|
||||
|
||||
def __setitem__(self, path, runtime) -> None:
|
||||
self.cache[path] = runtime
|
||||
@ -0,0 +1,93 @@
|
||||
import copy
|
||||
import ctypes
|
||||
import os
|
||||
import torch
|
||||
|
||||
from typing import Any, Iterable, Dict, Tuple
|
||||
|
||||
|
||||
# Name map for Python `eval`
|
||||
typename_map: Dict[Any, str] = {
|
||||
**{t: t.__name__ for t in (bool, int, float)},
|
||||
torch.int: 'torch.int',
|
||||
torch.float: 'torch.float',
|
||||
torch.bfloat16: 'torch.bfloat16',
|
||||
torch.float8_e4m3fn: 'torch.float8_e4m3fn',
|
||||
torch.cuda.Stream: 'torch.cuda.Stream',
|
||||
}
|
||||
|
||||
# `ctype` map for Python casting
|
||||
ctype_map: Dict[Any, Any] = {
|
||||
**{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)},
|
||||
**{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)},
|
||||
}
|
||||
|
||||
|
||||
# Type map for both Python API and source code usages
|
||||
genc_map = {
|
||||
bool: ('bool', 'bool'),
|
||||
int: ('int', 'int'),
|
||||
float: ('float', 'float'),
|
||||
torch.int: ('void*', 'int*'),
|
||||
torch.float: ('void*', 'float*'),
|
||||
torch.bfloat16: ('void*', '__nv_bfloat16*'),
|
||||
torch.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'),
|
||||
torch.cuda.Stream: ('void*', 'cudaStream_t'),
|
||||
}
|
||||
|
||||
|
||||
def map_ctype(value: Any) -> Any:
|
||||
ctype = ctype_map[value.dtype if isinstance(value, torch.Tensor) else type(value)]
|
||||
if isinstance(value, torch.Tensor):
|
||||
return ctype(value.data_ptr())
|
||||
if isinstance(value, torch.cuda.Stream):
|
||||
return ctype(value.cuda_stream)
|
||||
return ctype(value)
|
||||
|
||||
|
||||
def cpp_format(template: str, keys: Dict[str, Any]) -> str:
|
||||
# We don't use `str.format` because it's not safe for C++ {} braces
|
||||
new_template = copy.deepcopy(template)
|
||||
for key, value in keys.items():
|
||||
new_template = new_template.replace(f'{{{key}}}', f'{value}')
|
||||
return new_template
|
||||
|
||||
|
||||
def generate(includes: Iterable[str], arg_defs: Iterable[Tuple], body: str) -> str:
|
||||
# Common prefix
|
||||
code = '// DeepGEMM auto-generated JIT CUDA source file\n\n'
|
||||
|
||||
# Includes
|
||||
preload_sys_includes = ['<cuda.h>', '<cuda_fp8.h>', '<cuda_runtime.h>', '<iostream>']
|
||||
preload_package_includes = ['"cutlass/cutlass.h"']
|
||||
|
||||
assert isinstance(includes, list) or isinstance(includes, tuple)
|
||||
sys_includes = sorted(list(set(preload_sys_includes + [include for include in includes if include.startswith('<')])))
|
||||
package_includes = sorted(list(set(preload_package_includes + [include for include in includes if include.startswith('"')])))
|
||||
code += '\n'.join(f'#include {include}' for include in sys_includes) + '\n\n'
|
||||
code += '\n'.join(f'#include {include}' for include in package_includes) + '\n\n'
|
||||
|
||||
# Function signature
|
||||
raw = '__raw_'
|
||||
get_def = lambda n, t: f'{genc_map[t][0]} ' + (raw if genc_map[t][0] != genc_map[t][1] else '') + n
|
||||
code += f'extern "C" void launch('
|
||||
code += ', '.join([get_def(*arg_def) for arg_def in arg_defs] + ['int& __return_code', ])
|
||||
code += ') {\n'
|
||||
|
||||
# Cast raw types
|
||||
code += ' // Cast raw types (if needed)\n'
|
||||
for arg_name, arg_type in arg_defs:
|
||||
if genc_map[arg_type][0] != genc_map[arg_type][1]:
|
||||
code += f' auto {arg_name} = reinterpret_cast<{genc_map[arg_type][1]}>({raw}{arg_name});\n'
|
||||
|
||||
# Function body
|
||||
code += '\n'.join([((' ' if line else '') + line) for line in body.split('\n')])
|
||||
|
||||
# End the function
|
||||
code += '}\n\n'
|
||||
|
||||
# Debug print
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(f'Generated code:\n{code}')
|
||||
|
||||
return code
|
||||
@ -0,0 +1,10 @@
|
||||
from .gemm import gemm_fp8_fp8_bf16_nt
|
||||
from .m_grouped_gemm import (
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked
|
||||
)
|
||||
from .utils import (
|
||||
cell_div, set_num_sms, get_num_sms,
|
||||
get_col_major_tma_aligned_tensor,
|
||||
get_m_alignment_for_contiguous_layout
|
||||
)
|
||||
@ -0,0 +1,171 @@
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
from .tuner import jit_tuner
|
||||
from .utils import get_num_sms, cell_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
|
||||
|
||||
# C++ code templates
|
||||
includes = ('"deep_gemm/fp8_gemm.cuh"', )
|
||||
template = """
|
||||
using namespace deep_gemm;
|
||||
|
||||
// Templated args from Python JIT call
|
||||
constexpr auto N = {N}, K = {K};
|
||||
constexpr auto BLOCK_M = {BLOCK_M};
|
||||
constexpr auto BLOCK_N = {BLOCK_N};
|
||||
constexpr auto kNumStages = {NUM_STAGES};
|
||||
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
|
||||
|
||||
// Make a templated GEMM
|
||||
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, 1, kNumStages, kNumTMAMulticast, GemmType::Normal>;
|
||||
|
||||
// Launch kernel
|
||||
auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m);
|
||||
auto tma_b_desc = GemmType::make_2d_tma_b_desc(rhs);
|
||||
auto tma_scales_a_desc = GemmType::make_2d_tma_scales_a_desc(lhs_scales, m);
|
||||
auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m);
|
||||
GemmType::run(out, rhs_scales, nullptr,
|
||||
m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
||||
stream, num_sms, smem_size);
|
||||
"""
|
||||
|
||||
|
||||
def is_tma_multicast_legal(n: int, block_n: int, num_tma_multicast: int, num_sms: int) -> bool:
|
||||
if num_tma_multicast == 1:
|
||||
return True
|
||||
return (n % (block_n * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0
|
||||
|
||||
|
||||
def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> int:
|
||||
smem_d = block_m * block_n * 2
|
||||
smem_a_per_stage = block_m * block_k
|
||||
smem_scales_a_per_stage = block_m * 4
|
||||
smem_b_per_stage = block_n * block_k
|
||||
smem_scales_b = cell_div(k, block_k) * 4
|
||||
smem_barrier = num_stages * 8 * 2
|
||||
|
||||
smem_size = 0
|
||||
smem_size += smem_d
|
||||
smem_size += num_stages * smem_a_per_stage
|
||||
smem_size += num_stages * smem_scales_a_per_stage
|
||||
smem_size += num_stages * smem_b_per_stage
|
||||
smem_size += smem_scales_b * (1 if block_k % block_n == 0 else 2)
|
||||
smem_size += smem_barrier
|
||||
return smem_size
|
||||
|
||||
|
||||
def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
is_grouped_contiguous: bool = False) -> Tuple[int, int, int, int, int]:
|
||||
if not is_grouped_contiguous:
|
||||
# TODO: for some cases, smaller M block is better, add them into tuning space
|
||||
block_ms = (64 if m <= 64 else 128, )
|
||||
else:
|
||||
block_ms = (get_m_alignment_for_contiguous_layout(), )
|
||||
block_ns = tuple(range(16, 129, 8))
|
||||
|
||||
fix_wave_saturate = lambda x: num_sms if x == 0 else x
|
||||
get_num_waves = lambda bm, bn: (cell_div(cell_div(m, bm) * cell_div(n, bn) * num_groups, num_sms) if bm else None)
|
||||
get_last_wave_util = lambda bm, bn: fix_wave_saturate((cell_div(m, bm) * cell_div(n, bn) * num_groups) % num_sms)
|
||||
|
||||
# Decide block sizes by waves
|
||||
best_block_m, best_block_n = None, None
|
||||
for block_m in block_ms:
|
||||
for block_n in block_ns:
|
||||
success = False
|
||||
num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n)
|
||||
if best_block_m is None or best_block_n is None:
|
||||
success = True
|
||||
elif num_waves < best_num_waves:
|
||||
success = True
|
||||
elif num_waves == best_num_waves:
|
||||
# Check last wave utilization
|
||||
util = get_last_wave_util(block_m, block_n)
|
||||
best_util = get_last_wave_util(best_block_m, best_block_n)
|
||||
success = util > best_util or (util == best_util and (block_n >= best_block_n and block_m <= best_block_m))
|
||||
best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n)
|
||||
assert best_block_m is not None and best_block_n is not None
|
||||
|
||||
# Always pick the longest one
|
||||
# NOTES: for double B scales, the best number of stages may be reduced
|
||||
best_num_stages, best_smem_size, sm90_capacity = None, None, 232448
|
||||
for num_stages in (6, 5, 4) if 128 % best_block_n != 0 else (8, 7, 6, 5, 4):
|
||||
best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n)
|
||||
if best_smem_size <= sm90_capacity:
|
||||
best_num_stages = num_stages
|
||||
break
|
||||
assert best_num_stages is not None
|
||||
|
||||
# Decide the number of TMA multicast
|
||||
best_num_tma_multicast = 1
|
||||
if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, num_sms) and num_groups == 1:
|
||||
best_num_tma_multicast = 2
|
||||
|
||||
return best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size
|
||||
|
||||
|
||||
def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
out: torch.Tensor) -> None:
|
||||
"""
|
||||
Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
||||
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
||||
RHS and RHS scaling factors are required to be transposed.
|
||||
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
|
||||
this function will do a transposing with a set of slow PyTorch operations.
|
||||
|
||||
Arguments:
|
||||
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
|
||||
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`.
|
||||
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`.
|
||||
the second element is an FP32 128x128 scaling tensor for RHS of shape `[⌈n / 128⌉, ⌈k / 128⌉]`.
|
||||
out: the BF16 output tensor of shape `[m, n]`, representing the result.
|
||||
"""
|
||||
lhs, lhs_scales = lhs
|
||||
rhs, rhs_scales = rhs
|
||||
m, k = lhs.shape
|
||||
n, k_ = rhs.shape
|
||||
m_, n_ = out.shape
|
||||
|
||||
assert n % 64 == 0 and k % 128 == 0
|
||||
|
||||
# Type and shape checks
|
||||
assert m == m_ and n == n_ and k == k_
|
||||
assert n > 0 and k > 0
|
||||
assert lhs_scales.shape == (m, (k + 127) // 128)
|
||||
assert rhs_scales.shape == ((n + 127) // 128, (k + 127) // 128)
|
||||
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||
assert out.dtype == torch.bfloat16
|
||||
assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
assert rhs_scales.is_contiguous()
|
||||
|
||||
# Do nothing if `m` is zero
|
||||
if m == 0:
|
||||
return
|
||||
|
||||
# Auto-tuning with compilation
|
||||
global includes, template
|
||||
num_sms = get_num_sms()
|
||||
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms)
|
||||
args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_size)
|
||||
runtime = jit_tuner.compile_and_tune(
|
||||
name='gemm_fp8_fp8_bf16_nt',
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast},
|
||||
space=(),
|
||||
includes=includes,
|
||||
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
||||
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
|
||||
('out', torch.bfloat16), ('m', int),
|
||||
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
||||
template=template,
|
||||
args=args
|
||||
)
|
||||
|
||||
# Run the kernel
|
||||
runtime(*args)
|
||||
@ -0,0 +1,182 @@
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
from .gemm import get_best_configs
|
||||
from .tuner import jit_tuner
|
||||
from .utils import get_col_major_tma_aligned_tensor, get_num_sms
|
||||
|
||||
# C++ code templates
|
||||
includes = ('"deep_gemm/fp8_gemm.cuh"', )
|
||||
template = """
|
||||
using namespace deep_gemm;
|
||||
|
||||
// Templated args from Python JIT call
|
||||
constexpr auto N = {N}, K = {K};
|
||||
constexpr auto BLOCK_M = {BLOCK_M};
|
||||
constexpr auto BLOCK_N = {BLOCK_N};
|
||||
constexpr auto kNumStages = {NUM_STAGES};
|
||||
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
|
||||
|
||||
// Make a templated grouped GEMM
|
||||
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, {NUM_GROUPS}, kNumStages, kNumTMAMulticast, GemmType::{GEMM_TYPE}>;
|
||||
|
||||
// Launch kernel
|
||||
auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m);
|
||||
auto tma_b_desc = GemmType::make_2d_tma_b_desc(rhs);
|
||||
auto tma_scales_a_desc = GemmType::make_2d_tma_scales_a_desc(lhs_scales, m);
|
||||
auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m);
|
||||
GemmType::run(out, rhs_scales, grouped_layout,
|
||||
m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
||||
stream, num_sms, smem_size);
|
||||
"""
|
||||
|
||||
|
||||
def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
out: torch.Tensor, m_indices: torch.Tensor) -> None:
|
||||
"""
|
||||
Do a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
||||
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
||||
RHS and RHS scaling factors are required to be transposed.
|
||||
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
|
||||
this function will do a transposing with a set of slow PyTorch operations.
|
||||
On the M axis, inputs are grouped into several batches, of which batch sizes aligned to
|
||||
`get_m_alignment_for_contiguous_layout()` (128).
|
||||
|
||||
Arguments:
|
||||
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`,
|
||||
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m_sum, ⌈k / 128⌉]`.
|
||||
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`.
|
||||
the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
|
||||
out: the BF16 output tensor of shape `[m_sum, n]`, representing the result.
|
||||
m_indices: a tensor of shape `[m_sum]` with type `torch.int`.
|
||||
`m_indices[i]` records the group which the j-th row of the LHS belong to,
|
||||
which means that the i-th row of the LHS matrix will be multiplied with `rhs[m_indices[i]]`.
|
||||
Values of `m_indices` in every-m-alignment-block must also be the same.
|
||||
`-1` in this tensor indicates no RHS matrix selected, the kernel will skip the computation for that aligned block.
|
||||
"""
|
||||
lhs, lhs_scales = lhs
|
||||
rhs, rhs_scales = rhs
|
||||
m, k = lhs.shape
|
||||
num_groups, n, k_ = rhs.shape
|
||||
m_, n_ = out.shape
|
||||
m__ = m_indices.numel()
|
||||
|
||||
# Type and shape checks
|
||||
assert m == m_ == m__ and k == k_ and n == n_
|
||||
assert lhs_scales.shape == (m, (k + 127) // 128)
|
||||
assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
|
||||
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||
assert out.dtype == torch.bfloat16
|
||||
assert m_indices.dtype == torch.int32
|
||||
assert lhs.is_contiguous() and rhs.is_contiguous()
|
||||
assert out.is_contiguous() and m_indices.is_contiguous()
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
assert rhs_scales.is_contiguous()
|
||||
|
||||
# Do nothing if `m` is zero
|
||||
if m == 0:
|
||||
return
|
||||
|
||||
# Auto-tuning with compilation
|
||||
global includes, template
|
||||
num_sms = get_num_sms()
|
||||
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms,
|
||||
is_grouped_contiguous=True)
|
||||
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
||||
m_indices, m, num_groups,
|
||||
torch.cuda.current_stream(), num_sms, smem_size)
|
||||
runtime = jit_tuner.compile_and_tune(
|
||||
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
|
||||
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedContiguous'},
|
||||
space=(),
|
||||
includes=includes,
|
||||
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
||||
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
|
||||
('out', torch.bfloat16),
|
||||
('grouped_layout', torch.int32), ('m', int), ('num_groups', int),
|
||||
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
||||
template=template,
|
||||
args=args
|
||||
)
|
||||
|
||||
# Run the kernel
|
||||
runtime(*args)
|
||||
|
||||
|
||||
def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None:
|
||||
"""
|
||||
Do a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
||||
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
||||
RHS and RHS scaling factors are required to be transposed.
|
||||
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
|
||||
this function will do a transposing with a set of slow PyTorch operations.
|
||||
Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch
|
||||
should be separately transposed.
|
||||
|
||||
Arguments:
|
||||
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
|
||||
the second element is an FP32 1x128 scaling tensor for LHS of shape `[num_groups, m_max, ⌈k / 128⌉]`.
|
||||
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`.
|
||||
the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
|
||||
out: the BF16 output tensor of shape `[num_groups, m_max, n]`, representing the result.
|
||||
masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute
|
||||
in the i-th group.
|
||||
expected_m: a value hint (which is a value on CPU) for the M expectation of each batch,
|
||||
correctly setting this value may lead to better performance.
|
||||
"""
|
||||
lhs, lhs_scales = lhs
|
||||
rhs, rhs_scales = rhs
|
||||
num_groups, m, k = lhs.shape
|
||||
num_groups_, n, k_ = rhs.shape
|
||||
num_groups__, m_, n_ = out.shape
|
||||
num_groups___ = masked_m.numel()
|
||||
|
||||
# Type and shape checks
|
||||
assert num_groups == num_groups_ == num_groups__ == num_groups___
|
||||
assert m == m_ and n == n_ and k == k_
|
||||
assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0
|
||||
assert lhs_scales.shape == (num_groups, m, (k + 127) // 128)
|
||||
assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
|
||||
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||
assert out.dtype == torch.bfloat16
|
||||
assert masked_m.dtype == torch.int32
|
||||
assert lhs.is_contiguous() and rhs.is_contiguous()
|
||||
assert out.is_contiguous() and masked_m.is_contiguous()
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
assert rhs_scales.is_contiguous()
|
||||
|
||||
# Auto-tuning with compilation
|
||||
global includes, template
|
||||
num_sms = get_num_sms()
|
||||
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms)
|
||||
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
||||
masked_m, m,
|
||||
torch.cuda.current_stream(), num_sms, smem_size)
|
||||
runtime = jit_tuner.compile_and_tune(
|
||||
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
|
||||
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedMasked'},
|
||||
space=(),
|
||||
includes=includes,
|
||||
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
||||
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
|
||||
('out', torch.bfloat16),
|
||||
('grouped_layout', torch.int32), ('m', int),
|
||||
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
||||
template=template,
|
||||
args=args
|
||||
)
|
||||
|
||||
# Run the kernel
|
||||
runtime(*args)
|
||||
@ -0,0 +1,81 @@
|
||||
import copy
|
||||
import os
|
||||
import torch
|
||||
from typing import Any, Dict
|
||||
|
||||
from ..jit import build, cpp_format, generate, Runtime
|
||||
|
||||
|
||||
class JITTuner:
|
||||
def __init__(self) -> None:
|
||||
self.tuned = {}
|
||||
|
||||
def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple,
|
||||
includes: tuple, arg_defs: tuple, template: str, args: tuple) -> Runtime:
|
||||
# NOTES: we always assume the space and template will not change
|
||||
# We also assume the GPU device will not be changed
|
||||
# NOTES: the function must have no accumulated side effects
|
||||
keys = {k: keys[k] for k in sorted(keys.keys())}
|
||||
signature = (name, f'{keys}')
|
||||
if signature in self.tuned:
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(f'Using cached JIT kernel {name} with keys {keys}')
|
||||
return self.tuned[signature]
|
||||
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(f'Auto-tuning JIT kernel {name} with keys {keys}')
|
||||
|
||||
assert signature not in self.tuned
|
||||
assert args is not None
|
||||
space = (dict(), ) if len(space) == 0 else space
|
||||
|
||||
kernels = []
|
||||
for tuned_keys in space:
|
||||
assert isinstance(tuned_keys, dict)
|
||||
full_keys = copy.deepcopy(keys)
|
||||
full_keys.update(tuned_keys)
|
||||
code = generate(includes, arg_defs, cpp_format(template, full_keys))
|
||||
|
||||
# Illegal build must raise errors
|
||||
kernels.append((build(name, arg_defs, code), tuned_keys))
|
||||
|
||||
best_runtime, best_time, best_keys = None, None, None
|
||||
for runtime, tuned_keys in kernels:
|
||||
if len(space) > 1:
|
||||
# Check kernel validity
|
||||
return_code = runtime(*args)
|
||||
if return_code != 0:
|
||||
# Pass illegal kernels, e.g. insufficient shared memory capacity
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}')
|
||||
continue
|
||||
|
||||
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda').zero_()
|
||||
torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
||||
start_event.record()
|
||||
for i in range(20):
|
||||
assert runtime(*args) == 0
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
elapsed_time = start_event.elapsed_time(end_event)
|
||||
else:
|
||||
elapsed_time = 0
|
||||
|
||||
# Compare if better
|
||||
if best_time is None or elapsed_time < best_time:
|
||||
best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}')
|
||||
assert best_runtime is not None, f'Failed to tune JIT kernel {name} with keys {keys}'
|
||||
|
||||
# Cache the best runtime and return
|
||||
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_PRINT_AUTOTUNE', None):
|
||||
print(f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}')
|
||||
self.tuned[signature] = best_runtime
|
||||
return best_runtime
|
||||
|
||||
|
||||
jit_tuner = JITTuner()
|
||||
@ -0,0 +1,105 @@
|
||||
import torch
|
||||
|
||||
_num_sms = None
|
||||
|
||||
|
||||
def set_num_sms(num_sms: int) -> None:
|
||||
"""
|
||||
Set the maximum SM count for all GEMM kernels to use.
|
||||
|
||||
Arguments:
|
||||
num_sms: the desired maximum SM count for all GEMM kernels to use.
|
||||
"""
|
||||
global _num_sms
|
||||
assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count
|
||||
_num_sms = num_sms
|
||||
|
||||
|
||||
def get_num_sms() -> int:
|
||||
"""
|
||||
Get the current maximum limit of SM count for all GEMM kernels to use.
|
||||
If the count is never specified, the function will return the number of device SMs.
|
||||
|
||||
Returns:
|
||||
Current maximum limit of SM count for all GEMM kernels to use.
|
||||
"""
|
||||
global _num_sms
|
||||
if _num_sms is None:
|
||||
_num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count
|
||||
return _num_sms
|
||||
|
||||
|
||||
def cell_div(x: int, y: int) -> int:
|
||||
"""
|
||||
Perform ceiling division of two integers.
|
||||
|
||||
Args:
|
||||
x: the dividend.
|
||||
y: the divisor.
|
||||
|
||||
Returns:
|
||||
The result of the ceiling division.
|
||||
"""
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
def get_m_alignment_for_contiguous_layout():
|
||||
"""
|
||||
When we do a grouped GEMM in contiguous format, LHS are grouped into several batches along the M axis.
|
||||
Since we deal with exactly one sub-matrix of RHS for each GEMM block, batch sizes above should align well
|
||||
with GEMM block shape.
|
||||
|
||||
Returns:
|
||||
Group-level alignment requirement for grouped contiguous layout, which is always 128.
|
||||
"""
|
||||
return 128
|
||||
|
||||
|
||||
def get_tma_aligned_size(x: int, element_size: int) -> int:
|
||||
"""
|
||||
Global memory address of TMA must be 16-byte aligned.
|
||||
Since we use column-major layout for the LHS scaling tensor,
|
||||
the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes.
|
||||
|
||||
Arguments:
|
||||
x: original M-axis shape of the LHS scaling tensor.
|
||||
element_size: element size of the LHS scaling tensor.
|
||||
|
||||
Returns:
|
||||
M-axis shape of the LHS scaling tensor after padding.
|
||||
"""
|
||||
tma_alignment_bytes = 16
|
||||
assert tma_alignment_bytes % element_size == 0
|
||||
alignment = tma_alignment_bytes // element_size
|
||||
return cell_div(x, alignment) * alignment
|
||||
|
||||
|
||||
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Returns TMA-aligned transposed format of the input tensor. `torch.transpose` will be called if necessary.
|
||||
If the input tensor is already column-major layout and 16-byte aligned along the M axis
|
||||
(thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing.
|
||||
|
||||
Arguments:
|
||||
x: usually the LHS scaling tensor in GEMM.
|
||||
|
||||
Returns:
|
||||
The LHS scaling tensor of TMA-aligned transposed format.
|
||||
"""
|
||||
# NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA
|
||||
assert x.dim() in (2, 3)
|
||||
remove_dim = False
|
||||
if x.dim() == 2:
|
||||
x, remove_dim = x.unsqueeze(0), True
|
||||
|
||||
b, m, n = x.shape
|
||||
aligned_m = get_tma_aligned_size(m, x.element_size())
|
||||
|
||||
# The last kernel gives a column-major TMA aligned layout
|
||||
if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m:
|
||||
return x.squeeze(0) if remove_dim else x
|
||||
|
||||
# Normal layout requires transposing
|
||||
aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
|
||||
aligned_x[:, :m, :] = x
|
||||
return aligned_x.squeeze(0) if remove_dim else aligned_x
|
||||
@ -0,0 +1,154 @@
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def bench(fn, num_warmups: int = 5, num_tests: int = 10,
|
||||
high_precision: bool = False):
|
||||
# Flush L2 cache with 256 MB data
|
||||
torch.cuda.synchronize()
|
||||
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
|
||||
cache.zero_()
|
||||
|
||||
# Warmup
|
||||
for _ in range(num_warmups):
|
||||
fn()
|
||||
|
||||
# Add a large kernel to eliminate the CPU launch overhead
|
||||
if high_precision:
|
||||
x = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
||||
y = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
||||
x @ y
|
||||
|
||||
# Testing
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event.record()
|
||||
for i in range(num_tests):
|
||||
fn()
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return start_event.elapsed_time(end_event) / num_tests
|
||||
|
||||
|
||||
class empty_suppress:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *_):
|
||||
pass
|
||||
|
||||
|
||||
class suppress_stdout_stderr:
|
||||
def __enter__(self):
|
||||
self.outnull_file = open(os.devnull, 'w')
|
||||
self.errnull_file = open(os.devnull, 'w')
|
||||
|
||||
self.old_stdout_fileno_undup = sys.stdout.fileno()
|
||||
self.old_stderr_fileno_undup = sys.stderr.fileno()
|
||||
|
||||
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
|
||||
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
|
||||
|
||||
self.old_stdout = sys.stdout
|
||||
self.old_stderr = sys.stderr
|
||||
|
||||
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
|
||||
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
|
||||
|
||||
sys.stdout = self.outnull_file
|
||||
sys.stderr = self.errnull_file
|
||||
return self
|
||||
|
||||
def __exit__(self, *_):
|
||||
sys.stdout = self.old_stdout
|
||||
sys.stderr = self.old_stderr
|
||||
|
||||
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
|
||||
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
|
||||
|
||||
os.close(self.old_stdout_fileno)
|
||||
os.close(self.old_stderr_fileno)
|
||||
|
||||
self.outnull_file.close()
|
||||
self.errnull_file.close()
|
||||
|
||||
|
||||
def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False,
|
||||
trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = False):
|
||||
# Conflict with Nsight Systems
|
||||
using_nsys = os.environ.get('DG_NSYS_PROFILING', False)
|
||||
|
||||
# For some auto-tuning kernels with prints
|
||||
fn()
|
||||
|
||||
# Profile
|
||||
suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress
|
||||
with suppress():
|
||||
schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) if not using_nsys else None
|
||||
profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress()
|
||||
with profiler:
|
||||
for i in range(2):
|
||||
# NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead
|
||||
if barrier_comm_profiling:
|
||||
lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
||||
rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
||||
lhs @ rhs
|
||||
dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda'))
|
||||
for _ in range(num_tests):
|
||||
if flush_l2:
|
||||
torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda').zero_()
|
||||
fn()
|
||||
|
||||
if not using_nsys:
|
||||
profiler.step()
|
||||
|
||||
# Return 1 if using Nsight Systems
|
||||
if using_nsys:
|
||||
return 1
|
||||
|
||||
# Parse the profiling table
|
||||
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
|
||||
is_tupled = isinstance(kernel_names, tuple)
|
||||
prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
|
||||
kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
|
||||
assert all([isinstance(name, str) for name in kernel_names])
|
||||
for name in kernel_names:
|
||||
assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table'
|
||||
|
||||
# Save chrome traces
|
||||
if trace_path is not None:
|
||||
profiler.export_chrome_trace(trace_path)
|
||||
|
||||
# Return average kernel times
|
||||
units = {'ms': 1e3, 'us': 1e6}
|
||||
kernel_times = []
|
||||
for name in kernel_names:
|
||||
for line in prof_lines:
|
||||
if name in line:
|
||||
time_str = line.split()[-2]
|
||||
for unit, scale in units.items():
|
||||
if unit in time_str:
|
||||
kernel_times.append(float(time_str.replace(unit, '')) / scale)
|
||||
break
|
||||
break
|
||||
return tuple(kernel_times) if is_tupled else kernel_times[0]
|
||||
|
||||
|
||||
def calc_diff(x, y):
|
||||
x, y = x.double(), y.double()
|
||||
denominator = (x * x + y * y).sum()
|
||||
sim = 2 * (x * y).sum() / denominator
|
||||
return 1 - sim
|
||||
|
||||
|
||||
def count_bytes(tensors):
|
||||
total = 0
|
||||
for t in tensors:
|
||||
if isinstance(t, tuple):
|
||||
total += count_bytes(t)
|
||||
else:
|
||||
total += t.numel() * t.element_size()
|
||||
return total
|
||||
@ -0,0 +1,69 @@
|
||||
import os
|
||||
import setuptools
|
||||
import shutil
|
||||
import subprocess
|
||||
from setuptools.command.develop import develop
|
||||
from setuptools.command.install import install
|
||||
|
||||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
jit_include_dirs = ('deep_gemm/include/deep_gemm', )
|
||||
cutlass_dirs = '../../include'
|
||||
third_party_include_dirs = (os.path.join(cutlass_dirs, 'cute'), os.path.join(cutlass_dirs, 'cutlass'))
|
||||
print(third_party_include_dirs)
|
||||
|
||||
|
||||
class PostDevelopCommand(develop):
|
||||
def run(self):
|
||||
develop.run(self)
|
||||
self.make_jit_include_symlinks()
|
||||
|
||||
@staticmethod
|
||||
def make_jit_include_symlinks():
|
||||
# Make symbolic links of third-party include directories
|
||||
for d in third_party_include_dirs:
|
||||
dirname = d.split('/')[-1]
|
||||
src_dir = f'{current_dir}/{d}'
|
||||
dst_dir = f'{current_dir}/deep_gemm/include/{dirname}'
|
||||
if not os.path.exists(src_dir):
|
||||
os.makedirs(src_dir, exist_ok=True)
|
||||
assert os.path.exists(src_dir)
|
||||
if os.path.exists(dst_dir):
|
||||
assert os.path.islink(dst_dir)
|
||||
os.unlink(dst_dir)
|
||||
os.symlink(src_dir, dst_dir, target_is_directory=True)
|
||||
|
||||
|
||||
class PostInstallCommand(install):
|
||||
def run(self):
|
||||
install.run(self)
|
||||
self.copy_jit_includes()
|
||||
|
||||
def copy_jit_includes(self):
|
||||
# Copy include directories needed by JIT
|
||||
shutil.rmtree(f'{self.build_lib}/deep_gemm/include', ignore_errors=True)
|
||||
os.makedirs(f'{self.build_lib}/deep_gemm/include', exist_ok=False)
|
||||
for d in jit_include_dirs + third_party_include_dirs:
|
||||
src_dir = f'{current_dir}/{d}'
|
||||
dst_dir = f'{self.build_lib}/deep_gemm/include/{d.split("/")[-1]}'
|
||||
assert os.path.exists(src_dir)
|
||||
shutil.copytree(src_dir, dst_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
cmd = ['git', 'rev-parse', '--short', 'HEAD']
|
||||
revision = '+' + subprocess.check_output(cmd).decode('ascii').rstrip()
|
||||
except:
|
||||
revision = ''
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
setuptools.setup(
|
||||
name='deep_gemm',
|
||||
version='1.0.0' + revision,
|
||||
packages=['deep_gemm', 'deep_gemm/jit', 'deep_gemm/jit_kernels'],
|
||||
cmdclass={
|
||||
'develop': PostDevelopCommand,
|
||||
'install': PostInstallCommand
|
||||
}
|
||||
)
|
||||
@ -0,0 +1,158 @@
|
||||
import random
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
import deep_gemm
|
||||
from deep_gemm import bench_kineto, calc_diff, cell_div, get_col_major_tma_aligned_tensor
|
||||
|
||||
|
||||
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2 and x.size(1) % 128 == 0
|
||||
m, n = x.shape
|
||||
x_view = x.view(m, -1, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros((cell_div(m, 128) * 128, cell_div(n, 128) * 128), dtype=x.dtype, device=x.device)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
||||
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
|
||||
|
||||
|
||||
def construct(m: int, k: int, n: int) -> \
|
||||
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
|
||||
x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
|
||||
y = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
|
||||
out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
|
||||
ref_out = x @ y.t()
|
||||
|
||||
x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_block_cast_to_fp8(y)
|
||||
# Transpose earlier so that the testing will not trigger transposing kernels
|
||||
x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
|
||||
return x_fp8, y_fp8, out, ref_out
|
||||
|
||||
|
||||
def construct_grouped(num_groups: int, m: int, k: int, n: int, is_masked: bool) -> \
|
||||
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
|
||||
x = torch.randn((num_groups, m, k), device='cuda', dtype=torch.bfloat16)
|
||||
y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16)
|
||||
out = torch.empty((num_groups, m, n), device='cuda', dtype=torch.bfloat16)
|
||||
ref_out = torch.einsum('gmk,gnk->gmn', x, y)
|
||||
|
||||
assert m % 4 == 0, f'TMA alignment error: {m}'
|
||||
x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), torch.empty((num_groups, m, k // 128), device='cuda', dtype=torch.float))
|
||||
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, (n + 127) // 128, k // 128), device='cuda', dtype=torch.float))
|
||||
for i in range(num_groups):
|
||||
x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i])
|
||||
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
|
||||
|
||||
# For non-masked input, we must merge the group and M dims
|
||||
if not is_masked:
|
||||
x_fp8 = (x_fp8[0].view(-1, k), per_token_cast_to_fp8(x.view(-1, k))[1])
|
||||
out, ref_out = out.view(-1, n), ref_out.view(-1, n)
|
||||
|
||||
# Transpose earlier so that the testing will not trigger transposing kernels
|
||||
x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
|
||||
return x_fp8, y_fp8, out, ref_out
|
||||
|
||||
|
||||
def test_gemm() -> None:
|
||||
print('Testing GEMM:')
|
||||
for m in (64, 128, 4096):
|
||||
for k, n in [(7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)]:
|
||||
x_fp8, y_fp8, out, ref_out = construct(m, k, n)
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out)
|
||||
diff = calc_diff(out, ref_out)
|
||||
assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}'
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
# Construct new tensors every time to avoid L2 cache acceleration
|
||||
x_fp8, y_fp8, out, ref_out = construct(m, k, n)
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out)
|
||||
|
||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||
print(f' > Performance (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | '
|
||||
f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, '
|
||||
f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
def test_m_grouped_gemm_contiguous() -> None:
|
||||
print('Testing grouped contiguous GEMM:')
|
||||
|
||||
for num_groups, m, k, n in ((4, 8192, 7168, 4096), (4, 8192, 2048, 7168), (8, 4096, 7168, 4096), (8, 4096, 2048, 7168)):
|
||||
# TODO: make a stronger test
|
||||
x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=False)
|
||||
m_indices = torch.arange(0, num_groups, device='cuda', dtype=torch.int)
|
||||
m_indices = m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1)
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices)
|
||||
diff = calc_diff(out, ref_out)
|
||||
assert diff < 0.001, f'm={m * num_groups}, {k=}, {n=}, {diff:.5f}'
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
# Construct new tensors every time to avoid L2 cache acceleration
|
||||
x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=False)
|
||||
m_indices = torch.arange(0, num_groups, device='cuda', dtype=torch.int)
|
||||
m_indices = m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1)
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices)
|
||||
|
||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||
print(f' > Performance ({num_groups=}, m_per_group={m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
|
||||
f'throughput: {2 * num_groups * m * n * k / t / 1e12:4.0f} TFLOPS, '
|
||||
f'{(num_groups * (m * k + k * n + m * n * 2)) / 1e9 / t:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
def test_m_grouped_gemm_masked() -> None:
|
||||
print('Testing grouped masked GEMM:')
|
||||
|
||||
for num_groups, m in ((1, 1024), (2, 512), (4, 256)):
|
||||
for k, n in ((7168, 4096), (2048, 7168), ):
|
||||
# Test correctness
|
||||
masked_m_candidates = list(filter(lambda candidate: candidate <= m, (64, 128, 192, 256, 320, 384)))
|
||||
for i in range(10):
|
||||
x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True)
|
||||
masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int)
|
||||
for j in range(num_groups):
|
||||
masked_m[j] = random.choice(masked_m_candidates)
|
||||
expected_m = int(masked_m.float().mean()) + 1
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m)
|
||||
for j in range(num_groups):
|
||||
diff = calc_diff(out[j, :masked_m[j].item()], ref_out[j, :masked_m[j].item()])
|
||||
assert diff < 0.001, f'{m=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}'
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
# Construct new tensors every time to avoid L2 cache acceleration
|
||||
x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True)
|
||||
masked_m = torch.ones((num_groups, ), device='cuda', dtype=torch.int) * m
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, m)
|
||||
|
||||
# Test performance with fixed shapes
|
||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||
print(f' > Performance ({num_groups=}, m_per_group={m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
|
||||
f'throughput: {2 * num_groups * m * n * k / t / 1e12:4.0f} TFLOPS, '
|
||||
f'{(num_groups * (m * k + k * n + m * n * 2)) / 1e9 / t:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
|
||||
print('Library path:')
|
||||
print(f' > {deep_gemm.__path__}\n')
|
||||
|
||||
test_gemm()
|
||||
test_m_grouped_gemm_contiguous()
|
||||
test_m_grouped_gemm_masked()
|
||||
@ -0,0 +1,64 @@
|
||||
import os
|
||||
import torch
|
||||
from typing import Any
|
||||
|
||||
from deep_gemm import jit
|
||||
|
||||
|
||||
class Capture:
|
||||
def __init__(self) -> None:
|
||||
self.read_fd = None
|
||||
self.write_fd = None
|
||||
self.saved_stdout = None
|
||||
self.captured = None
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
self.read_fd, self.write_fd = os.pipe()
|
||||
self.saved_stdout = os.dup(1)
|
||||
os.dup2(self.write_fd, 1)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
os.dup2(self.saved_stdout, 1)
|
||||
os.close(self.write_fd)
|
||||
with os.fdopen(self.read_fd, 'r') as f:
|
||||
self.captured = f.read()
|
||||
|
||||
def capture(self) -> str:
|
||||
return self.captured
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Runtime
|
||||
print(f'NVCC compiler: {jit.get_nvcc_compiler()}\n')
|
||||
|
||||
# Templates
|
||||
print('Generated code:')
|
||||
args = (('lhs', torch.float8_e4m3fn), ('rhs', torch.float8_e4m3fn), ('scale', torch.float), ('out', torch.bfloat16),
|
||||
('enable_double_streams', bool), ('stream', torch.cuda.Stream))
|
||||
body = "\n"
|
||||
body += 'std::cout << reinterpret_cast<uint64_t>(lhs) << std::endl;\n'
|
||||
body += 'std::cout << reinterpret_cast<uint64_t>(rhs) << std::endl;\n'
|
||||
body += 'std::cout << reinterpret_cast<uint64_t>(scale) << std::endl;\n'
|
||||
body += 'std::cout << reinterpret_cast<uint64_t>(out) << std::endl;\n'
|
||||
body += 'std::cout << enable_double_streams << std::endl;\n'
|
||||
body += 'std::cout << reinterpret_cast<uint64_t>(stream) << std::endl;\n'
|
||||
code = jit.generate((), args, body)
|
||||
print(code)
|
||||
|
||||
# Build
|
||||
print('Building ...')
|
||||
func = jit.build('test_func', args, code)
|
||||
|
||||
# Test correctness
|
||||
print('Running ...')
|
||||
fp8_tensor = torch.empty((1, ), dtype=torch.float8_e4m3fn, device='cuda')
|
||||
fp32_tensor = torch.empty((1, ), dtype=torch.float, device='cuda')
|
||||
bf16_tensor = torch.empty((1, ), dtype=torch.bfloat16, device='cuda')
|
||||
with Capture() as capture:
|
||||
assert func(fp8_tensor, fp8_tensor, fp32_tensor, bf16_tensor, True, torch.cuda.current_stream()) == 0
|
||||
output = capture.capture()
|
||||
ref_output = f'{fp8_tensor.data_ptr()}\n{fp8_tensor.data_ptr()}\n{fp32_tensor.data_ptr()}\n{bf16_tensor.data_ptr()}\n1\n{torch.cuda.current_stream().cuda_stream}\n'
|
||||
assert output == ref_output, f'{output=}, {ref_output=}'
|
||||
|
||||
print('JIT test passed')
|
||||
Reference in New Issue
Block a user