[Build][Kernel] Update CUTLASS to v3.6.0 (#11607)

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Tyler Michael Smith
2024-12-30 04:22:13 -05:00
committed by GitHub
parent 628ec6c17b
commit 970d6d0776
6 changed files with 25 additions and 31 deletions

View File

@ -189,7 +189,7 @@ using Kernel_{{type_sig}} = MacheteKernelTemplate<
{{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT
{{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT
{{DataTypeTag[t.a_token_scale]}}, // TokenScaleT
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput,
cutlass::gemm::KernelTmaWarpSpecializedCooperative,
Sch>;
{% for sch in schs %}
@ -223,7 +223,7 @@ torch::Tensor prepack_B_dispatch(PrepackBArgs args) {
{{DataTypeTag[t.convert]}}, // ElementConvert
{{DataTypeTag[t.accumulator]}}, // Accumulator
cutlass::layout::ColumnMajor,
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput>
cutlass::gemm::KernelTmaWarpSpecializedCooperative>
>(args.B);
}
{%- endfor %}
@ -239,7 +239,7 @@ torch::Tensor prepack_B_dispatch(PrepackBArgs args) {
}; // namespace machete
"""
TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput
TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperative
TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative
@ -300,7 +300,7 @@ def generate_sch_sig(schedule_config: ScheduleConfig) -> str:
# mostly unique shorter sch_sig
def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str:
kernel_terse_names_replace = {
"KernelTmaWarpSpecializedCooperativeMixedInput_": "TmaMI_",
"KernelTmaWarpSpecializedCooperative": "TmaMI_",
"TmaWarpSpecializedCooperative_": "TmaCoop_",
"StreamKScheduler": "streamK",
}

View File

@ -18,16 +18,14 @@ struct VLLMCollectiveBuilder<
ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType,
KernelScheduleType,
cute::enable_if_t<(
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong> ||
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedMixedInput> ||
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedPingpongMixedInput> ||
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedCooperativeMixedInput>)>> {
KernelTmaWarpSpecializedCooperative>)>> {
using CollectiveOp = machete::MacheteCollectiveMma<
ElementPairA_, GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_,
AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK,
StageCountType, KernelScheduleType>;
};
}; // namespace cutlass::gemm::collective
}; // namespace cutlass::gemm::collective

View File

@ -66,13 +66,11 @@ struct MacheteCollectiveMma {
using Schedule = KernelScheduleType;
static_assert(
cute::is_same_v<Schedule, KernelTmaWarpSpecialized> ||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedMixedInput> ||
cute::is_same_v<Schedule, KernelTmaWarpSpecialized> ||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedPingpong> ||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedPingpong> ||
cute::is_same_v<Schedule,
KernelTmaWarpSpecializedPingpongMixedInput> ||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedCooperative> ||
cute::is_same_v<Schedule,
KernelTmaWarpSpecializedCooperativeMixedInput>,
cute::is_same_v<Schedule, KernelTmaWarpSpecializedCooperative>,
"KernelSchedule must be one of the warp specialized policies");
public:
@ -113,8 +111,7 @@ struct MacheteCollectiveMma {
// For coop schedules we have two warp groups cooperatively issuing wgmma
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
using AtomLayoutMNK = cute::conditional_t<
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedCooperativeMixedInput>,
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative>,
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
using TiledMma = decltype(cute::make_tiled_mma(

View File

@ -98,8 +98,7 @@ struct PrepackedLayoutBTemplate {
// For coop schedules we have two warp groups cooperatively issuing wgmma
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
using AtomLayoutMNK = cute::conditional_t<
cute::is_same_v<KernelSchedule,
KernelTmaWarpSpecializedCooperativeMixedInput>,
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperative>,
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
using TiledMma = decltype(cute::make_tiled_mma(
@ -247,4 +246,4 @@ struct PrepackedLayoutBTemplate {
}
};
}; // namespace machete
}; // namespace machete