Add cutlass support for blackwell fp8 blockwise gemm (#14383)
Signed-off-by: Shu Wang <shuw@nvidia.com>
This commit is contained in:
@ -418,6 +418,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu"
|
||||
)
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
|
||||
@ -59,3 +59,13 @@ struct enable_sm90_only : Kernel {
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Kernel>
|
||||
struct enable_sm100_only : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE void operator()(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1000
|
||||
Kernel::operator()(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@ -0,0 +1,27 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
TORCH_CHECK(
|
||||
a.size(0) % 4 == 0,
|
||||
"Input tensor must have a number of rows that is a multiple of 4. ",
|
||||
"but got: ", a.size(0), " rows.");
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::bfloat16_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::half_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@ -0,0 +1,205 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass_gemm_caller.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename OutType, typename MmaTileShape, typename ScalesPerTile,
|
||||
class ClusterShape, typename EpilogueScheduler,
|
||||
typename MainloopScheduler>
|
||||
struct cutlass_3x_gemm_fp8_blockwise {
|
||||
using ElementAB = cutlass::float_e4m3_t;
|
||||
|
||||
using ElementA = ElementAB;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||
|
||||
using ElementB = ElementAB;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||
|
||||
using ElementC = void;
|
||||
using ElementD = OutType;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
using LayoutC = LayoutD;
|
||||
static constexpr int AlignmentC = AlignmentD;
|
||||
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
using ElementBlockScale = float;
|
||||
|
||||
// MMA and Cluster Tile Shapes
|
||||
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster
|
||||
// Shape %2 == 0 using MmaTileShape_MNK = Shape<_128,_128,_128>;
|
||||
static constexpr int ScaleMsPerTile = size<0>(ScalesPerTile{});
|
||||
static constexpr int ScaleGranularityM =
|
||||
size<0>(MmaTileShape{}) / ScaleMsPerTile;
|
||||
static constexpr int ScaleGranularityN =
|
||||
size<1>(MmaTileShape{}) / size<1>(ScalesPerTile{});
|
||||
static constexpr int ScaleGranularityK =
|
||||
size<2>(MmaTileShape{}) / size<2>(ScalesPerTile{});
|
||||
|
||||
// Shape of the threadblocks in a cluster
|
||||
using ClusterShape_MNK = ClusterShape;
|
||||
|
||||
using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<
|
||||
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
|
||||
cute::UMMA::Major::MN, cute::UMMA::Major::K>;
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
using ElementScalar = float;
|
||||
// clang-format off
|
||||
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementScalar, RoundStyle>;
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
AlignmentC,
|
||||
ElementD,
|
||||
LayoutD,
|
||||
AlignmentD,
|
||||
EpilogueScheduler,
|
||||
DefaultOperation
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto;
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
ElementA,
|
||||
cute::tuple<LayoutA, LayoutSFA>,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
cute::tuple<LayoutB, LayoutSFB>,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
MainloopScheduler
|
||||
>::CollectiveOp;
|
||||
// clang-format on
|
||||
|
||||
using KernelType = enable_sm100_only<cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>>;
|
||||
|
||||
struct GemmKernel : public KernelType {};
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using LayoutSFA = typename Gemm::LayoutSFA;
|
||||
using LayoutSFB = typename Gemm::LayoutSFB;
|
||||
using ScaleConfig = typename Gemm::ScaleConfig;
|
||||
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
|
||||
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
|
||||
auto prob_shape = cute::make_shape(m, n, k, 1);
|
||||
|
||||
StrideA a_stride;
|
||||
StrideB b_stride;
|
||||
StrideC c_stride;
|
||||
a_stride =
|
||||
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
|
||||
b_stride =
|
||||
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
|
||||
c_stride =
|
||||
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1));
|
||||
|
||||
LayoutSFA layout_SFA =
|
||||
ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
|
||||
LayoutSFB layout_SFB =
|
||||
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
|
||||
|
||||
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
||||
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
|
||||
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
|
||||
|
||||
typename GemmKernel::MainloopArguments mainloop_args{
|
||||
a_ptr, a_stride, b_ptr, b_stride,
|
||||
a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB};
|
||||
|
||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
{}, c_ptr, c_stride, c_ptr, c_stride};
|
||||
c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
|
||||
epilogue_args);
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
auto m = a.size(0);
|
||||
auto k = a.size(1);
|
||||
auto n = b.size(1);
|
||||
int sms;
|
||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());
|
||||
|
||||
auto should_use_2sm = [&sms](int m, int n, int tile1SM = 128) {
|
||||
return std::ceil(static_cast<float>(m) / tile1SM) *
|
||||
std::ceil(static_cast<float>(n) / tile1SM) >=
|
||||
sms;
|
||||
};
|
||||
bool use_2sm = should_use_2sm(m, n);
|
||||
if (use_2sm) {
|
||||
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
||||
OutType, Shape<_256, _128, _128>, Shape<_256, _1, _1>,
|
||||
Shape<_2, _2, _1>, cutlass::epilogue::TmaWarpSpecialized2Sm,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
} else {
|
||||
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
||||
OutType, Shape<_128, _128, _128>, Shape<_128, _1, _1>,
|
||||
Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
57
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp
Normal file
57
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp
Normal file
@ -0,0 +1,57 @@
|
||||
#include <torch/all.h>
|
||||
#include "cuda_utils.h"
|
||||
|
||||
template <typename Fp8Func, typename Int8Func, typename BlockwiseFunc>
|
||||
void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias,
|
||||
Fp8Func fp8_func, Int8Func int8_func,
|
||||
BlockwiseFunc blockwise_func) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
int M = a.size(0), N = b.size(1), K = a.size(1);
|
||||
|
||||
if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
|
||||
(b_scales.numel() == 1 || b_scales.numel() == b.size(1))) {
|
||||
// Standard per-tensor/per-token/per-channel scaling
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (a.dtype() == torch::kFloat8_e4m3fn) {
|
||||
fp8_func(c, a, b, a_scales, b_scales, bias);
|
||||
} else {
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
if constexpr (!std::is_same_v<Int8Func, std::nullptr_t>) {
|
||||
int8_func(c, a, b, a_scales, b_scales, bias);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Int8 not supported for this architecture");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
using GroupShape = std::array<int64_t, 2>;
|
||||
auto make_group_shape = [](torch::Tensor const& x,
|
||||
torch::Tensor const& s) -> GroupShape {
|
||||
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
|
||||
return {cuda_utils::ceil_div(x.size(0), s.size(0)),
|
||||
cuda_utils::ceil_div(x.size(1), s.size(1))};
|
||||
};
|
||||
|
||||
GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
|
||||
GroupShape b_scale_group_shape = make_group_shape(b, b_scales);
|
||||
|
||||
// 1x128 per-token group scales for activations
|
||||
// 128x128 blockwise scales for weights
|
||||
TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
|
||||
b_scale_group_shape == GroupShape{128, 128} &&
|
||||
a.dtype() == torch::kFloat8_e4m3fn &&
|
||||
b.dtype() == torch::kFloat8_e4m3fn),
|
||||
"cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
|
||||
"a_scale_group_shape must be [1, 128]. Got: [",
|
||||
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
|
||||
"]\n"
|
||||
"b_scale_group_shape must be [128, 128]. Got: [",
|
||||
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
|
||||
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
|
||||
blockwise_func(c, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
@ -36,4 +36,9 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
} // namespace vllm
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
#include <cudaTypedefs.h>
|
||||
#include "c3x/scaled_mm_helper.hpp"
|
||||
#include "c3x/scaled_mm_kernels.hpp"
|
||||
|
||||
#include "cuda_utils.h"
|
||||
|
||||
/*
|
||||
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
||||
NVIDIA GPUs with sm100 (Blackwell).
|
||||
@ -15,20 +13,10 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
int M = a.size(0), N = b.size(1), K = a.size(1);
|
||||
TORCH_CHECK(
|
||||
(a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
|
||||
(b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
|
||||
"Currently, block scaled fp8 gemm is not implemented for Blackwell");
|
||||
|
||||
// Standard per-tensor/per-token/per-channel scaling
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
|
||||
"Currently, only fp8 gemm is implemented for Blackwell");
|
||||
vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias);
|
||||
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
|
||||
vllm::cutlass_scaled_mm_sm100_fp8,
|
||||
nullptr, // int8 not supported on SM100
|
||||
vllm::cutlass_scaled_mm_blockwise_sm100_fp8);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
#include <cudaTypedefs.h>
|
||||
#include "c3x/scaled_mm_helper.hpp"
|
||||
#include "c3x/scaled_mm_kernels.hpp"
|
||||
|
||||
#include "cuda_utils.h"
|
||||
|
||||
/*
|
||||
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
||||
NVIDIA GPUs with sm90a (Hopper).
|
||||
@ -15,49 +13,10 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
int M = a.size(0), N = b.size(1), K = a.size(1);
|
||||
|
||||
if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
|
||||
(b_scales.numel() == 1 || b_scales.numel() == b.size(1))) {
|
||||
// Standard per-tensor/per-token/per-channel scaling
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (a.dtype() == torch::kFloat8_e4m3fn) {
|
||||
vllm::cutlass_scaled_mm_sm90_fp8(c, a, b, a_scales, b_scales, bias);
|
||||
} else {
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
vllm::cutlass_scaled_mm_sm90_int8(c, a, b, a_scales, b_scales, bias);
|
||||
}
|
||||
} else {
|
||||
using GroupShape = std::array<int64_t, 2>;
|
||||
auto make_group_shape = [](torch::Tensor const& x,
|
||||
torch::Tensor const& s) -> GroupShape {
|
||||
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
|
||||
return {cuda_utils::ceil_div(x.size(0), s.size(0)),
|
||||
cuda_utils::ceil_div(x.size(1), s.size(1))};
|
||||
};
|
||||
|
||||
GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
|
||||
GroupShape b_scale_group_shape = make_group_shape(b, b_scales);
|
||||
|
||||
// 1x128 per-token group scales for activations
|
||||
// 128x128 blockwise scales for weights
|
||||
TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
|
||||
b_scale_group_shape == GroupShape{128, 128} &&
|
||||
a.dtype() == torch::kFloat8_e4m3fn &&
|
||||
b.dtype() == torch::kFloat8_e4m3fn),
|
||||
"cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
|
||||
"a_scale_group_shape must be [1, 128]. Got: [",
|
||||
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
|
||||
"]\n"
|
||||
"b_scale_group_shape must be [128, 128]. Got: [",
|
||||
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
|
||||
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
|
||||
|
||||
vllm::cutlass_scaled_mm_blockwise_sm90_fp8(c, a, b, a_scales, b_scales);
|
||||
}
|
||||
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
|
||||
vllm::cutlass_scaled_mm_sm90_fp8,
|
||||
vllm::cutlass_scaled_mm_sm90_int8,
|
||||
vllm::cutlass_scaled_mm_blockwise_sm90_fp8);
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
|
||||
|
||||
@ -110,6 +110,8 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
|
||||
#if defined CUDA_VERSION
|
||||
if (cuda_device_capability >= 90 && cuda_device_capability < 100) {
|
||||
return CUDA_VERSION >= 12000;
|
||||
} else if (cuda_device_capability >= 100) {
|
||||
return CUDA_VERSION >= 12080;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
@ -95,7 +95,7 @@ def cutlass_fp8_gemm_helper(m: int,
|
||||
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=5e-2)
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=1.5e-1)
|
||||
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm,
|
||||
(out, a, b, scale_a, scale_b, bias))
|
||||
@ -161,6 +161,8 @@ def test_cutlass_fp8_blockwise_scale_gemm(m: int, n: int, k: int,
|
||||
return
|
||||
if m % a_scale_group_shape[0] != 0 or k % a_scale_group_shape[1] != 0:
|
||||
return
|
||||
if m % 4 != 0 and current_platform.has_device_capability(100):
|
||||
return
|
||||
cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
|
||||
use_bias)
|
||||
|
||||
|
||||
@ -57,6 +57,16 @@ def apply_w8a8_block_fp8_linear(
|
||||
or br not in (1, weight.shape[0])):
|
||||
shape_supported_by_cutlass = False
|
||||
if cutlass_block_fp8_supported and shape_supported_by_cutlass:
|
||||
rows, cols = input_2d.shape
|
||||
# Blackwell GPUs (SM100) require row dimensions to be multiple of 4 for
|
||||
# optimal tensor core usage. Can be removed when targeting platforms
|
||||
# without this constraint.
|
||||
should_pad = current_platform.has_device_capability(
|
||||
100) and rows % 4 != 0
|
||||
if should_pad:
|
||||
input_2d = torch.nn.functional.pad(input_2d,
|
||||
(0, 0, 0, 4 - (rows % 4)),
|
||||
value=0).contiguous()
|
||||
q_input, x_scale = per_token_group_quant_fp8(input_2d,
|
||||
block_size[1],
|
||||
column_major_scales=True)
|
||||
@ -65,6 +75,8 @@ def apply_w8a8_block_fp8_linear(
|
||||
out_dtype=input.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale.T)
|
||||
if should_pad:
|
||||
output = output[:rows, :]
|
||||
else:
|
||||
q_input, x_scale = per_token_group_quant_fp8(input_2d,
|
||||
block_size[1],
|
||||
|
||||
Reference in New Issue
Block a user