This commit is contained in:
Lain
2025-10-15 11:46:38 -07:00
committed by GitHub
parent c6aeb9179c
commit e6e2cc29f5
3 changed files with 24 additions and 19 deletions

View File

@ -347,7 +347,7 @@ struct LayoutAwareConvertImpl<
// Specialization for INT8 -> BF16 with [3120] value order
template <>
struct LayoutAwareConvertImpl<
cutlass::int8_t,
int8_t,
cutlass::bfloat16_t,
cute::Layout<cute::Shape<_2,_2>, cute::Stride<_2,_1>>,
cute::Layout<_4>
@ -362,9 +362,9 @@ struct LayoutAwareConvertImpl<
cute::Layout<_4>
>& dst) {
static_assert(cute::is_same_v<cutlass::int8_t, typename EngineIn::value_type> &&
static_assert(cute::is_same_v<int8_t, typename EngineIn::value_type> &&
cute::is_same_v<cutlass::bfloat16_t, typename EngineOut::value_type>);
using SrcArray = cutlass::Array<cutlass::int8_t, 8>;
using SrcArray = cutlass::Array<int8_t, 8>;
using DstArray = cutlass::Array<cutlass::bfloat16_t, 8>;
using RegArray = cutlass::AlignedArray<uint32_t, 4, sizeof(DstArray)>;
@ -402,7 +402,7 @@ struct LayoutAwareConvertImpl<
// Specialization for INT8 -> FP16 with [3120] value order
template <>
struct LayoutAwareConvertImpl<
cutlass::int8_t,
int8_t,
cutlass::half_t,
cute::Layout<cute::Shape<_2,_2>, cute::Stride<_2,_1>>,
cute::Layout<_4>
@ -417,9 +417,9 @@ struct LayoutAwareConvertImpl<
cute::Layout<_4>
>& dst) {
static_assert(cute::is_same_v<cutlass::int8_t, typename EngineIn::value_type> &&
static_assert(cute::is_same_v<int8_t, typename EngineIn::value_type> &&
cute::is_same_v<cutlass::half_t, typename EngineOut::value_type>);
using SrcArray = cutlass::Array<cutlass::int8_t, 8>;
using SrcArray = cutlass::Array<int8_t, 8>;
using DstArray = cutlass::Array<cutlass::half_t, 8>;
using RegArray = cutlass::AlignedArray<uint32_t, 4, sizeof(DstArray)>;

View File

@ -52,21 +52,21 @@ namespace cutlass::gemm::collective {
namespace detail {
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
template<int capacity_bytes, class ElementA, class ElementB, class TileShapeMNK, int stages, int alignment = 128>
template<int capacity_bytes, class ElementA, class ElementB, class TileShapeMNK, int alignment = 128, int stages>
constexpr int
compute_stage_count_or_override(StageCount<stages> stage_count) {
return stages;
}
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
template<int capacity_bytes, class ElementA, class ElementB, class TileShapeMNK, int stages, int alignment = 128>
template<int capacity_bytes, class ElementA, class ElementB, class TileShapeMNK, int alignment = 128, int stages>
constexpr int
compute_stage_count_or_override(cute::Int<stages> stage_count) {
return stages;
}
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
template<int capacity_bytes_, class ElementA, class ElementB, class TileShapeMNK, int carveout_bytes_, int alignment = 128>
template<int capacity_bytes_, class ElementA, class ElementB, class TileShapeMNK, int alignment = 128, int carveout_bytes_>
constexpr int
compute_stage_count_or_override(StageCountAutoCarveout<carveout_bytes_> stage_count) {
constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage);
@ -85,7 +85,7 @@ compute_stage_count_or_override(StageCountAutoCarveout<carveout_bytes_> stage_co
}
// Returns the maximum number of smem tiles that can be used with a given smem capacity in gemm of blockwise/groupwise scale.
template<int capacity_bytes_, class ElementA, class ElementB, class ElementBlockScale, class TileShapeMNK, int ScaleMsPerTile, int ScaleNsPerTile, int carveout_bytes_, int alignment = 128>
template<int capacity_bytes_, class ElementA, class ElementB, class ElementBlockScale, class TileShapeMNK, int ScaleMsPerTile, int ScaleNsPerTile, int alignment = 128, int carveout_bytes_>
constexpr int
compute_stage_count_with_blockwise_scale(StageCountAutoCarveout<carveout_bytes_> stage_count) {
constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage);
@ -107,7 +107,14 @@ compute_stage_count_with_blockwise_scale(StageCountAutoCarveout<carveout_bytes_>
}
// Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count.
template<int capacity_bytes, class ElementA, class ElementB, class ElementScale, class ElementZero, class TileShapeMNK, int stages, int alignment = 128>
template<int capacity_bytes, class ElementA, class ElementB, class ElementScale, class ElementZero, class TileShapeMNK, int alignment = 128, int stages>
constexpr int
compute_stage_count_or_override_single_affine_transformed_input(cute::Int<stages> stage_count) {
return stages;
}
// Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count.
template<int capacity_bytes, class ElementA, class ElementB, class ElementScale, class ElementZero, class TileShapeMNK, int alignment = 128, int stages>
constexpr int
compute_stage_count_or_override_single_affine_transformed_input(StageCount<stages> stage_count) {
return stages;
@ -124,7 +131,7 @@ constexpr int get_bits_for_possibly_void_element() {
}
// Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count.
template<int capacity_bytes_, class ElementA, class ElementB, class ElementScale, class ElementZero, class TileShapeMNK, int carveout_bytes_, int alignment = 128>
template<int capacity_bytes_, class ElementA, class ElementB, class ElementScale, class ElementZero, class TileShapeMNK, int alignment = 128, int carveout_bytes_>
constexpr int
compute_stage_count_or_override_single_affine_transformed_input(StageCountAutoCarveout<carveout_bytes_> stage_count) {
@ -456,12 +463,12 @@ public:
static constexpr int PipelineStages = IsMixedInput ?
( IsArrayOfPointersGemm ?
detail::compute_stage_count_or_override_single_affine_transformed_input<Sm90ReducedSmemCapacityBytes,
RealElementA, RealElementB, ElementScale, ElementZero, TileShape_MNK, StageCountType::bytes, SmemAlignment>(StageCountType{}) :
RealElementA, RealElementB, ElementScale, ElementZero, TileShape_MNK, SmemAlignment>(StageCountType{}) :
detail::compute_stage_count_or_override_single_affine_transformed_input<detail::sm90_smem_capacity_bytes,
RealElementA, RealElementB, ElementScale, ElementZero, TileShape_MNK, StageCountType::bytes, SmemAlignment>(StageCountType{})
RealElementA, RealElementB, ElementScale, ElementZero, TileShape_MNK, SmemAlignment>(StageCountType{})
)
: detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes,
ElementAMma, ElementBMma, TileShape_MNK, StageCountType::bytes, SmemAlignment>(StageCountType{});
ElementAMma, ElementBMma, TileShape_MNK, SmemAlignment>(StageCountType{});
using DispatchPolicy = cute::conditional_t<IsMixedInput,
cute::conditional_t<IsArrayOfPointersGemm,

View File

@ -42,6 +42,7 @@
#include "cutlass/util/device_memory.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cute/util/type_traits.hpp"
#include "cute/numeric/numeric_types.hpp"
namespace cutlass {
@ -177,10 +178,7 @@ static void dequantize(DequantizedElement* dq_buffer,
template <typename T>
class packed_scale_t {
public:
static_assert(cute::is_same_v<T, cutlass::int8_t> ||
cute::is_same_v<T, cutlass::uint8_t> ||
cute::is_same_v<T, cutlass::float_e4m3_t> ||
cute::is_same_v<T, cutlass::float_e5m2_t>,
static_assert(cute::sizeof_bits_v<T> == 8,
"only 8 bit arithmetic types are supported.");
CUTLASS_HOST_DEVICE
explicit packed_scale_t(T val) {