diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_deepgemm.cu b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_deepgemm.cu index 1ac49803..7e8e6c40 100644 --- a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_deepgemm.cu +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_deepgemm.cu @@ -68,7 +68,7 @@ #include "helper.h" // #include "reference/host/gemm_with_groupwise_scaling.h" -#include "include/deep_gemm/fp8_gemm.cuh" +#include "deep_gemm/include/deep_gemm/fp8_gemm.cuh" // using namespace cute; using namespace deep_gemm; diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/__init__.py b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/__init__.py new file mode 100644 index 00000000..27932b0e --- /dev/null +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/__init__.py @@ -0,0 +1,13 @@ +import torch + +from . import jit +from .jit_kernels import ( + gemm_fp8_fp8_bf16_nt, + m_grouped_gemm_fp8_fp8_bf16_nt_contiguous, + m_grouped_gemm_fp8_fp8_bf16_nt_masked, + cell_div, + set_num_sms, get_num_sms, + get_col_major_tma_aligned_tensor, + get_m_alignment_for_contiguous_layout +) +from .utils import bench, bench_kineto, calc_diff diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/include/deep_gemm/fp8_gemm.cuh new file mode 100644 index 00000000..bf5249e5 --- /dev/null +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -0,0 +1,444 @@ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" +#pragma once + +#include +#include + +#include +#include +#include + +#include "mma_utils.cuh" +#include "scheduler.cuh" +#include "tma_utils.cuh" +#include "utils.cuh" + +namespace deep_gemm { + +enum class Layout { + RowMajor, + ColMajor +}; + +template +__device__ __host__ constexpr int get_num_threads_per_sm(int block_m) { + DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group"); + return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads; +} + +template +__global__ void __launch_bounds__(get_num_threads_per_sm(BLOCK_M), 1) +fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, + uint32_t shape_m, + const __grid_constant__ CUtensorMap tensor_map_a, + const __grid_constant__ CUtensorMap tensor_map_b, + const __grid_constant__ CUtensorMap tensor_map_scales_a, + const __grid_constant__ CUtensorMap tensor_map_d) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + // Scaling checks + DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); + DG_STATIC_ASSERT(cell_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block"); + + // Types + using WGMMA = typename FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Shared memory + static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16); + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float); + static constexpr uint32_t SHAPE_K_SCALES = cell_div(SHAPE_K, BLOCK_K); + static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0); + + // Configs + constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K; + constexpr uint32_t kNumThreads = get_num_threads_per_sm(BLOCK_M); + constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads; + constexpr uint32_t kNumIterations = cell_div(SHAPE_K, kFullKOfAllStages); + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = get_lane_id(); + + // Prefetch TMA descriptors at very beginning + if (threadIdx.x == kNumMathThreads) { + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_a)); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_b)); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_scales_a)); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_d)); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + // Data on shared memory + auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer); + __nv_fp8_e4m3* smem_a[kNumStages]; + __nv_fp8_e4m3* smem_b[kNumStages]; + float* smem_scales_a[kNumStages]; + float* smem_scales_b; + + // TMA Barrier for both divisible and non-divisible cases + Barrier* full_barriers[kNumStages]; + Barrier* empty_barriers[kNumStages]; + + // Fill shared memory pointers + #pragma unroll + for (int i = 0; i < kNumStages; ++ i) { + smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); + smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + smem_scales_a[i] = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SCALES_A_SIZE_PER_STAGE); + } + smem_scales_b = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE)); + + // Fill barriers + DG_STATIC_ASSERT(sizeof(Barrier) % sizeof(float) == 0, "Misaligned barriers"); + DG_STATIC_ASSERT(not kMustUseUniformedScaleB or SHAPE_K_SCALES % (sizeof(Barrier) / sizeof(float)) == 0, "Misaligned barriers"); + auto barrier_start_ptr = reinterpret_cast(smem_scales_b + SHAPE_K_SCALES * (kMustUseUniformedScaleB ? 1 : 2)); + #pragma unroll + for (int i = 0; i < kNumStages; ++ i) { + full_barriers[i] = barrier_start_ptr + i; + empty_barriers[i] = barrier_start_ptr + kNumStages + i; + } + + // Initialize barriers + DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "To many TMA multicast"); + if (threadIdx.x == kNumMathThreads) { + #pragma unroll + for (int i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_view_async_shared(); + (kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void(); + } + + // Synchronize all threads to make barrier visible in normal memory model + (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); + + // For pipeline unrolling + struct DivisibleK {}; + struct NotDivisibleK {}; + auto launch_k_iterations = [](const auto& func) { + if constexpr (SHAPE_K % kFullKOfAllStages == 0) { + for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter) + func(k_iter, DivisibleK{}); + } else { + for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter) + func(k_iter, DivisibleK{}); + func(kNumIterations - 1, NotDivisibleK{}); + } + }; + + // Register reconfigurations + constexpr int kNumTMARegisters = 40; + constexpr int kNumMathRegisters = 232; + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, grouped_layout); + + if (threadIdx.x >= kNumMathThreads) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + if (threadIdx.x == kNumMathThreads) { + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + launch_k_iterations([&](int k_iter, auto type) { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + + #pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++ s) { + // Wait consumer release + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); + + // Issue TMA A with broadcasting + auto& full_barrier = *full_barriers[s]; + int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; + tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), + smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); + tma_copy(&tensor_map_scales_a, reinterpret_cast(&full_barrier), + smem_scales_a[s], m_block_idx * BLOCK_M, + scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K)); + + // Issue TMA B without broadcasting + tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), + smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx)); + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE); + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); + full_barriers[s]->arrive(); + } + }); + } + + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) { + #pragma unroll + for (uint32_t s = 0; s < kNumStages; ++ s) + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0); + const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Decide the number of scales B to load + DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N"); + uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters; + if constexpr (not kMustUseUniformedScaleB) { + num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8; + num_full_iters = min(SHAPE_N - n_block_idx * BLOCK_N, BLOCK_N) / 8; + } + uint32_t num_scales_b = SHAPE_K_SCALES * (num_former_iters >= num_full_iters ? 1 : 2); + + // Load B scales with math warp-groups + // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks + if (threadIdx.x >= 32) { + auto num_previous_lines = scheduler.get_global_idx(cell_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx); + auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES; + #pragma unroll + for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32) + st_shared(smem_scales_b + i, __ldg(local_scales_b + i)); + } + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Accumulation for WGMMA or CUDA promotion + float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; + + // Empty barrier arrival + auto empty_barrier_arrive = [&](int s) { + if constexpr (kNumTMAMulticast == 1) { + lane_idx == 0 ? empty_barriers[s]->arrive() : void(); + } else { + lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void(); + } + }; + + // Launch MMAs + launch_k_iterations([&](int k_iter, auto type) { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + + #pragma unroll + for (int s = 0; s < kNumInnerStages; ++ s) { + // Read B scales + float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1; + // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks + if constexpr (not kMustUseUniformedScaleB) + scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES); + + // Wait TMA arrivals + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results + auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1); + + // Commit WGMMA instructions + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival + empty_barrier_arrive(s); + + // Promote with scales + float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; + float scale_0_1, scale_1_1; + if constexpr (not kMustUseUniformedScaleB) + scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + bool predicate = kMustUseUniformedScaleB or i < num_former_iters; + final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; + } + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + empty_barrier_arrive(s); + } + }); + + // Write back to shared memory using STSM + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { + SM90_U32x4_STSM_N::copy( + __float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}), + __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}), + __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}), + __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16) + ); + } + if constexpr (WGMMA::kNumAccum % 8 != 0) { + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16 + ); + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Use TMA store to write back to global memory + if (threadIdx.x == 0) { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, + scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); + cute::tma_store_arrive(); + cute::tma_store_wait<0>(); + } + __syncwarp(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +template +class Gemm { +private: + using Barrier = cuda::barrier; + +public: + Gemm() = default; + + static void run(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, + uint32_t shape_m, + const CUtensorMap& tma_a_desc, + const CUtensorMap& tma_b_desc, + const CUtensorMap& tma_scales_a_desc, + const CUtensorMap& tma_d_desc, + cudaStream_t stream, + int num_sms, uint32_t smem_size) { + // NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps + constexpr uint32_t kNumTMAThreads = 128; + constexpr uint32_t kNumMathThreadsPerGroup = 128; + auto kernel = fp8_gemm_kernel; + DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); + + // Cluster launch + cudaLaunchConfig_t config; + config.gridDim = num_sms; + config.blockDim = get_num_threads_per_sm(BLOCK_M); + config.dynamicSmemBytes = smem_size; + config.stream = stream; + + // Clusters for TMA multicast + // NOTES: `>= 4` cluster size will cause performance degradation + cudaLaunchAttribute attr; + attr.id = cudaLaunchAttributeClusterDimension; + attr.val.clusterDim = {kNumTMAMulticast, 1, 1}; + config.attrs = &attr; + config.numAttrs = 1; + + // Launch + auto status = cudaLaunchKernelEx(&config, kernel, + gmem_d, scales_b, grouped_layout, + shape_m, + tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc); + DG_HOST_ASSERT(status == cudaSuccess); + } + + template + static CUtensorMap make_2d_tma_a_desc(T* global_address, uint32_t shape_m) { + return make_2d_tma_desc(global_address, Layout::RowMajor, + shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_K, BLOCK_M, BLOCK_K); + } + + template + static CUtensorMap make_2d_tma_b_desc(T* global_address) { + return make_2d_tma_desc(global_address, Layout::ColMajor, + SHAPE_K, SHAPE_N * (kGemmType != GemmType::Normal ? kNumGroups : 1), BLOCK_K, BLOCK_N); + } + + template + static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) { + return make_2d_tma_desc(global_address, Layout::RowMajor, + shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N, BLOCK_M, BLOCK_N, + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); + } + + template + static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) { + // Make TMA aligned to 16 bytes + constexpr uint32_t kAlignment = 16 / sizeof(T); + shape_m = cell_div(shape_m, kAlignment) * kAlignment; + + return make_2d_tma_desc(global_address, Layout::ColMajor, + shape_m, cell_div(SHAPE_K, BLOCK_K) * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), BLOCK_M, 1, + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); + } + + template + static CUtensorMap make_2d_tma_desc( + T* global_address, Layout layout, + uint32_t gmem_rows, uint32_t gmem_cols, + uint32_t smem_rows, uint32_t smem_cols, + CUtensorMapSwizzle swizzle_type = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) { + if (layout == Layout::RowMajor) { + uint64_t gmem_dim[2] = {gmem_cols, gmem_rows}; + uint32_t smem_dim[2] = {smem_cols, smem_rows}; + return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_cols * sizeof(T), smem_dim, swizzle_type); + } else { + uint64_t gmem_dim[2] = {gmem_rows, gmem_cols}; + uint32_t smem_dim[2] = {smem_rows, smem_cols}; + return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_rows * sizeof(T), smem_dim, swizzle_type); + } + } +}; + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/include/deep_gemm/mma_utils.cuh b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/include/deep_gemm/mma_utils.cuh new file mode 100644 index 00000000..b44bf956 --- /dev/null +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/include/deep_gemm/mma_utils.cuh @@ -0,0 +1,885 @@ +#pragma once + +#include + +#include "utils.cuh" + +namespace deep_gemm { + +struct SM90_64x16x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, + float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, + bool scale_d) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + wgmma(desc_a, desc_b, + d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], + scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 16; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +struct SM90_64x24x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, + float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, + float& d08, float& d09, float& d10, float& d11, + bool scale_d) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + wgmma(desc_a, desc_b, + d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], + d[8], d[9], d[10], d[11], + scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 24; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +struct SM90_64x32x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, + float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, + float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, + bool scale_d) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + wgmma(desc_a, desc_b, + d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], + d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], + scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 32; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +struct SM90_64x40x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, + float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, + float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, + float& d16, float& d17, float& d18, float& d19, + bool scale_d) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + wgmma(desc_a, desc_b, + d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], + d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], + d[16], d[17], d[18], d[19], + scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 40; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +struct SM90_64x48x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, + float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, + float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, + float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, + bool scale_d) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + wgmma(desc_a, desc_b, + d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], + d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], + d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], + scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 48; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +struct SM90_64x56x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, + float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, + float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, + float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, + float& d24, float& d25, float& d26, float& d27, + bool scale_d) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}, " + " %28," + " %29," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + wgmma(desc_a, desc_b, + d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], + d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], + d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], + d[24], d[25], d[26], d[27], + scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 56; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +struct SM90_64x64x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, + float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, + float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, + float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, + float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, + bool scale_d) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}, " + " %32," + " %33," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + wgmma(desc_a, desc_b, + d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], + d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], + d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], + d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], + scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 64; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +struct SM90_64x72x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, + float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, + float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, + float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, + float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, + float& d32, float& d33, float& d34, float& d35, + bool scale_d) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}, " + " %36," + " %37," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + wgmma(desc_a, desc_b, + d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], + d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], + d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], + d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], + d[32], d[33], d[34], d[35], + scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 72; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +struct SM90_64x80x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, + float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, + float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, + float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, + float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, + float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, + bool scale_d) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}, " + " %40," + " %41," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + wgmma(desc_a, desc_b, + d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], + d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], + d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], + d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], + d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], + scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 80; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +struct SM90_64x88x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, + float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, + float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, + float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, + float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, + float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, + float& d40, float& d41, float& d42, float& d43, + bool scale_d) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}, " + " %44," + " %45," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + wgmma(desc_a, desc_b, + d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], + d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], + d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], + d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], + d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], + d[40], d[41], d[42], d[43], + scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 88; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +struct SM90_64x96x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, + float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, + float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, + float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, + float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, + float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, + float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47, + bool scale_d) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}, " + " %48," + " %49," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + wgmma(desc_a, desc_b, + d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], + d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], + d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], + d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], + d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], + d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], + scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 96; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +struct SM90_64x104x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, + float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, + float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, + float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, + float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, + float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, + float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47, + float& d48, float& d49, float& d50, float& d51, + bool scale_d) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}, " + " %52," + " %53," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + wgmma(desc_a, desc_b, + d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], + d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], + d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], + d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], + d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], + d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], + d[48], d[49], d[50], d[51], + scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 104; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +struct SM90_64x112x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, + float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, + float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, + float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, + float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, + float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, + float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47, + float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55, + bool scale_d) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}, " + " %56," + " %57," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + wgmma(desc_a, desc_b, + d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], + d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], + d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], + d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], + d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], + d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], + d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55], + scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 112; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +struct SM90_64x120x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, + float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, + float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, + float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, + float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, + float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, + float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47, + float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55, + float& d56, float& d57, float& d58, float& d59, + bool scale_d) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}, " + " %60," + " %61," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + wgmma(desc_a, desc_b, + d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], + d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], + d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], + d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], + d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], + d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], + d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55], + d[56], d[57], d[58], d[59], + scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 120; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +struct SM90_64x128x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, + float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, + float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, + float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, + float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, + float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, + float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47, + float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55, + float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63, + bool scale_d) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}, " + " %64," + " %65," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + wgmma(desc_a, desc_b, + d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], + d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], + d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], + d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], + d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], + d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], + d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55], + d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63], + scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 128; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +struct SM90_64x192x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, + float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, + float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, + float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, + float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, + float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, + float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47, + float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55, + float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63, + float& d64, float& d65, float& d66, float& d67, float& d68, float& d69, float& d70, float& d71, + float& d72, float& d73, float& d74, float& d75, float& d76, float& d77, float& d78, float& d79, + float& d80, float& d81, float& d82, float& d83, float& d84, float& d85, float& d86, float& d87, + float& d88, float& d89, float& d90, float& d91, float& d92, float& d93, float& d94, float& d95, + bool scale_d) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}, " + " %96," + " %97," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + wgmma(desc_a, desc_b, + d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], + d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], + d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], + d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], + d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], + d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], + d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55], + d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63], + d[64], d[65], d[66], d[67], d[68], d[69], d[70], d[71], + d[72], d[73], d[74], d[75], d[76], d[77], d[78], d[79], + d[80], d[81], d[82], d[83], d[84], d[85], d[86], d[87], + d[88], d[89], d[90], d[91], d[92], d[93], d[94], d[95], + scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 192; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +template +struct SM90_U32x2_STSM_N { + __device__ __forceinline__ static void + copy(dtype_t src_0, dtype_t src_1, void* smem_dst) { + const uint32_t src[2] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1)}; + asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" + :: "l"(smem_dst), "r"(src[0]), "r"(src[1])); + } +}; + +template +struct SM90_U32x4_STSM_N { + __device__ __forceinline__ static void + copy(dtype_t src_0, dtype_t src_1, dtype_t src_2, dtype_t src_3, void* smem_dst) { + const uint32_t src[4] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1), + *reinterpret_cast(&src_2), *reinterpret_cast(&src_3)}; + asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" + :: "l"(smem_dst), "r"(src[0]), "r"(src[1]), "r"(src[2]), "r"(src[3])); + } +}; + +__device__ void warpgroup_arrive() { + asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory"); +} + +__device__ void warpgroup_commit_batch() { + asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); +} + +__device__ void warpgroup_fence_operand(float& reg) { + asm volatile("" : "+f"(reg) :: "memory"); +} + +__forceinline__ __device__ uint32_t get_lane_id() { + uint32_t lane_id; + asm("mov.u32 %0, %laneid;" : "=r"(lane_id)); + return lane_id; +} + +__device__ __forceinline__ uint32_t ld_shared(const uint32_t* __restrict__ ptr) { + uint32_t ret; + asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(ptr)); + return ret; +} + +__device__ __forceinline__ int4 ld_shared(const int4* __restrict__ ptr) { + int4 ret; + asm volatile("ld.shared.v4.s32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(ptr)); + return ret; +} + +__device__ __forceinline__ float ld_shared(const float* __restrict__ ptr) { + float ret; + asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(ptr)); + return ret; +} + +__device__ __forceinline__ void st_shared(const float* ptr, float val) { + asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val)); +} + +__device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) { + asm volatile("st.shared.u32 [%0], %1;" :: "l"(ptr), "r"(val)); +} + +template +__device__ void warpgroup_wait() { + DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]"); + asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory"); +} + +union GmmaDescriptor { + __host__ __device__ constexpr GmmaDescriptor() noexcept: desc_(0) {} + + __host__ __device__ constexpr GmmaDescriptor(uint64_t desc) noexcept: desc_(desc) {} + + __host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor const &t) noexcept: desc_(t.desc_) {} + + __host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor &&t) noexcept: desc_(t.desc_) {} + + __host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor const &t) noexcept { + desc_ = t.desc_; + return *this; + } + + __host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor &&t) noexcept { + desc_ = t.desc_; + return *this; + } + + uint64_t desc_; + uint32_t reg32_[2]; + uint16_t reg16_[4]; + + struct { + uint16_t start_address_: 14, : 2; + uint16_t leading_byte_offset_: 14, : 2; + uint16_t stride_byte_offset_: 14, : 2; + uint8_t : 1, base_offset_: 3, : 4; + uint8_t : 6, layout_type_: 2; + } bitfield; + + // Decay to an `uint64_t` + __host__ __device__ constexpr operator uint64_t() const noexcept { return desc_; } +}; + +template +__device__ GmmaDescriptor make_smem_desc(PointerType smem_ptr, int layout_type, + int leading_byte_offset = 0, + int stride_byte_offset = 1024) { + GmmaDescriptor desc; + auto uint_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + desc.bitfield.start_address_ = uint_ptr >> 4; + desc.bitfield.layout_type_ = layout_type; + desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4; + desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4; + desc.bitfield.base_offset_ = 0; + return desc; +} + +template +struct FP8MMASelector { + static constexpr auto select_type() { + if constexpr (N == 16) return SM90_64x16x32_F32E4M3E4M3_SS(); + if constexpr (N == 24) return SM90_64x24x32_F32E4M3E4M3_SS(); + if constexpr (N == 32) return SM90_64x32x32_F32E4M3E4M3_SS(); + if constexpr (N == 40) return SM90_64x40x32_F32E4M3E4M3_SS(); + if constexpr (N == 48) return SM90_64x48x32_F32E4M3E4M3_SS(); + if constexpr (N == 56) return SM90_64x56x32_F32E4M3E4M3_SS(); + if constexpr (N == 64) return SM90_64x64x32_F32E4M3E4M3_SS(); + if constexpr (N == 72) return SM90_64x72x32_F32E4M3E4M3_SS(); + if constexpr (N == 80) return SM90_64x80x32_F32E4M3E4M3_SS(); + if constexpr (N == 88) return SM90_64x88x32_F32E4M3E4M3_SS(); + if constexpr (N == 96) return SM90_64x96x32_F32E4M3E4M3_SS(); + if constexpr (N == 104) return SM90_64x104x32_F32E4M3E4M3_SS(); + if constexpr (N == 112) return SM90_64x112x32_F32E4M3E4M3_SS(); + if constexpr (N == 120) return SM90_64x120x32_F32E4M3E4M3_SS(); + if constexpr (N == 128) return SM90_64x128x32_F32E4M3E4M3_SS(); + if constexpr (N == 192) return SM90_64x192x32_F32E4M3E4M3_SS(); + } + + using type = decltype(select_type()); +}; + +} // namespace deep_gemm diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/include/deep_gemm/scheduler.cuh b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/include/deep_gemm/scheduler.cuh new file mode 100644 index 00000000..5e1c211b --- /dev/null +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/include/deep_gemm/scheduler.cuh @@ -0,0 +1,103 @@ +#include "utils.cuh" + +namespace deep_gemm { + +enum class GemmType { + Normal, + GroupedContiguous, + GroupedMasked +}; + +#pragma clang diagnostic push +#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init" +template +struct Scheduler { + int current_iter = -1; + uint32_t num_aligned_m_blocks; + + // For normal GEMM + // Maybe not used in the masked grouped GEMM + uint32_t num_blocks; + + // For grouped GEMM + int* grouped_layout; + // Only used for masked layout + uint32_t curr_group_idx, curr_cumsum; + + __device__ __forceinline__ explicit Scheduler(const uint32_t shape_m, + int* grouped_layout = nullptr) { + num_aligned_m_blocks = cell_div(shape_m, BLOCK_M); + if constexpr (kGemmType == GemmType::Normal) { + num_blocks = num_aligned_m_blocks * kNumNBlocks; + } else if (kGemmType == GemmType::GroupedContiguous) { + num_blocks = num_aligned_m_blocks * kNumNBlocks; + this->grouped_layout = grouped_layout; + } else if (kGemmType == GemmType::GroupedMasked) { + curr_group_idx = curr_cumsum = 0; + this->grouped_layout = grouped_layout; + } + } + + __device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) { + DG_STATIC_ASSERT(kNumNBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size"); + + // Swizzle for better L2 usages + auto num_blocks_per_group = num_m_blocks * kNumNBlocksPerGroup; + auto group_idx = block_idx / num_blocks_per_group; + auto first_n_block_idx = group_idx * kNumNBlocksPerGroup; + auto num_n_blocks_in_group = min(kNumNBlocksPerGroup, kNumNBlocks - first_n_block_idx); + auto in_group_idx = block_idx % num_blocks_per_group; + m_block_idx = in_group_idx / num_n_blocks_in_group; + n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group; + } + + template + __device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size, + const uint32_t& block_idx, const uint32_t& m_block_idx=0) { + if constexpr (kGemmType == GemmType::Normal) { + return block_idx * block_size; + } else if (kGemmType == GemmType::GroupedContiguous) { + auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg(grouped_layout + m_block_idx * BLOCK_M); + return offset * shape_dim + block_idx * block_size; + } else if (kGemmType == GemmType::GroupedMasked) { + return curr_group_idx * shape_dim + block_idx * block_size; + } + } + + __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { + const auto next_block_idx = (++ current_iter) * gridDim.x + blockIdx.x; + + if constexpr (kGemmType == GemmType::GroupedMasked) { + uint32_t num_m_blocks; + while (true) { + // End of the task + if (curr_group_idx == kNumGroups) + return false; + + // Within current group + num_m_blocks = cell_div(static_cast(__ldg(grouped_layout + curr_group_idx)), BLOCK_M); + auto current_m_block_cumsum = curr_cumsum + num_m_blocks; + if (next_block_idx < current_m_block_cumsum * kNumNBlocks) + break; + + // Move to check the next group + curr_group_idx ++, curr_cumsum = current_m_block_cumsum; + } + + get_swizzled_block_idx(num_m_blocks, next_block_idx - curr_cumsum * kNumNBlocks, m_block_idx, n_block_idx); + } else { + if (next_block_idx >= num_blocks) + return false; + + get_swizzled_block_idx(num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx); + } + return true; + } +}; +#pragma clang diagnostic pop + +} // namespace deep_gemm diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/include/deep_gemm/tma_utils.cuh b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/include/deep_gemm/tma_utils.cuh new file mode 100644 index 00000000..c938c4d9 --- /dev/null +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/include/deep_gemm/tma_utils.cuh @@ -0,0 +1,96 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "utils.cuh" + +namespace deep_gemm { + +template +constexpr CUtensorMapDataType get_CUtensorMapDataType() { + if constexpr (std::is_same::value) { + return CU_TENSOR_MAP_DATA_TYPE_UINT8; + } else if constexpr (std::is_same::value) { + return CU_TENSOR_MAP_DATA_TYPE_UINT8; + } else if constexpr (std::is_same::value) { + return CU_TENSOR_MAP_DATA_TYPE_UINT8; + } else if constexpr (std::is_same::value) { + return CU_TENSOR_MAP_DATA_TYPE_UINT16; + } else if constexpr (std::is_same::value) { + return CU_TENSOR_MAP_DATA_TYPE_UINT32; + } else if constexpr (std::is_same::value) { + return CU_TENSOR_MAP_DATA_TYPE_UINT64; + } else if constexpr (std::is_same::value) { + return CU_TENSOR_MAP_DATA_TYPE_INT32; + } else if constexpr (std::is_same::value) { + return CU_TENSOR_MAP_DATA_TYPE_INT64; + } else if constexpr (std::is_same::value) { + return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + } else if constexpr (std::is_same::value) { + return CU_TENSOR_MAP_DATA_TYPE_FLOAT32; + } else if constexpr (std::is_same::value) { + return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + } else if constexpr (std::is_same::value) { + return CU_TENSOR_MAP_DATA_TYPE_FLOAT64; + } +} + +PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() { + // Get pointer to `cuTensorMapEncodeTiled` + cudaDriverEntryPointQueryResult driver_status; + void* cuTensorMapEncodeTiled_ptr = nullptr; + +#if CUDA_VERSION >= 12050 + cudaGetDriverEntryPointByVersion("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, 12000, + cudaEnableDefault, &driver_status); +#else + cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, + cudaEnableDefault, &driver_status); +#endif + + if (driver_status != cudaDriverEntryPointSuccess) + throw std::runtime_error("driver_status != cudaDriverEntryPointSuccess"); + return reinterpret_cast(cuTensorMapEncodeTiled_ptr); +} + +template +CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2], + uint64_t stride_in_bytes, uint32_t smem_dim[2], + CUtensorMapSwizzle swizzle_type, + PFN_cuTensorMapEncodeTiled encode_func = nullptr) { + CUtensorMap tensor_map{}; + constexpr uint32_t rank = 2; + uint64_t global_stride[rank - 1] = {stride_in_bytes}; + uint32_t elem_strides[rank] = {1, 1}; + + if (encode_func == nullptr) + encode_func = get_cuTensorMapEncodeTiled(); + + auto result = encode_func( + &tensor_map, get_CUtensorMapDataType::type>(), rank, + global_address, gmem_dim, global_stride, smem_dim, elem_strides, + CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type, + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B, + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + DG_HOST_ASSERT(result == CUDA_SUCCESS); + return tensor_map; +} + +template +__device__ __forceinline__ void +tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr, + int32_t const& crd_0, int32_t const& crd_1) { + constexpr auto cache_hint = static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL); + if constexpr (kNumTMAMulticast == 1) { + cute::SM90_TMA_LOAD_2D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1); + } else if (cute::block_rank_in_cluster() == 0) { + cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, (1 << kNumTMAMulticast) - 1, cache_hint, smem_ptr, crd_0, crd_1); + } +} + +} // namespace deep_gemm diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/include/deep_gemm/utils.cuh b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/include/deep_gemm/utils.cuh new file mode 100644 index 00000000..608945d7 --- /dev/null +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/include/deep_gemm/utils.cuh @@ -0,0 +1,48 @@ +#pragma once + +#include + +#ifdef __CLION_IDE__ +__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { asm volatile("trap;"); } +#define printf host_device_printf +#endif + +class AssertionException : public std::exception { +private: + std::string message{}; + +public: + explicit AssertionException(const std::string& message) : message(message) {} + + const char *what() const noexcept override { return message.c_str(); } +}; + +#ifndef DG_HOST_ASSERT +#define DG_HOST_ASSERT(cond) \ +do { \ + if (not (cond)) { \ + printf("Assertion failed: %s:%d, condition: %s\n", \ + __FILE__, __LINE__, #cond); \ + throw AssertionException("Assertion failed: " #cond); \ + } \ +} while (0) +#endif + +#ifndef DG_DEVICE_ASSERT +#define DG_DEVICE_ASSERT(cond) \ +do { \ + if (not (cond)) { \ + printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \ + asm("trap;"); \ + } \ +} while (0) +#endif + +#ifndef DG_STATIC_ASSERT +#define DG_STATIC_ASSERT(cond, reason) static_assert(cond, reason) +#endif + +template +__device__ __host__ constexpr T cell_div(T a, T b) { + return (a + b - 1) / b; +} diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/jit/__init__.py b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/jit/__init__.py new file mode 100644 index 00000000..eb08b142 --- /dev/null +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/jit/__init__.py @@ -0,0 +1,3 @@ +from .compiler import get_nvcc_compiler, build +from .template import cpp_format, generate +from .runtime import Runtime diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/jit/compiler.py b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/jit/compiler.py new file mode 100644 index 00000000..0383eef5 --- /dev/null +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/jit/compiler.py @@ -0,0 +1,146 @@ +import hashlib +import functools +import os +import re +import subprocess +import uuid +from torch.utils.cpp_extension import CUDA_HOME +from typing import Tuple + +from .runtime import Runtime, RuntimeCache +from .template import typename_map + +runtime_cache = RuntimeCache() + + +def hash_to_hex(s: str) -> str: + md5 = hashlib.md5() + md5.update(s.encode('utf-8')) + return md5.hexdigest()[0:12] + + +@functools.lru_cache(maxsize=None) +def get_jit_include_dir() -> str: + return f'{os.path.dirname(os.path.abspath(__file__))}/../include' + + +@functools.lru_cache(maxsize=None) +def get_deep_gemm_version() -> str: + # Update include directories + include_dir = f'{get_jit_include_dir()}/deep_gemm' + assert os.path.exists(include_dir), f'Cannot find GEMM include directory {include_dir}' + md5 = hashlib.md5() + for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))): + with open(f'{include_dir}/{filename}', 'rb') as f: + md5.update(f.read()) + + return md5.hexdigest()[0:12] + + +@functools.lru_cache(maxsize=None) +def get_nvcc_compiler() -> Tuple[str, str]: + paths = [] + if os.getenv('DG_NVCC_COMPILER'): + paths.append(os.getenv('DG_NVCC_COMPILER')) + paths.append(f'{CUDA_HOME}/bin/nvcc') + + # Try to find the first available NVCC compiler + least_version_required = '12.3' + version_pattern = re.compile(r'release (\d+\.\d+)') + for path in paths: + if os.path.exists(path): + match = version_pattern.search(os.popen(f'{path} --version').read()) + version = match.group(1) + assert match, f'Cannot get the version of NVCC compiler {path}' + assert version >= least_version_required, f'NVCC {path} version {version} is lower than {least_version_required}' + return path, version + raise RuntimeError('Cannot find any available NVCC compiler') + + +@functools.lru_cache(maxsize=None) +def get_default_user_dir(): + if 'DG_CACHE_DIR' in os.environ: + path = os.getenv('DG_CACHE_DIR') + os.makedirs(path, exist_ok=True) + return path + return os.path.expanduser('~') + '/.deep_gemm' + + +@functools.lru_cache(maxsize=None) +def get_tmp_dir(): + return f'{get_default_user_dir()}/tmp' + + +@functools.lru_cache(maxsize=None) +def get_cache_dir(): + return f'{get_default_user_dir()}/cache' + + +def make_tmp_dir(): + tmp_dir = get_tmp_dir() + os.makedirs(tmp_dir, exist_ok=True) + return tmp_dir + + +def put(path, data, is_binary=False): + # Write and do POSIX atomic replace + tmp_file_path = f'{make_tmp_dir()}/file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}' + with open(tmp_file_path, 'wb' if is_binary else 'w') as f: + f.write(data) + os.replace(tmp_file_path, path) + + +def build(name: str, arg_defs: tuple, code: str) -> Runtime: + # Compiler flags + nvcc_flags = ['-std=c++17', '-shared', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda', + '-gencode=arch=compute_90a,code=sm_90a', + '--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''), + # Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases + '--diag-suppress=177,174,940'] + cxx_flags = ['-fPIC', '-O3', '-Wno-deprecated-declarations', '-Wno-abi'] + flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}'] + include_dirs = [get_jit_include_dir()] + + # Build signature + enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and int(os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0 + signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}' + name = f'kernel.{name}.{hash_to_hex(signature)}' + path = f'{get_cache_dir()}/{name}' + + # Check runtime cache or file system hit + global runtime_cache + if runtime_cache[path] is not None: + if os.getenv('DG_JIT_DEBUG', None): + print(f'Using cached JIT runtime {name} during build') + return runtime_cache[path] + + # Write the code + os.makedirs(path, exist_ok=True) + args_path = f'{path}/kernel.args' + src_path = f'{path}/kernel.cu' + put(args_path, ', '.join([f"('{arg_def[0]}', {typename_map[arg_def[1]]})" for arg_def in arg_defs])) + put(src_path, code) + + # Compile into a temporary SO file + so_path = f'{path}/kernel.so' + tmp_so_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(so_path)}.so' + + # Compile + command = [get_nvcc_compiler()[0], + src_path, '-o', tmp_so_path, + *flags, + *[f'-I{d}' for d in include_dirs]] + if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False): + print(f'Compiling JIT runtime {name} with command {command}') + assert subprocess.check_call(command) == 0, f'Failed to compile {src_path}' + + # Interleave FFMA reuse + if enable_sass_opt: + pass + + # Atomic replace SO file + os.replace(tmp_so_path, so_path) + + # Put cache and return + runtime_cache[path] = Runtime(path) + return runtime_cache[path] diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/jit/runtime.py b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/jit/runtime.py new file mode 100644 index 00000000..66c370a6 --- /dev/null +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/jit/runtime.py @@ -0,0 +1,66 @@ +import ctypes +import os +import torch +from typing import Optional + +from .template import map_ctype + + +class Runtime: + def __init__(self, path: str) -> None: + self.path = path + self.lib = None + self.args = None + + assert self.is_path_valid(self.path) + + @staticmethod + def is_path_valid(path: str) -> bool: + # Exists and is a directory + if not os.path.exists(path) or not os.path.isdir(path): + return False + + # Contains all necessary files + files = ['kernel.cu', 'kernel.args', 'kernel.so'] + return all(os.path.exists(os.path.join(path, file)) for file in files) + + def __call__(self, *args) -> int: + # Load SO file + if self.lib is None or self.args is None: + self.lib = ctypes.CDLL(os.path.join(self.path, 'kernel.so')) + with open(os.path.join(self.path, 'kernel.args'), 'r') as f: + self.args = eval(f.read()) + + # Check args and launch + assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}' + cargs = [] + for arg, (name, dtype) in zip(args, self.args): + if isinstance(arg, torch.Tensor): + assert arg.dtype == dtype, f'Expected tensor dtype `{dtype}` for `{name}`, got `{arg.dtype}`' + else: + assert isinstance(arg, dtype), f'Expected built-in type `{dtype}` for `{name}`, got `{type(arg)}`' + cargs.append(map_ctype(arg)) + + return_code = ctypes.c_int(0) + self.lib.launch(*cargs, ctypes.byref(return_code)) + return return_code.value + + +class RuntimeCache: + def __init__(self) -> None: + self.cache = {} + + def __getitem__(self, path: str) -> Optional[Runtime]: + # In Python runtime + if path in self.cache: + return self.cache[path] + + # Already compiled + if os.path.exists(path) and Runtime.is_path_valid(path): + runtime = Runtime(path) + self.cache[path] = runtime + return runtime + return None + + def __setitem__(self, path, runtime) -> None: + self.cache[path] = runtime diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/jit/template.py b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/jit/template.py new file mode 100644 index 00000000..b917dec0 --- /dev/null +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/jit/template.py @@ -0,0 +1,93 @@ +import copy +import ctypes +import os +import torch + +from typing import Any, Iterable, Dict, Tuple + + +# Name map for Python `eval` +typename_map: Dict[Any, str] = { + **{t: t.__name__ for t in (bool, int, float)}, + torch.int: 'torch.int', + torch.float: 'torch.float', + torch.bfloat16: 'torch.bfloat16', + torch.float8_e4m3fn: 'torch.float8_e4m3fn', + torch.cuda.Stream: 'torch.cuda.Stream', +} + +# `ctype` map for Python casting +ctype_map: Dict[Any, Any] = { + **{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)}, + **{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)}, +} + + +# Type map for both Python API and source code usages +genc_map = { + bool: ('bool', 'bool'), + int: ('int', 'int'), + float: ('float', 'float'), + torch.int: ('void*', 'int*'), + torch.float: ('void*', 'float*'), + torch.bfloat16: ('void*', '__nv_bfloat16*'), + torch.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'), + torch.cuda.Stream: ('void*', 'cudaStream_t'), +} + + +def map_ctype(value: Any) -> Any: + ctype = ctype_map[value.dtype if isinstance(value, torch.Tensor) else type(value)] + if isinstance(value, torch.Tensor): + return ctype(value.data_ptr()) + if isinstance(value, torch.cuda.Stream): + return ctype(value.cuda_stream) + return ctype(value) + + +def cpp_format(template: str, keys: Dict[str, Any]) -> str: + # We don't use `str.format` because it's not safe for C++ {} braces + new_template = copy.deepcopy(template) + for key, value in keys.items(): + new_template = new_template.replace(f'{{{key}}}', f'{value}') + return new_template + + +def generate(includes: Iterable[str], arg_defs: Iterable[Tuple], body: str) -> str: + # Common prefix + code = '// DeepGEMM auto-generated JIT CUDA source file\n\n' + + # Includes + preload_sys_includes = ['', '', '', ''] + preload_package_includes = ['"cutlass/cutlass.h"'] + + assert isinstance(includes, list) or isinstance(includes, tuple) + sys_includes = sorted(list(set(preload_sys_includes + [include for include in includes if include.startswith('<')]))) + package_includes = sorted(list(set(preload_package_includes + [include for include in includes if include.startswith('"')]))) + code += '\n'.join(f'#include {include}' for include in sys_includes) + '\n\n' + code += '\n'.join(f'#include {include}' for include in package_includes) + '\n\n' + + # Function signature + raw = '__raw_' + get_def = lambda n, t: f'{genc_map[t][0]} ' + (raw if genc_map[t][0] != genc_map[t][1] else '') + n + code += f'extern "C" void launch(' + code += ', '.join([get_def(*arg_def) for arg_def in arg_defs] + ['int& __return_code', ]) + code += ') {\n' + + # Cast raw types + code += ' // Cast raw types (if needed)\n' + for arg_name, arg_type in arg_defs: + if genc_map[arg_type][0] != genc_map[arg_type][1]: + code += f' auto {arg_name} = reinterpret_cast<{genc_map[arg_type][1]}>({raw}{arg_name});\n' + + # Function body + code += '\n'.join([((' ' if line else '') + line) for line in body.split('\n')]) + + # End the function + code += '}\n\n' + + # Debug print + if os.getenv('DG_JIT_DEBUG', None): + print(f'Generated code:\n{code}') + + return code diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/jit_kernels/__init__.py b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/jit_kernels/__init__.py new file mode 100644 index 00000000..d4c9aba7 --- /dev/null +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/jit_kernels/__init__.py @@ -0,0 +1,10 @@ +from .gemm import gemm_fp8_fp8_bf16_nt +from .m_grouped_gemm import ( + m_grouped_gemm_fp8_fp8_bf16_nt_contiguous, + m_grouped_gemm_fp8_fp8_bf16_nt_masked +) +from .utils import ( + cell_div, set_num_sms, get_num_sms, + get_col_major_tma_aligned_tensor, + get_m_alignment_for_contiguous_layout +) diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/jit_kernels/gemm.py b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/jit_kernels/gemm.py new file mode 100644 index 00000000..0251b8c4 --- /dev/null +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/jit_kernels/gemm.py @@ -0,0 +1,171 @@ +import torch +from typing import Tuple + +from .tuner import jit_tuner +from .utils import get_num_sms, cell_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout + +# C++ code templates +includes = ('"deep_gemm/fp8_gemm.cuh"', ) +template = """ +using namespace deep_gemm; + +// Templated args from Python JIT call +constexpr auto N = {N}, K = {K}; +constexpr auto BLOCK_M = {BLOCK_M}; +constexpr auto BLOCK_N = {BLOCK_N}; +constexpr auto kNumStages = {NUM_STAGES}; +constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST}; + +// Make a templated GEMM +using GemmType = Gemm; + +// Launch kernel +auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m); +auto tma_b_desc = GemmType::make_2d_tma_b_desc(rhs); +auto tma_scales_a_desc = GemmType::make_2d_tma_scales_a_desc(lhs_scales, m); +auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m); +GemmType::run(out, rhs_scales, nullptr, + m, + tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc, + stream, num_sms, smem_size); +""" + + +def is_tma_multicast_legal(n: int, block_n: int, num_tma_multicast: int, num_sms: int) -> bool: + if num_tma_multicast == 1: + return True + return (n % (block_n * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0 + + +def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> int: + smem_d = block_m * block_n * 2 + smem_a_per_stage = block_m * block_k + smem_scales_a_per_stage = block_m * 4 + smem_b_per_stage = block_n * block_k + smem_scales_b = cell_div(k, block_k) * 4 + smem_barrier = num_stages * 8 * 2 + + smem_size = 0 + smem_size += smem_d + smem_size += num_stages * smem_a_per_stage + smem_size += num_stages * smem_scales_a_per_stage + smem_size += num_stages * smem_b_per_stage + smem_size += smem_scales_b * (1 if block_k % block_n == 0 else 2) + smem_size += smem_barrier + return smem_size + + +def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, + is_grouped_contiguous: bool = False) -> Tuple[int, int, int, int, int]: + if not is_grouped_contiguous: + # TODO: for some cases, smaller M block is better, add them into tuning space + block_ms = (64 if m <= 64 else 128, ) + else: + block_ms = (get_m_alignment_for_contiguous_layout(), ) + block_ns = tuple(range(16, 129, 8)) + + fix_wave_saturate = lambda x: num_sms if x == 0 else x + get_num_waves = lambda bm, bn: (cell_div(cell_div(m, bm) * cell_div(n, bn) * num_groups, num_sms) if bm else None) + get_last_wave_util = lambda bm, bn: fix_wave_saturate((cell_div(m, bm) * cell_div(n, bn) * num_groups) % num_sms) + + # Decide block sizes by waves + best_block_m, best_block_n = None, None + for block_m in block_ms: + for block_n in block_ns: + success = False + num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n) + if best_block_m is None or best_block_n is None: + success = True + elif num_waves < best_num_waves: + success = True + elif num_waves == best_num_waves: + # Check last wave utilization + util = get_last_wave_util(block_m, block_n) + best_util = get_last_wave_util(best_block_m, best_block_n) + success = util > best_util or (util == best_util and (block_n >= best_block_n and block_m <= best_block_m)) + best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n) + assert best_block_m is not None and best_block_n is not None + + # Always pick the longest one + # NOTES: for double B scales, the best number of stages may be reduced + best_num_stages, best_smem_size, sm90_capacity = None, None, 232448 + for num_stages in (6, 5, 4) if 128 % best_block_n != 0 else (8, 7, 6, 5, 4): + best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n) + if best_smem_size <= sm90_capacity: + best_num_stages = num_stages + break + assert best_num_stages is not None + + # Decide the number of TMA multicast + best_num_tma_multicast = 1 + if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, num_sms) and num_groups == 1: + best_num_tma_multicast = 2 + + return best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size + + +def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], + rhs: Tuple[torch.Tensor, torch.Tensor], + out: torch.Tensor) -> None: + """ + Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. + LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. + RHS and RHS scaling factors are required to be transposed. + The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement, + this function will do a transposing with a set of slow PyTorch operations. + + Arguments: + lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`, + the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`. + rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`. + the second element is an FP32 128x128 scaling tensor for RHS of shape `[⌈n / 128⌉, ⌈k / 128⌉]`. + out: the BF16 output tensor of shape `[m, n]`, representing the result. + """ + lhs, lhs_scales = lhs + rhs, rhs_scales = rhs + m, k = lhs.shape + n, k_ = rhs.shape + m_, n_ = out.shape + + assert n % 64 == 0 and k % 128 == 0 + + # Type and shape checks + assert m == m_ and n == n_ and k == k_ + assert n > 0 and k > 0 + assert lhs_scales.shape == (m, (k + 127) // 128) + assert rhs_scales.shape == ((n + 127) // 128, (k + 127) // 128) + assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 + assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 + assert out.dtype == torch.bfloat16 + assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous() + + # LHS scales must be transposed for TMA load, but not for RHS scales + # NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels + lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) + assert rhs_scales.is_contiguous() + + # Do nothing if `m` is zero + if m == 0: + return + + # Auto-tuning with compilation + global includes, template + num_sms = get_num_sms() + block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms) + args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_size) + runtime = jit_tuner.compile_and_tune( + name='gemm_fp8_fp8_bf16_nt', + keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, + 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast}, + space=(), + includes=includes, + arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float), + ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float), + ('out', torch.bfloat16), ('m', int), + ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)), + template=template, + args=args + ) + + # Run the kernel + runtime(*args) diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/jit_kernels/m_grouped_gemm.py b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/jit_kernels/m_grouped_gemm.py new file mode 100644 index 00000000..6d6e39b3 --- /dev/null +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -0,0 +1,182 @@ +import torch +from typing import Tuple + +from .gemm import get_best_configs +from .tuner import jit_tuner +from .utils import get_col_major_tma_aligned_tensor, get_num_sms + +# C++ code templates +includes = ('"deep_gemm/fp8_gemm.cuh"', ) +template = """ +using namespace deep_gemm; + +// Templated args from Python JIT call +constexpr auto N = {N}, K = {K}; +constexpr auto BLOCK_M = {BLOCK_M}; +constexpr auto BLOCK_N = {BLOCK_N}; +constexpr auto kNumStages = {NUM_STAGES}; +constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST}; + +// Make a templated grouped GEMM +using GemmType = Gemm; + +// Launch kernel +auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m); +auto tma_b_desc = GemmType::make_2d_tma_b_desc(rhs); +auto tma_scales_a_desc = GemmType::make_2d_tma_scales_a_desc(lhs_scales, m); +auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m); +GemmType::run(out, rhs_scales, grouped_layout, + m, + tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc, + stream, num_sms, smem_size); +""" + + +def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor], + rhs: Tuple[torch.Tensor, torch.Tensor], + out: torch.Tensor, m_indices: torch.Tensor) -> None: + """ + Do a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. + LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. + RHS and RHS scaling factors are required to be transposed. + The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement, + this function will do a transposing with a set of slow PyTorch operations. + On the M axis, inputs are grouped into several batches, of which batch sizes aligned to + `get_m_alignment_for_contiguous_layout()` (128). + + Arguments: + lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`, + the second element is an FP32 1x128 scaling tensor for LHS of shape `[m_sum, ⌈k / 128⌉]`. + rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`. + the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`. + out: the BF16 output tensor of shape `[m_sum, n]`, representing the result. + m_indices: a tensor of shape `[m_sum]` with type `torch.int`. + `m_indices[i]` records the group which the j-th row of the LHS belong to, + which means that the i-th row of the LHS matrix will be multiplied with `rhs[m_indices[i]]`. + Values of `m_indices` in every-m-alignment-block must also be the same. + `-1` in this tensor indicates no RHS matrix selected, the kernel will skip the computation for that aligned block. + """ + lhs, lhs_scales = lhs + rhs, rhs_scales = rhs + m, k = lhs.shape + num_groups, n, k_ = rhs.shape + m_, n_ = out.shape + m__ = m_indices.numel() + + # Type and shape checks + assert m == m_ == m__ and k == k_ and n == n_ + assert lhs_scales.shape == (m, (k + 127) // 128) + assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128) + assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 + assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 + assert out.dtype == torch.bfloat16 + assert m_indices.dtype == torch.int32 + assert lhs.is_contiguous() and rhs.is_contiguous() + assert out.is_contiguous() and m_indices.is_contiguous() + + # LHS scales must be transposed for TMA load, but not for RHS scales + lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) + assert rhs_scales.is_contiguous() + + # Do nothing if `m` is zero + if m == 0: + return + + # Auto-tuning with compilation + global includes, template + num_sms = get_num_sms() + block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms, + is_grouped_contiguous=True) + args = (lhs, lhs_scales, rhs, rhs_scales, out, + m_indices, m, num_groups, + torch.cuda.current_stream(), num_sms, smem_size) + runtime = jit_tuner.compile_and_tune( + name='m_grouped_gemm_fp8_fp8_bf16_nt', + keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups, + 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedContiguous'}, + space=(), + includes=includes, + arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float), + ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float), + ('out', torch.bfloat16), + ('grouped_layout', torch.int32), ('m', int), ('num_groups', int), + ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)), + template=template, + args=args + ) + + # Run the kernel + runtime(*args) + + +def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor], + rhs: Tuple[torch.Tensor, torch.Tensor], + out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None: + """ + Do a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. + LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. + RHS and RHS scaling factors are required to be transposed. + The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement, + this function will do a transposing with a set of slow PyTorch operations. + Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch + should be separately transposed. + + Arguments: + lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`, + the second element is an FP32 1x128 scaling tensor for LHS of shape `[num_groups, m_max, ⌈k / 128⌉]`. + rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`. + the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`. + out: the BF16 output tensor of shape `[num_groups, m_max, n]`, representing the result. + masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute + in the i-th group. + expected_m: a value hint (which is a value on CPU) for the M expectation of each batch, + correctly setting this value may lead to better performance. + """ + lhs, lhs_scales = lhs + rhs, rhs_scales = rhs + num_groups, m, k = lhs.shape + num_groups_, n, k_ = rhs.shape + num_groups__, m_, n_ = out.shape + num_groups___ = masked_m.numel() + + # Type and shape checks + assert num_groups == num_groups_ == num_groups__ == num_groups___ + assert m == m_ and n == n_ and k == k_ + assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0 + assert lhs_scales.shape == (num_groups, m, (k + 127) // 128) + assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128) + assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 + assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 + assert out.dtype == torch.bfloat16 + assert masked_m.dtype == torch.int32 + assert lhs.is_contiguous() and rhs.is_contiguous() + assert out.is_contiguous() and masked_m.is_contiguous() + + # LHS scales must be transposed for TMA load, but not for RHS scales + lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) + assert rhs_scales.is_contiguous() + + # Auto-tuning with compilation + global includes, template + num_sms = get_num_sms() + block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms) + args = (lhs, lhs_scales, rhs, rhs_scales, out, + masked_m, m, + torch.cuda.current_stream(), num_sms, smem_size) + runtime = jit_tuner.compile_and_tune( + name='m_grouped_gemm_fp8_fp8_bf16_nt', + keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups, + 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedMasked'}, + space=(), + includes=includes, + arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float), + ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float), + ('out', torch.bfloat16), + ('grouped_layout', torch.int32), ('m', int), + ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)), + template=template, + args=args + ) + + # Run the kernel + runtime(*args) diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/jit_kernels/tuner.py b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/jit_kernels/tuner.py new file mode 100644 index 00000000..6ed67499 --- /dev/null +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/jit_kernels/tuner.py @@ -0,0 +1,81 @@ +import copy +import os +import torch +from typing import Any, Dict + +from ..jit import build, cpp_format, generate, Runtime + + +class JITTuner: + def __init__(self) -> None: + self.tuned = {} + + def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple, + includes: tuple, arg_defs: tuple, template: str, args: tuple) -> Runtime: + # NOTES: we always assume the space and template will not change + # We also assume the GPU device will not be changed + # NOTES: the function must have no accumulated side effects + keys = {k: keys[k] for k in sorted(keys.keys())} + signature = (name, f'{keys}') + if signature in self.tuned: + if os.getenv('DG_JIT_DEBUG', None): + print(f'Using cached JIT kernel {name} with keys {keys}') + return self.tuned[signature] + + if os.getenv('DG_JIT_DEBUG', None): + print(f'Auto-tuning JIT kernel {name} with keys {keys}') + + assert signature not in self.tuned + assert args is not None + space = (dict(), ) if len(space) == 0 else space + + kernels = [] + for tuned_keys in space: + assert isinstance(tuned_keys, dict) + full_keys = copy.deepcopy(keys) + full_keys.update(tuned_keys) + code = generate(includes, arg_defs, cpp_format(template, full_keys)) + + # Illegal build must raise errors + kernels.append((build(name, arg_defs, code), tuned_keys)) + + best_runtime, best_time, best_keys = None, None, None + for runtime, tuned_keys in kernels: + if len(space) > 1: + # Check kernel validity + return_code = runtime(*args) + if return_code != 0: + # Pass illegal kernels, e.g. insufficient shared memory capacity + if os.getenv('DG_JIT_DEBUG', None): + print(f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}') + continue + + # Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda').zero_() + torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn((8192, 8192), dtype=torch.float, device='cuda') + start_event.record() + for i in range(20): + assert runtime(*args) == 0 + end_event.record() + end_event.synchronize() + elapsed_time = start_event.elapsed_time(end_event) + else: + elapsed_time = 0 + + # Compare if better + if best_time is None or elapsed_time < best_time: + best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys + if os.getenv('DG_JIT_DEBUG', None): + print(f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}') + assert best_runtime is not None, f'Failed to tune JIT kernel {name} with keys {keys}' + + # Cache the best runtime and return + if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_PRINT_AUTOTUNE', None): + print(f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}') + self.tuned[signature] = best_runtime + return best_runtime + + +jit_tuner = JITTuner() diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/jit_kernels/utils.py b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/jit_kernels/utils.py new file mode 100644 index 00000000..8ae50c92 --- /dev/null +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/jit_kernels/utils.py @@ -0,0 +1,105 @@ +import torch + +_num_sms = None + + +def set_num_sms(num_sms: int) -> None: + """ + Set the maximum SM count for all GEMM kernels to use. + + Arguments: + num_sms: the desired maximum SM count for all GEMM kernels to use. + """ + global _num_sms + assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count + _num_sms = num_sms + + +def get_num_sms() -> int: + """ + Get the current maximum limit of SM count for all GEMM kernels to use. + If the count is never specified, the function will return the number of device SMs. + + Returns: + Current maximum limit of SM count for all GEMM kernels to use. + """ + global _num_sms + if _num_sms is None: + _num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count + return _num_sms + + +def cell_div(x: int, y: int) -> int: + """ + Perform ceiling division of two integers. + + Args: + x: the dividend. + y: the divisor. + + Returns: + The result of the ceiling division. + """ + return (x + y - 1) // y + + +def get_m_alignment_for_contiguous_layout(): + """ + When we do a grouped GEMM in contiguous format, LHS are grouped into several batches along the M axis. + Since we deal with exactly one sub-matrix of RHS for each GEMM block, batch sizes above should align well + with GEMM block shape. + + Returns: + Group-level alignment requirement for grouped contiguous layout, which is always 128. + """ + return 128 + + +def get_tma_aligned_size(x: int, element_size: int) -> int: + """ + Global memory address of TMA must be 16-byte aligned. + Since we use column-major layout for the LHS scaling tensor, + the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes. + + Arguments: + x: original M-axis shape of the LHS scaling tensor. + element_size: element size of the LHS scaling tensor. + + Returns: + M-axis shape of the LHS scaling tensor after padding. + """ + tma_alignment_bytes = 16 + assert tma_alignment_bytes % element_size == 0 + alignment = tma_alignment_bytes // element_size + return cell_div(x, alignment) * alignment + + +def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: + """ + Returns TMA-aligned transposed format of the input tensor. `torch.transpose` will be called if necessary. + If the input tensor is already column-major layout and 16-byte aligned along the M axis + (thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing. + + Arguments: + x: usually the LHS scaling tensor in GEMM. + + Returns: + The LHS scaling tensor of TMA-aligned transposed format. + """ + # NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA + assert x.dim() in (2, 3) + remove_dim = False + if x.dim() == 2: + x, remove_dim = x.unsqueeze(0), True + + b, m, n = x.shape + aligned_m = get_tma_aligned_size(m, x.element_size()) + + # The last kernel gives a column-major TMA aligned layout + if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m: + return x.squeeze(0) if remove_dim else x + + # Normal layout requires transposing + aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2) + aligned_x[:, :m, :] = x + return aligned_x.squeeze(0) if remove_dim else aligned_x diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/utils.py b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/utils.py new file mode 100644 index 00000000..b6bd7fb6 --- /dev/null +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/deep_gemm/utils.py @@ -0,0 +1,154 @@ +import os +import sys +import torch +import torch.distributed as dist + + +def bench(fn, num_warmups: int = 5, num_tests: int = 10, + high_precision: bool = False): + # Flush L2 cache with 256 MB data + torch.cuda.synchronize() + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda') + cache.zero_() + + # Warmup + for _ in range(num_warmups): + fn() + + # Add a large kernel to eliminate the CPU launch overhead + if high_precision: + x = torch.randn((8192, 8192), dtype=torch.float, device='cuda') + y = torch.randn((8192, 8192), dtype=torch.float, device='cuda') + x @ y + + # Testing + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for i in range(num_tests): + fn() + end_event.record() + torch.cuda.synchronize() + + return start_event.elapsed_time(end_event) / num_tests + + +class empty_suppress: + def __enter__(self): + return self + + def __exit__(self, *_): + pass + + +class suppress_stdout_stderr: + def __enter__(self): + self.outnull_file = open(os.devnull, 'w') + self.errnull_file = open(os.devnull, 'w') + + self.old_stdout_fileno_undup = sys.stdout.fileno() + self.old_stderr_fileno_undup = sys.stderr.fileno() + + self.old_stdout_fileno = os.dup(sys.stdout.fileno()) + self.old_stderr_fileno = os.dup(sys.stderr.fileno()) + + self.old_stdout = sys.stdout + self.old_stderr = sys.stderr + + os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup) + os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup) + + sys.stdout = self.outnull_file + sys.stderr = self.errnull_file + return self + + def __exit__(self, *_): + sys.stdout = self.old_stdout + sys.stderr = self.old_stderr + + os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup) + os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) + + os.close(self.old_stdout_fileno) + os.close(self.old_stderr_fileno) + + self.outnull_file.close() + self.errnull_file.close() + + +def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False, + trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = False): + # Conflict with Nsight Systems + using_nsys = os.environ.get('DG_NSYS_PROFILING', False) + + # For some auto-tuning kernels with prints + fn() + + # Profile + suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress + with suppress(): + schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) if not using_nsys else None + profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress() + with profiler: + for i in range(2): + # NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead + if barrier_comm_profiling: + lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda') + rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda') + lhs @ rhs + dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda')) + for _ in range(num_tests): + if flush_l2: + torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda').zero_() + fn() + + if not using_nsys: + profiler.step() + + # Return 1 if using Nsight Systems + if using_nsys: + return 1 + + # Parse the profiling table + assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) + is_tupled = isinstance(kernel_names, tuple) + prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') + kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names + assert all([isinstance(name, str) for name in kernel_names]) + for name in kernel_names: + assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table' + + # Save chrome traces + if trace_path is not None: + profiler.export_chrome_trace(trace_path) + + # Return average kernel times + units = {'ms': 1e3, 'us': 1e6} + kernel_times = [] + for name in kernel_names: + for line in prof_lines: + if name in line: + time_str = line.split()[-2] + for unit, scale in units.items(): + if unit in time_str: + kernel_times.append(float(time_str.replace(unit, '')) / scale) + break + break + return tuple(kernel_times) if is_tupled else kernel_times[0] + + +def calc_diff(x, y): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +def count_bytes(tensors): + total = 0 + for t in tensors: + if isinstance(t, tuple): + total += count_bytes(t) + else: + total += t.numel() * t.element_size() + return total diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/setup.py b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/setup.py new file mode 100644 index 00000000..c674dad3 --- /dev/null +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/setup.py @@ -0,0 +1,69 @@ +import os +import setuptools +import shutil +import subprocess +from setuptools.command.develop import develop +from setuptools.command.install import install + +current_dir = os.path.dirname(os.path.realpath(__file__)) +jit_include_dirs = ('deep_gemm/include/deep_gemm', ) +cutlass_dirs = '../../include' +third_party_include_dirs = (os.path.join(cutlass_dirs, 'cute'), os.path.join(cutlass_dirs, 'cutlass')) +print(third_party_include_dirs) + + +class PostDevelopCommand(develop): + def run(self): + develop.run(self) + self.make_jit_include_symlinks() + + @staticmethod + def make_jit_include_symlinks(): + # Make symbolic links of third-party include directories + for d in third_party_include_dirs: + dirname = d.split('/')[-1] + src_dir = f'{current_dir}/{d}' + dst_dir = f'{current_dir}/deep_gemm/include/{dirname}' + if not os.path.exists(src_dir): + os.makedirs(src_dir, exist_ok=True) + assert os.path.exists(src_dir) + if os.path.exists(dst_dir): + assert os.path.islink(dst_dir) + os.unlink(dst_dir) + os.symlink(src_dir, dst_dir, target_is_directory=True) + + +class PostInstallCommand(install): + def run(self): + install.run(self) + self.copy_jit_includes() + + def copy_jit_includes(self): + # Copy include directories needed by JIT + shutil.rmtree(f'{self.build_lib}/deep_gemm/include', ignore_errors=True) + os.makedirs(f'{self.build_lib}/deep_gemm/include', exist_ok=False) + for d in jit_include_dirs + third_party_include_dirs: + src_dir = f'{current_dir}/{d}' + dst_dir = f'{self.build_lib}/deep_gemm/include/{d.split("/")[-1]}' + assert os.path.exists(src_dir) + shutil.copytree(src_dir, dst_dir) + + +if __name__ == '__main__': + # noinspection PyBroadException + try: + cmd = ['git', 'rev-parse', '--short', 'HEAD'] + revision = '+' + subprocess.check_output(cmd).decode('ascii').rstrip() + except: + revision = '' + + # noinspection PyTypeChecker + setuptools.setup( + name='deep_gemm', + version='1.0.0' + revision, + packages=['deep_gemm', 'deep_gemm/jit', 'deep_gemm/jit_kernels'], + cmdclass={ + 'develop': PostDevelopCommand, + 'install': PostInstallCommand + } + ) diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/tests/test_core.py b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/tests/test_core.py new file mode 100644 index 00000000..b4309037 --- /dev/null +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/tests/test_core.py @@ -0,0 +1,158 @@ +import random +import torch +from typing import Tuple + +import deep_gemm +from deep_gemm import bench_kineto, calc_diff, cell_div, get_col_major_tma_aligned_tensor + + +def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 and x.size(1) % 128 == 0 + m, n = x.shape + x_view = x.view(m, -1, 128) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) + + +def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros((cell_div(m, 128) * 128, cell_div(n, 128) * 128), dtype=x.dtype, device=x.device) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) + + +def construct(m: int, k: int, n: int) -> \ + Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]: + x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) + y = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) + out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) + ref_out = x @ y.t() + + x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_block_cast_to_fp8(y) + # Transpose earlier so that the testing will not trigger transposing kernels + x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) + return x_fp8, y_fp8, out, ref_out + + +def construct_grouped(num_groups: int, m: int, k: int, n: int, is_masked: bool) -> \ + Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]: + x = torch.randn((num_groups, m, k), device='cuda', dtype=torch.bfloat16) + y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) + out = torch.empty((num_groups, m, n), device='cuda', dtype=torch.bfloat16) + ref_out = torch.einsum('gmk,gnk->gmn', x, y) + + assert m % 4 == 0, f'TMA alignment error: {m}' + x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), torch.empty((num_groups, m, k // 128), device='cuda', dtype=torch.float)) + y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, (n + 127) // 128, k // 128), device='cuda', dtype=torch.float)) + for i in range(num_groups): + x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i]) + y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) + + # For non-masked input, we must merge the group and M dims + if not is_masked: + x_fp8 = (x_fp8[0].view(-1, k), per_token_cast_to_fp8(x.view(-1, k))[1]) + out, ref_out = out.view(-1, n), ref_out.view(-1, n) + + # Transpose earlier so that the testing will not trigger transposing kernels + x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) + return x_fp8, y_fp8, out, ref_out + + +def test_gemm() -> None: + print('Testing GEMM:') + for m in (64, 128, 4096): + for k, n in [(7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)]: + x_fp8, y_fp8, out, ref_out = construct(m, k, n) + deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out) + diff = calc_diff(out, ref_out) + assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}' + + # noinspection PyShadowingNames + def test_func(): + # Construct new tensors every time to avoid L2 cache acceleration + x_fp8, y_fp8, out, ref_out = construct(m, k, n) + deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out) + + t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) + print(f' > Performance (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | ' + f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, ' + f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s') + print() + + +def test_m_grouped_gemm_contiguous() -> None: + print('Testing grouped contiguous GEMM:') + + for num_groups, m, k, n in ((4, 8192, 7168, 4096), (4, 8192, 2048, 7168), (8, 4096, 7168, 4096), (8, 4096, 2048, 7168)): + # TODO: make a stronger test + x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=False) + m_indices = torch.arange(0, num_groups, device='cuda', dtype=torch.int) + m_indices = m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1) + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices) + diff = calc_diff(out, ref_out) + assert diff < 0.001, f'm={m * num_groups}, {k=}, {n=}, {diff:.5f}' + + # noinspection PyShadowingNames + def test_func(): + # Construct new tensors every time to avoid L2 cache acceleration + x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=False) + m_indices = torch.arange(0, num_groups, device='cuda', dtype=torch.int) + m_indices = m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1) + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices) + + t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) + print(f' > Performance ({num_groups=}, m_per_group={m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' + f'throughput: {2 * num_groups * m * n * k / t / 1e12:4.0f} TFLOPS, ' + f'{(num_groups * (m * k + k * n + m * n * 2)) / 1e9 / t:4.0f} GB/s') + print() + + +def test_m_grouped_gemm_masked() -> None: + print('Testing grouped masked GEMM:') + + for num_groups, m in ((1, 1024), (2, 512), (4, 256)): + for k, n in ((7168, 4096), (2048, 7168), ): + # Test correctness + masked_m_candidates = list(filter(lambda candidate: candidate <= m, (64, 128, 192, 256, 320, 384))) + for i in range(10): + x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True) + masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int) + for j in range(num_groups): + masked_m[j] = random.choice(masked_m_candidates) + expected_m = int(masked_m.float().mean()) + 1 + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m) + for j in range(num_groups): + diff = calc_diff(out[j, :masked_m[j].item()], ref_out[j, :masked_m[j].item()]) + assert diff < 0.001, f'{m=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}' + + # noinspection PyShadowingNames + def test_func(): + # Construct new tensors every time to avoid L2 cache acceleration + x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True) + masked_m = torch.ones((num_groups, ), device='cuda', dtype=torch.int) * m + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, m) + + # Test performance with fixed shapes + t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) + print(f' > Performance ({num_groups=}, m_per_group={m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' + f'throughput: {2 * num_groups * m * n * k / t / 1e12:4.0f} TFLOPS, ' + f'{(num_groups * (m * k + k * n + m * n * 2)) / 1e9 / t:4.0f} GB/s') + print() + + +if __name__ == '__main__': + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.manual_seed(0) + random.seed(0) + + print('Library path:') + print(f' > {deep_gemm.__path__}\n') + + test_gemm() + test_m_grouped_gemm_contiguous() + test_m_grouped_gemm_masked() diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/tests/test_jit.py b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/tests/test_jit.py new file mode 100644 index 00000000..78bc77b3 --- /dev/null +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/tests/test_jit.py @@ -0,0 +1,64 @@ +import os +import torch +from typing import Any + +from deep_gemm import jit + + +class Capture: + def __init__(self) -> None: + self.read_fd = None + self.write_fd = None + self.saved_stdout = None + self.captured = None + + def __enter__(self) -> Any: + self.read_fd, self.write_fd = os.pipe() + self.saved_stdout = os.dup(1) + os.dup2(self.write_fd, 1) + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + os.dup2(self.saved_stdout, 1) + os.close(self.write_fd) + with os.fdopen(self.read_fd, 'r') as f: + self.captured = f.read() + + def capture(self) -> str: + return self.captured + + +if __name__ == '__main__': + # Runtime + print(f'NVCC compiler: {jit.get_nvcc_compiler()}\n') + + # Templates + print('Generated code:') + args = (('lhs', torch.float8_e4m3fn), ('rhs', torch.float8_e4m3fn), ('scale', torch.float), ('out', torch.bfloat16), + ('enable_double_streams', bool), ('stream', torch.cuda.Stream)) + body = "\n" + body += 'std::cout << reinterpret_cast(lhs) << std::endl;\n' + body += 'std::cout << reinterpret_cast(rhs) << std::endl;\n' + body += 'std::cout << reinterpret_cast(scale) << std::endl;\n' + body += 'std::cout << reinterpret_cast(out) << std::endl;\n' + body += 'std::cout << enable_double_streams << std::endl;\n' + body += 'std::cout << reinterpret_cast(stream) << std::endl;\n' + code = jit.generate((), args, body) + print(code) + + # Build + print('Building ...') + func = jit.build('test_func', args, code) + + # Test correctness + print('Running ...') + fp8_tensor = torch.empty((1, ), dtype=torch.float8_e4m3fn, device='cuda') + fp32_tensor = torch.empty((1, ), dtype=torch.float, device='cuda') + bf16_tensor = torch.empty((1, ), dtype=torch.bfloat16, device='cuda') + with Capture() as capture: + assert func(fp8_tensor, fp8_tensor, fp32_tensor, bf16_tensor, True, torch.cuda.current_stream()) == 0 + output = capture.capture() + ref_output = f'{fp8_tensor.data_ptr()}\n{fp8_tensor.data_ptr()}\n{fp32_tensor.data_ptr()}\n{bf16_tensor.data_ptr()}\n1\n{torch.cuda.current_stream().cuda_stream}\n' + assert output == ref_output, f'{output=}, {ref_output=}' + + print('JIT test passed')