fix (#2684)
This commit is contained in:
@ -347,7 +347,7 @@ struct LayoutAwareConvertImpl<
|
||||
// Specialization for INT8 -> BF16 with [3120] value order
|
||||
template <>
|
||||
struct LayoutAwareConvertImpl<
|
||||
cutlass::int8_t,
|
||||
int8_t,
|
||||
cutlass::bfloat16_t,
|
||||
cute::Layout<cute::Shape<_2,_2>, cute::Stride<_2,_1>>,
|
||||
cute::Layout<_4>
|
||||
@ -362,9 +362,9 @@ struct LayoutAwareConvertImpl<
|
||||
cute::Layout<_4>
|
||||
>& dst) {
|
||||
|
||||
static_assert(cute::is_same_v<cutlass::int8_t, typename EngineIn::value_type> &&
|
||||
static_assert(cute::is_same_v<int8_t, typename EngineIn::value_type> &&
|
||||
cute::is_same_v<cutlass::bfloat16_t, typename EngineOut::value_type>);
|
||||
using SrcArray = cutlass::Array<cutlass::int8_t, 8>;
|
||||
using SrcArray = cutlass::Array<int8_t, 8>;
|
||||
using DstArray = cutlass::Array<cutlass::bfloat16_t, 8>;
|
||||
using RegArray = cutlass::AlignedArray<uint32_t, 4, sizeof(DstArray)>;
|
||||
|
||||
@ -402,7 +402,7 @@ struct LayoutAwareConvertImpl<
|
||||
// Specialization for INT8 -> FP16 with [3120] value order
|
||||
template <>
|
||||
struct LayoutAwareConvertImpl<
|
||||
cutlass::int8_t,
|
||||
int8_t,
|
||||
cutlass::half_t,
|
||||
cute::Layout<cute::Shape<_2,_2>, cute::Stride<_2,_1>>,
|
||||
cute::Layout<_4>
|
||||
@ -417,9 +417,9 @@ struct LayoutAwareConvertImpl<
|
||||
cute::Layout<_4>
|
||||
>& dst) {
|
||||
|
||||
static_assert(cute::is_same_v<cutlass::int8_t, typename EngineIn::value_type> &&
|
||||
static_assert(cute::is_same_v<int8_t, typename EngineIn::value_type> &&
|
||||
cute::is_same_v<cutlass::half_t, typename EngineOut::value_type>);
|
||||
using SrcArray = cutlass::Array<cutlass::int8_t, 8>;
|
||||
using SrcArray = cutlass::Array<int8_t, 8>;
|
||||
using DstArray = cutlass::Array<cutlass::half_t, 8>;
|
||||
using RegArray = cutlass::AlignedArray<uint32_t, 4, sizeof(DstArray)>;
|
||||
|
||||
|
||||
@ -52,21 +52,21 @@ namespace cutlass::gemm::collective {
|
||||
namespace detail {
|
||||
|
||||
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
|
||||
template<int capacity_bytes, class ElementA, class ElementB, class TileShapeMNK, int stages, int alignment = 128>
|
||||
template<int capacity_bytes, class ElementA, class ElementB, class TileShapeMNK, int alignment = 128, int stages>
|
||||
constexpr int
|
||||
compute_stage_count_or_override(StageCount<stages> stage_count) {
|
||||
return stages;
|
||||
}
|
||||
|
||||
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
|
||||
template<int capacity_bytes, class ElementA, class ElementB, class TileShapeMNK, int stages, int alignment = 128>
|
||||
template<int capacity_bytes, class ElementA, class ElementB, class TileShapeMNK, int alignment = 128, int stages>
|
||||
constexpr int
|
||||
compute_stage_count_or_override(cute::Int<stages> stage_count) {
|
||||
return stages;
|
||||
}
|
||||
|
||||
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
|
||||
template<int capacity_bytes_, class ElementA, class ElementB, class TileShapeMNK, int carveout_bytes_, int alignment = 128>
|
||||
template<int capacity_bytes_, class ElementA, class ElementB, class TileShapeMNK, int alignment = 128, int carveout_bytes_>
|
||||
constexpr int
|
||||
compute_stage_count_or_override(StageCountAutoCarveout<carveout_bytes_> stage_count) {
|
||||
constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage);
|
||||
@ -85,7 +85,7 @@ compute_stage_count_or_override(StageCountAutoCarveout<carveout_bytes_> stage_co
|
||||
}
|
||||
|
||||
// Returns the maximum number of smem tiles that can be used with a given smem capacity in gemm of blockwise/groupwise scale.
|
||||
template<int capacity_bytes_, class ElementA, class ElementB, class ElementBlockScale, class TileShapeMNK, int ScaleMsPerTile, int ScaleNsPerTile, int carveout_bytes_, int alignment = 128>
|
||||
template<int capacity_bytes_, class ElementA, class ElementB, class ElementBlockScale, class TileShapeMNK, int ScaleMsPerTile, int ScaleNsPerTile, int alignment = 128, int carveout_bytes_>
|
||||
constexpr int
|
||||
compute_stage_count_with_blockwise_scale(StageCountAutoCarveout<carveout_bytes_> stage_count) {
|
||||
constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage);
|
||||
@ -107,7 +107,14 @@ compute_stage_count_with_blockwise_scale(StageCountAutoCarveout<carveout_bytes_>
|
||||
}
|
||||
|
||||
// Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count.
|
||||
template<int capacity_bytes, class ElementA, class ElementB, class ElementScale, class ElementZero, class TileShapeMNK, int stages, int alignment = 128>
|
||||
template<int capacity_bytes, class ElementA, class ElementB, class ElementScale, class ElementZero, class TileShapeMNK, int alignment = 128, int stages>
|
||||
constexpr int
|
||||
compute_stage_count_or_override_single_affine_transformed_input(cute::Int<stages> stage_count) {
|
||||
return stages;
|
||||
}
|
||||
|
||||
// Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count.
|
||||
template<int capacity_bytes, class ElementA, class ElementB, class ElementScale, class ElementZero, class TileShapeMNK, int alignment = 128, int stages>
|
||||
constexpr int
|
||||
compute_stage_count_or_override_single_affine_transformed_input(StageCount<stages> stage_count) {
|
||||
return stages;
|
||||
@ -124,7 +131,7 @@ constexpr int get_bits_for_possibly_void_element() {
|
||||
}
|
||||
|
||||
// Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count.
|
||||
template<int capacity_bytes_, class ElementA, class ElementB, class ElementScale, class ElementZero, class TileShapeMNK, int carveout_bytes_, int alignment = 128>
|
||||
template<int capacity_bytes_, class ElementA, class ElementB, class ElementScale, class ElementZero, class TileShapeMNK, int alignment = 128, int carveout_bytes_>
|
||||
constexpr int
|
||||
compute_stage_count_or_override_single_affine_transformed_input(StageCountAutoCarveout<carveout_bytes_> stage_count) {
|
||||
|
||||
@ -456,12 +463,12 @@ public:
|
||||
static constexpr int PipelineStages = IsMixedInput ?
|
||||
( IsArrayOfPointersGemm ?
|
||||
detail::compute_stage_count_or_override_single_affine_transformed_input<Sm90ReducedSmemCapacityBytes,
|
||||
RealElementA, RealElementB, ElementScale, ElementZero, TileShape_MNK, StageCountType::bytes, SmemAlignment>(StageCountType{}) :
|
||||
RealElementA, RealElementB, ElementScale, ElementZero, TileShape_MNK, SmemAlignment>(StageCountType{}) :
|
||||
detail::compute_stage_count_or_override_single_affine_transformed_input<detail::sm90_smem_capacity_bytes,
|
||||
RealElementA, RealElementB, ElementScale, ElementZero, TileShape_MNK, StageCountType::bytes, SmemAlignment>(StageCountType{})
|
||||
RealElementA, RealElementB, ElementScale, ElementZero, TileShape_MNK, SmemAlignment>(StageCountType{})
|
||||
)
|
||||
: detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes,
|
||||
ElementAMma, ElementBMma, TileShape_MNK, StageCountType::bytes, SmemAlignment>(StageCountType{});
|
||||
ElementAMma, ElementBMma, TileShape_MNK, SmemAlignment>(StageCountType{});
|
||||
|
||||
using DispatchPolicy = cute::conditional_t<IsMixedInput,
|
||||
cute::conditional_t<IsArrayOfPointersGemm,
|
||||
|
||||
@ -42,6 +42,7 @@
|
||||
#include "cutlass/util/device_memory.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cute/util/type_traits.hpp"
|
||||
#include "cute/numeric/numeric_types.hpp"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
@ -177,10 +178,7 @@ static void dequantize(DequantizedElement* dq_buffer,
|
||||
template <typename T>
|
||||
class packed_scale_t {
|
||||
public:
|
||||
static_assert(cute::is_same_v<T, cutlass::int8_t> ||
|
||||
cute::is_same_v<T, cutlass::uint8_t> ||
|
||||
cute::is_same_v<T, cutlass::float_e4m3_t> ||
|
||||
cute::is_same_v<T, cutlass::float_e5m2_t>,
|
||||
static_assert(cute::sizeof_bits_v<T> == 8,
|
||||
"only 8 bit arithmetic types are supported.");
|
||||
CUTLASS_HOST_DEVICE
|
||||
explicit packed_scale_t(T val) {
|
||||
|
||||
Reference in New Issue
Block a user