[Build][Kernel] Update CUTLASS to v3.6.0 (#11607)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
committed by
GitHub
parent
628ec6c17b
commit
970d6d0776
@ -189,7 +189,7 @@ using Kernel_{{type_sig}} = MacheteKernelTemplate<
|
||||
{{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT
|
||||
{{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT
|
||||
{{DataTypeTag[t.a_token_scale]}}, // TokenScaleT
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative,
|
||||
Sch>;
|
||||
|
||||
{% for sch in schs %}
|
||||
@ -223,7 +223,7 @@ torch::Tensor prepack_B_dispatch(PrepackBArgs args) {
|
||||
{{DataTypeTag[t.convert]}}, // ElementConvert
|
||||
{{DataTypeTag[t.accumulator]}}, // Accumulator
|
||||
cutlass::layout::ColumnMajor,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput>
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative>
|
||||
>(args.B);
|
||||
}
|
||||
{%- endfor %}
|
||||
@ -239,7 +239,7 @@ torch::Tensor prepack_B_dispatch(PrepackBArgs args) {
|
||||
}; // namespace machete
|
||||
"""
|
||||
|
||||
TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput
|
||||
TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperative
|
||||
TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative
|
||||
|
||||
|
||||
@ -300,7 +300,7 @@ def generate_sch_sig(schedule_config: ScheduleConfig) -> str:
|
||||
# mostly unique shorter sch_sig
|
||||
def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str:
|
||||
kernel_terse_names_replace = {
|
||||
"KernelTmaWarpSpecializedCooperativeMixedInput_": "TmaMI_",
|
||||
"KernelTmaWarpSpecializedCooperative": "TmaMI_",
|
||||
"TmaWarpSpecializedCooperative_": "TmaCoop_",
|
||||
"StreamKScheduler": "streamK",
|
||||
}
|
||||
|
||||
@ -18,16 +18,14 @@ struct VLLMCollectiveBuilder<
|
||||
ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType,
|
||||
KernelScheduleType,
|
||||
cute::enable_if_t<(
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong> ||
|
||||
cute::is_same_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedMixedInput> ||
|
||||
cute::is_same_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedPingpongMixedInput> ||
|
||||
cute::is_same_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedCooperativeMixedInput>)>> {
|
||||
KernelTmaWarpSpecializedCooperative>)>> {
|
||||
using CollectiveOp = machete::MacheteCollectiveMma<
|
||||
ElementPairA_, GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_,
|
||||
AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK,
|
||||
StageCountType, KernelScheduleType>;
|
||||
};
|
||||
|
||||
}; // namespace cutlass::gemm::collective
|
||||
}; // namespace cutlass::gemm::collective
|
||||
|
||||
@ -66,13 +66,11 @@ struct MacheteCollectiveMma {
|
||||
using Schedule = KernelScheduleType;
|
||||
static_assert(
|
||||
cute::is_same_v<Schedule, KernelTmaWarpSpecialized> ||
|
||||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedMixedInput> ||
|
||||
cute::is_same_v<Schedule, KernelTmaWarpSpecialized> ||
|
||||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedPingpong> ||
|
||||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedPingpong> ||
|
||||
cute::is_same_v<Schedule,
|
||||
KernelTmaWarpSpecializedPingpongMixedInput> ||
|
||||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedCooperative> ||
|
||||
cute::is_same_v<Schedule,
|
||||
KernelTmaWarpSpecializedCooperativeMixedInput>,
|
||||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedCooperative>,
|
||||
"KernelSchedule must be one of the warp specialized policies");
|
||||
|
||||
public:
|
||||
@ -113,8 +111,7 @@ struct MacheteCollectiveMma {
|
||||
// For coop schedules we have two warp groups cooperatively issuing wgmma
|
||||
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
|
||||
using AtomLayoutMNK = cute::conditional_t<
|
||||
cute::is_same_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedCooperativeMixedInput>,
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative>,
|
||||
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(
|
||||
|
||||
@ -98,8 +98,7 @@ struct PrepackedLayoutBTemplate {
|
||||
// For coop schedules we have two warp groups cooperatively issuing wgmma
|
||||
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
|
||||
using AtomLayoutMNK = cute::conditional_t<
|
||||
cute::is_same_v<KernelSchedule,
|
||||
KernelTmaWarpSpecializedCooperativeMixedInput>,
|
||||
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperative>,
|
||||
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(
|
||||
@ -247,4 +246,4 @@ struct PrepackedLayoutBTemplate {
|
||||
}
|
||||
};
|
||||
|
||||
}; // namespace machete
|
||||
}; // namespace machete
|
||||
|
||||
Reference in New Issue
Block a user