Add Blackwell MLA forward (shape: d=192, dv=128) implementation in example_77 (#2472)

This commit is contained in:
zhang
2025-07-18 13:27:48 +08:00
committed by GitHub
parent ebe98c549a
commit 9baa06dd57
13 changed files with 3323 additions and 40 deletions

View File

@ -853,7 +853,7 @@ struct FwdRunner {
flops *= static_cast<double>(size<1>(problem_shape));
flops *= static_cast<double>(size<3,1>(problem_shape));
}
flops *= 4.0 * (std::is_same_v<ActiveMask, CausalMask> ? 0.5 : 1.0);
flops *= 4.0 * (std::is_same_v<ActiveMask, CausalMask<true>> || std::is_same_v<ActiveMask, CausalMask<false>> ? 0.5 : 1.0);
flops *= static_cast<double>(size<2>(problem_shape));
flops *= static_cast<double>(size<3,0>(problem_shape));
double tflops_s = flops * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/);

File diff suppressed because it is too large Load Diff

View File

@ -33,6 +33,7 @@ set_property(
77_blackwell_fmha_gen.cu
77_blackwell_mla.cu
77_blackwell_fmha_bwd.cu
77_blackwell_mla_fwd.cu
PROPERTY
COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0"
)
@ -59,6 +60,22 @@ set(TEST_VARLEN_12 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4
set(TEST_VARLEN_13 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=2 --varlen-q=177:366:479 --varlen-k=257:0:766)
set(TEST_VARLEN_14 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=1 --varlen-k=1)
set(TEST_MLA_FWD_VARLEN_00 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=4 --varlen-q=128 --varlen-k=128)
set(TEST_MLA_FWD_VARLEN_01 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
set(TEST_MLA_FWD_VARLEN_02 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=2 --varlen-q=128 --varlen-k=128)
set(TEST_MLA_FWD_VARLEN_03 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=8 --varlen-q=256:256 --varlen-k=512:512)
set(TEST_MLA_FWD_VARLEN_04 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=4 --varlen-q=256:256 --varlen-k=512:512)
set(TEST_MLA_FWD_VARLEN_05 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=1 --varlen-q=256:256 --varlen-k=512:512)
set(TEST_MLA_FWD_VARLEN_06 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=2 --varlen-q=256:256:256:256 --varlen-k=256:768:512:512)
set(TEST_MLA_FWD_VARLEN_07 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=2 --varlen-q=256:256:256:256 --varlen-k=256:0:1280:512)
set(TEST_MLA_FWD_VARLEN_08 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=2 --varlen-q=256:0:512:256 --varlen-k=256:256:1024:512)
set(TEST_MLA_FWD_VARLEN_09 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=16 --h_k=16 --varlen-q=100:300 --varlen-k=100:300)
set(TEST_MLA_FWD_VARLEN_10 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=4 --varlen-q=2:3 --varlen-k=2:5)
set(TEST_MLA_FWD_VARLEN_11 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=2 --varlen-q=11:10 --varlen-k=13:10)
set(TEST_MLA_FWD_VARLEN_12 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=4 --varlen-q=177:766 --varlen-k=257:845)
set(TEST_MLA_FWD_VARLEN_13 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=2 --varlen-q=177:0:479 --varlen-k=257:0:766)
set(TEST_MLA_FWD_VARLEN_14 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=4 --varlen-q=1 --varlen-k=1)
set(TEST_GEN_BASIC --b=1 --h=4 --k=512 --d=128 --verify)
set(TEST_GEN_VARLEN --b=1 --h=4 --k=512 --d=128 --verify --varlen)
set(TEST_GEN_HDIM64 --b=2 --h=4 --k=512 --d=64 --verify)
@ -161,6 +178,35 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
target_include_directories(77_blackwell_fmha_bwd_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_fmha_bwd_${PREC} PRIVATE ${PREC_MACRO})
target_compile_options(77_blackwell_fmha_bwd_${PREC} PRIVATE -Xptxas -v)
cutlass_example_add_executable(
77_blackwell_mla_fwd_${PREC}
77_blackwell_mla_fwd.cu
TEST_COMMAND_OPTIONS
TEST_BASIC
TEST_CAUSAL
TEST_VARLEN
TEST_HDIM64
TEST_GQA
TEST_MLA_FWD_VARLEN_00
TEST_MLA_FWD_VARLEN_01
TEST_MLA_FWD_VARLEN_02
TEST_MLA_FWD_VARLEN_03
TEST_MLA_FWD_VARLEN_04
TEST_MLA_FWD_VARLEN_05
TEST_MLA_FWD_VARLEN_06
TEST_MLA_FWD_VARLEN_07
TEST_MLA_FWD_VARLEN_08
TEST_MLA_FWD_VARLEN_09
TEST_MLA_FWD_VARLEN_10
TEST_MLA_FWD_VARLEN_11
TEST_MLA_FWD_VARLEN_12
TEST_MLA_FWD_VARLEN_13
TEST_MLA_FWD_VARLEN_14
)
target_include_directories(77_blackwell_mla_fwd_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_mla_fwd_${PREC} PRIVATE ${PREC_MACRO})
target_compile_options(77_blackwell_mla_fwd_${PREC} PRIVATE -Xptxas -v)
endforeach()
# Add a target that builds all examples
@ -176,5 +222,7 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
77_blackwell_mla_2sm_cpasync_fp16
77_blackwell_fmha_bwd_fp8
77_blackwell_fmha_bwd_fp16
77_blackwell_mla_fwd_fp8
77_blackwell_mla_fwd_fp16
)
endif()

View File

@ -184,10 +184,16 @@ struct ResidualMaskForBackward : NoMask {
}
};
// There are two ways to do causal if N_Q != N_K
// (1) The Q is at the beginning of the matrix
// (2) The Q is at the end of the matrix
template<bool kIsQBegin = true>
struct CausalMask : NoMask {
using Base = NoMask;
static constexpr bool IsQBegin = kIsQBegin;
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_trip_count(
@ -197,9 +203,16 @@ struct CausalMask : NoMask {
// See note below on different ways to think about causal attention
// Again, we'd add the offset_q into the max_blocks_q calculation
int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size);
int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape));
return std::min(max_blocks_k, max_blocks_q);
if constexpr (IsQBegin) {
int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size);
int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape));
return std::min(max_blocks_k, max_blocks_q);
} else {
const int offset_q = get<1>(problem_size) - get<0>(problem_size);
int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size);
int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape) + offset_q, get<1>(tile_shape));
return std::min(max_blocks_k, max_blocks_q);
}
}
template<class BlkCoord, class TileShape, class ProblemSize>
@ -208,9 +221,14 @@ struct CausalMask : NoMask {
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
if constexpr (IsQBegin) {
int trip_count = get_trip_count(blk_coord, tile_shape, problem_size);
return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));
} else {
const int offset_tile_q = get<1>(problem_size) % get<1>(tile_shape);
return ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape));
}
}
template<class BlkCoord, class TileShape, class ProblemSize>
@ -232,26 +250,36 @@ struct CausalMask : NoMask {
// There are two ways to do causal if N_Q != N_K
// (1) is to assume that the Q is at the beginning of the matrix
// - this is what we demonstrate here
// - this is the default setting.
// (2) is that it is at the end of the matrix
// - this is usually what we want for inference settings
// where we only compute the next row and use cache for the rest
// - if you'd like this, you only need to add an offset like so:
// get<0>(pos) + offset_q < get<1>(pos)
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_qk); i++) {
auto pos = index_qk(i);
if ((get<0>(pos) < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) {
acc_qk(i) = -INFINITY;
// - if you'd like this, you only need to set kIsQBegin=false
if constexpr (IsQBegin) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_qk); i++) {
auto pos = index_qk(i);
if ((get<0>(pos) < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) {
acc_qk(i) = -INFINITY;
}
}
} else {
const auto offset_q = get<1>(problem_size) - get<0>(problem_size);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_qk); i++) {
auto pos = index_qk(i);
if ((get<0>(pos) + offset_q < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) {
acc_qk(i) = -INFINITY;
}
}
}
}
};
struct CausalForBackwardMask : CausalMask, ResidualMaskForBackward {
struct CausalForBackwardMask : CausalMask<true>, ResidualMaskForBackward {
using Base = CausalMask;
using Base = CausalMask<true>;
template<class AccQK, class IndexQK, class ProblemSize>
CUTLASS_DEVICE

View File

@ -42,7 +42,8 @@ template<
class ElementAcc,
class TileShape, // Q, D, _
class StrideO, // Q, D, B
class StrideLSE_ // Q, B
class StrideLSE_, // Q, B
class OrderLoadEpilogue = cute::false_type
>
struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
@ -56,7 +57,10 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
using SmemLayoutO_ = SmemLayoutO;
using StrideLSE = StrideLSE_;
using ElementOut = Element;
static const int NumWarpsEpilogue = 1;
static const int NumWarpsLoad = 1;
struct TensorStorage {
using SmemLayoutO = SmemLayoutO_;
@ -86,6 +90,19 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
StrideLSE dLSE;
};
// FMHA and MLA have different input ProblemShapes;
// get problem_shape_O according to the input ProblemShape.
template<class ProblemShape>
CUTLASS_DEVICE static constexpr
auto get_problem_shape_O (
ProblemShape const& problem_shape) {
if constexpr (rank_v<decltype(get<2>(ProblemShape{}))> == 2) {
return replace<1>(select<0,2,3>(problem_shape), get<2, 0>(problem_shape));
} else {
return select<0,2,3>(problem_shape);
}
}
template<class ProblemShape>
static Params to_underlying_arguments(
ProblemShape const& problem_shape,
@ -94,7 +111,8 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
auto ptr_O = args.ptr_O;
StrideO dO = args.dO;
auto problem_shape_O = select<0,2,3>(problem_shape);
auto problem_shape_O = get_problem_shape_O(problem_shape);
if constexpr (is_variable_length_v<tuple_element_t<0, ProblemShape>>) {
auto cumulative_length_q = get<0>(problem_shape).cumulative_length;
@ -146,7 +164,7 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
int o0_index = 2 * get<0>(blk_coord);
int o1_index = 2 * get<0>(blk_coord) + 1;
Tensor mO_qdl_p = params.tma_store_o.get_tma_tensor(select<0,2,3>(problem_shape));
Tensor mO_qdl_p = params.tma_store_o.get_tma_tensor(get_problem_shape_O(problem_shape));
// offset mode 0 by (max_length - real_length)
// offset mode 3,1 by cumulative_length + real_length
// the ptr is already offset by - max_length
@ -201,6 +219,11 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
tma_store_wait<0>();
if constexpr (cute::is_same_v<OrderLoadEpilogue, cute::true_type>) {
cutlass::arch::NamedBarrier::arrive((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
}
pipeline.consumer_release(pipeline_release_state);
++pipeline_release_state;

View File

@ -58,7 +58,9 @@ template<
// and referes to the two softmax warps
// (2, 1, 1) means that they are stacked (best for large Q since it loads the least K/V)
// (1, 2, 1) means they sit side by side (best for small Q / large K)
class ThreadShape = Shape<_2, _1, _1>
class ThreadShape = Shape<_2, _1, _1>,
// Since shared memory is sufficient for FMHA, there is no need to reuse shared memory.
class OrderLoadEpilogue = cute::false_type
>
struct Sm100FmhaFwdMainloopTmaWarpspecialized {
@ -106,6 +108,8 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
using SmemLayoutK = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutB{}, Int<StageCountKV>{}));
using SmemLayoutV = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutB{}, Int<StageCountKV>{}));
// Reuse shared memory for V and O.
static constexpr bool IsOrderLoadEpilogue = std::is_same_v<OrderLoadEpilogue, cute::true_type>;
struct TensorStorage {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
union {
@ -168,9 +172,10 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
static const int TransactionBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v<Element>);
static const int TransactionBytesLoadKV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v<Element>);
static const int TransactionBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v<Element>);
static const int TransactionBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v<Element>);
static_assert(cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v<Element>) == cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v<Element>), "K and V smem layouts must be of equal size");
static_assert(TransactionBytesLoadK == TransactionBytesLoadV, "K and V smem layouts must be of equal size");
using Load = Sm100FmhaLoadTmaWarpspecialized<
Element, StrideQ, StrideK, StrideV,
@ -525,7 +530,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1);
Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));
auto tilePlikeFP32 = get<1>(TileShapeQK{}) / Int<sizeof(float)>{} * Int<sizeof(Element)>{};
auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int<sizeof(float)>{} * Int<sizeof(Element)>{};
Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1));
Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));

View File

@ -0,0 +1,340 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/arch/memory_sm80.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cute/tensor.hpp"
#include "cute/layout.hpp"
#include "collective/fmha_common.hpp"
#include "collective/fmha_fusion.hpp"
namespace cutlass::fmha::collective {
using namespace cute;
template<
class Element,
class StrideQ,
class StrideK,
class StrideV,
class CollectiveMmaQK,
class CollectiveMmaPV,
class SmemLayoutQ,
class SmemLayoutK,
class SmemLayoutV,
class TensorStorage,
class PipelineQ,
class PipelineKV,
class Mask,
class TileShape,
class OrderLoadEpilogue = cute::false_type
>
struct Sm100MlaFwdLoadTmaWarpspecialized {
using TileShapeQK = typename CollectiveMmaQK::TileShape;
using TileShapePV = typename CollectiveMmaPV::TileShape;
static constexpr int TransactionBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v<Element>);
static constexpr int TransactionBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v<Element>);
static const int NumWarpsEpilogue = 1;
static const int NumWarpsLoad = 1;
struct Arguments {
const Element* ptr_Q;
StrideQ dQ;
const Element* ptr_K;
StrideK dK;
const Element* ptr_V;
StrideV dV;
};
using TMA_Q = typename CollectiveMmaQK::Params::TMA_A;
using TMA_K = typename CollectiveMmaQK::Params::TMA_B;
using TMA_V = typename CollectiveMmaPV::Params::TMA_B;
struct Params {
TMA_Q tma_load_q;
TMA_K tma_load_k;
TMA_V tma_load_v;
};
template<class ProblemShape>
static Params to_underlying_arguments(
ProblemShape const& problem_shape,
Arguments const& args,
void* workspace) {
auto ptr_Q = args.ptr_Q;
auto ptr_K = args.ptr_K;
auto ptr_V = args.ptr_V;
auto dQ = args.dQ;
auto dK = args.dK;
auto dV = args.dV;
auto problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape));
if constexpr (is_variable_length_v<tuple_element_t<0, ProblemShape>>) {
auto cumulative_length_q = get<0>(problem_shape).cumulative_length;
if (cumulative_length_q != nullptr) {
int max_length_q = get<0>(problem_shape).max_length;
// for variable sequence lenght, the batch is in units of row_stride
get<2,1>(dQ) = get<0>(dQ);
get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_q * (1 + get<3,1>(problem_shape)));
// offset ptr by the amount we add back in later
ptr_Q -= max_length_q * get<0>(dQ);
}
}
if constexpr (is_variable_length_v<tuple_element_t<1, ProblemShape>>) {
auto cumulative_length_kv = get<1>(problem_shape).cumulative_length;
if (cumulative_length_kv != nullptr) {
int max_length_kv = get<1>(problem_shape).max_length;
// for variable sequence lenght, the batch is in units of row_stride
get<2,1>(dK) = get<0>(dK);
get<2,1>(dV) = get<0>(dV);
get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_kv * (1 + get<3,1>(problem_shape)));
// offset ptr by the amount we add back in later
ptr_K -= max_length_kv * get<0>(dK);
ptr_V -= max_length_kv * get<0>(dV);
}
}
auto problem_shape_pv = replace<1>(select<0,2,1,3>(problem_shape_qk), get<2, 0>(problem_shape));
auto params_qk = CollectiveMmaQK::to_underlying_arguments(
problem_shape_qk,
typename CollectiveMmaQK::Arguments {
ptr_Q, dQ,
ptr_K, dK,
}, /*workspace=*/ nullptr);
auto params_pv = CollectiveMmaPV::to_underlying_arguments(
problem_shape_pv,
typename CollectiveMmaPV::Arguments {
ptr_K, dK, // never used, dummy
ptr_V, select<1,0,2>(dV),
}, /*workspace=*/ nullptr);
return Params{
params_qk.tma_load_a,
params_qk.tma_load_b,
params_pv.tma_load_b
};
}
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& params) {
cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor());
cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor());
cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor());
}
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
CUTLASS_DEVICE void
load(
BlkCoord const& blk_coord_in, ProblemShape const& problem_shape,
Params const& params, ParamsProblemShape const& params_problem_shape,
TensorStorage& storage,
PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state,
PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) {
BlkCoord blk_coord_q = blk_coord_in;
BlkCoord blk_coord_kv = blk_coord_in;
auto problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape));
auto problem_shape_v = replace<2>(problem_shape, get<2, 0>(problem_shape));
int mask_tile_count = Mask{}.get_trip_count(blk_coord_in, TileShape{}, problem_shape);
using X = Underscore;
// this one is only executed by one thread, no need to elect_one
// Q1, K1, Q2, V1, K2, V2, K3, V3, ...
// two pipes: Q and KV
// from Memory (prod) to TensorCore (cons)
// compute gQ, sQ
// we load 2*get<0>(blk_coord), and 2*get<0>(blk_coord) + 1
ThrMMA mma_qk = typename CollectiveMmaQK::TiledMma{}.get_slice(0);
Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape_qk));
int q_offs_0 = 0;
int q_offs_2_1 = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length;
if (cumulative_length_q != nullptr) {
int max_length_q = get<0>(params_problem_shape).max_length;
q_offs_0 = max_length_q - get<0>(problem_shape);
q_offs_2_1 = cumulative_length_q[get<2,1>(blk_coord_q)] + get<0>(problem_shape);
get<2,1>(blk_coord_q) = 0;
}
}
Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, q_offs_2_1)), mQ_qdl_p);
Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{});
Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl);
Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{});
auto [tQgQ_qdl, tQsQ] = tma_partition(
params.tma_load_q, _0{}, make_layout(_1{}),
group_modes<0,3>(sQ), group_modes<0,3>(tSgQ_qdl)
);
Tensor tQgQ = tQgQ_qdl(_, _, _0{}, get<2>(blk_coord_q));
// compute gK, sK
Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape_qk));
int kv_offs_0 = 0;
int kv_offs_2_1 = 0;
if constexpr (is_variable_length_v<tuple_element_t<1, ParamsProblemShape>>) {
auto cumulative_length = get<1>(params_problem_shape).cumulative_length;
if (cumulative_length != nullptr) {
int max_length = get<1>(params_problem_shape).max_length;
kv_offs_0 = max_length - get<1>(problem_shape);
kv_offs_2_1 = cumulative_length[get<2,1>(blk_coord_kv)] + get<1>(problem_shape);
get<2,1>(blk_coord_kv) = 0;
}
}
Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, kv_offs_2_1)), mK_kdl_p);
Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step<X, _1, _1>{});
Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl);
Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{});
auto [tKgK_kdl, tKsK] = tma_partition(
params.tma_load_k, _0{}, make_layout(_1{}),
group_modes<0,3>(sK), group_modes<0,3>(tSgK_kdl)
);
Tensor tKgK = tKgK_kdl(_, _, _0{}, get<2>(blk_coord_kv));
// compute gV, sV
ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0);
Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape_v));
Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, kv_offs_2_1)), mV_dkl_p);
Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step<X, _1, _1>{});
Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl);
Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{});
auto [tVgV_dkl, tVsV] = tma_partition(
params.tma_load_v, _0{}, make_layout(_1{}),
group_modes<0,3>(sV), group_modes<0,3>(tOgV_dkl)
);
auto tVgV = tVgV_dkl(_, _0{}, _, get<2>(blk_coord_kv));
// blk_coord in decomposed in terms of TileShape, not TileShapeQK
// As such, it needs to be transformed as
// (a,b,c): a -> 2*a (Q0) 2*a+1 (Q1)
// b -> 2*a (Ki i even) 2*a+1 (Ki i odd)
uint32_t lane_predicate = cute::elect_one_sync();
// Q1
int q0_index = 2 * get<0>(blk_coord_q);
int q1_index = 2 * get<0>(blk_coord_q) + 1;
pipeline_q.producer_acquire(pipeline_q_producer_state);
if (lane_predicate) {
auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state);
copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q0_index), tQsQ(_, pipeline_q_producer_state.index()));
}
++pipeline_q_producer_state;
// K1
int k_index = 0;
pipeline_kv.producer_acquire(pipeline_kv_producer_state);
if (lane_predicate) {
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index() / 2));
}
++pipeline_kv_producer_state;
// Q2
pipeline_q.producer_acquire(pipeline_q_producer_state);
if (lane_predicate) {
auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state);
copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q1_index), tQsQ(_, pipeline_q_producer_state.index()));
}
++pipeline_q_producer_state;
if constexpr (cute::is_same_v<OrderLoadEpilogue, cute::true_type>) {
cutlass::arch::NamedBarrier::sync((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
}
// V1
pipeline_kv.producer_acquire_bytes(pipeline_kv_producer_state, TransactionBytesLoadV);
if (lane_predicate) {
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index() / 2));
}
++pipeline_kv_producer_state;
k_index += 1;
// loop:
mask_tile_count -= 1;
for (; mask_tile_count > 0; mask_tile_count -= 1) {
// Ki
pipeline_kv.producer_acquire(pipeline_kv_producer_state);
if (lane_predicate) {
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index() / 2));
// prefetch vi
cute::prefetch(params.tma_load_v, tVgV(_, k_index));
}
++pipeline_kv_producer_state;
// Vi
pipeline_kv.producer_acquire_bytes(pipeline_kv_producer_state, TransactionBytesLoadV);
if (lane_predicate) {
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index() / 2));
// prefetch ki+1
if(mask_tile_count > 1) {
cute::prefetch(params.tma_load_k, tKgK(_, k_index + 1));
}
}
++pipeline_kv_producer_state;
k_index += 1;
}
}
};
} // namespace cutlass::fmha::collective

View File

@ -0,0 +1,250 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief Support the producer to acquire specific bytes of data.
*/
#pragma once
#include "cutlass/pipeline/sm100_pipeline.hpp"
namespace cutlass {
using namespace cute;
template <
int Stages_,
class ClusterShape = Shape<int,int,_1>,
class AtomThrShape_MNK_ = Shape<_1,_1,_1>
>
class PipelineTmaAsyncMla {
public:
static constexpr uint32_t Stages = Stages_;
using AtomThrShape_MNK = AtomThrShape_MNK_;
private:
using Impl = PipelineTmaUmmaAsync<Stages_, ClusterShape, AtomThrShape_MNK_>;
public:
using FullBarrier = typename Impl::FullBarrier;
using EmptyBarrier = typename Impl::EmptyBarrier;
using ProducerBarrierType = typename Impl::ProducerBarrierType;
using ConsumerBarrierType = typename Impl::ConsumerBarrierType;
using PipelineState = typename Impl::PipelineState;
using SharedStorage = typename Impl::SharedStorage;
using ThreadCategory = typename Impl::ThreadCategory;
using Params = typename Impl::Params;
using McastDirection = McastDirection;
// Helper function to initialize barriers
static
CUTLASS_DEVICE
void
init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape) {
int warp_idx = canonical_warp_idx_sync();
if (warp_idx == params.initializing_warp) {
// Barrier FULL and EMPTY init
constexpr int producer_arv_cnt = 1;
auto atom_thr_shape = AtomThrShape_MNK{};
uint32_t const multicast_consumer_arrival_count = (cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape)) +
(cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape)) - 1;
cutlass::arch::detail::initialize_barrier_array_pair_aligned<decltype(storage.full_barrier_), decltype(storage.empty_barrier_), Stages>(
storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count);
}
cutlass::arch::fence_barrier_init();
}
static
CUTLASS_DEVICE
void
init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction) {
auto atom_thr_shape = AtomThrShape_MNK{};
int warp_idx = canonical_warp_idx_sync();
if (warp_idx == params.initializing_warp) {
// Barrier FULL and EMPTY init
constexpr int producer_arv_cnt = 1;
uint32_t const multicast_consumer_arrival_count = (mcast_direction == McastDirection::kRow) ?
cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape) : // Mcast with row ctas
cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape); // Mcast with col ctas
cutlass::arch::detail::initialize_barrier_array_pair_aligned<decltype(storage.full_barrier_), decltype(storage.empty_barrier_), Stages>(
storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count);
}
cutlass::arch::fence_barrier_init();
}
CUTLASS_DEVICE
void init_masks(ClusterShape cluster_shape, dim3 block_id_in_cluster = cute::block_id_in_cluster()) {
// Calculate consumer mask
if (params_.role == ThreadCategory::Consumer) {
auto cluster_layout = make_layout(cluster_shape);
block_id_mask_ = detail::calculate_multicast_mask<McastDirection::kRowCol>(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster);
}
}
CUTLASS_DEVICE
void init_masks(ClusterShape cluster_shape, McastDirection mcast_direction) {
// Calculate consumer mask
dim3 block_id_in_cluster = cute::block_id_in_cluster();
auto cluster_layout = make_layout(cluster_shape);
if (mcast_direction == McastDirection::kRow) {
block_id_mask_ = detail::calculate_multicast_mask<McastDirection::kRow>(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster);
}
else {
block_id_mask_ = detail::calculate_multicast_mask<McastDirection::kCol>(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster);
}
}
public:
template<typename InitBarriers = cute::true_type, typename InitMasks = cute::true_type>
CUTLASS_DEVICE
PipelineTmaAsyncMla(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}, InitMasks = {})
: impl_(storage, params, cluster_shape, cute::false_type{}, cute::false_type{})
, params_(params)
, empty_barrier_ptr_(&storage.empty_barrier_[0])
, full_barrier_ptr_(&storage.full_barrier_[0]) {
static_assert(cute::is_same_v<InitBarriers, cute::true_type> || cute::is_same_v<InitBarriers, cute::false_type>);
if constexpr (cute::is_same_v<InitBarriers, cute::true_type>) {
init_barriers(storage, params_, cluster_shape);
}
static_assert(cute::is_same_v<InitMasks, cute::true_type> || cute::is_same_v<InitMasks, cute::false_type>);
if constexpr (cute::is_same_v<InitMasks, cute::true_type>) {
init_masks(cluster_shape);
}
}
template<typename InitBarriers = cute::true_type, typename InitMasks = cute::true_type>
CUTLASS_DEVICE
PipelineTmaAsyncMla(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction, InitBarriers = {}, InitMasks = {})
: impl_(storage, params, cluster_shape, cute::false_type{}, cute::false_type{})
, params_(params)
, empty_barrier_ptr_(&storage.empty_barrier_[0])
, full_barrier_ptr_(&storage.full_barrier_[0]) {
static_assert(cute::is_same_v<InitBarriers, cute::true_type> || cute::is_same_v<InitBarriers, cute::false_type>);
if constexpr (cute::is_same_v<InitBarriers, cute::true_type>) {
init_barriers(storage, params_, cluster_shape, mcast_direction);
}
static_assert(cute::is_same_v<InitMasks, cute::true_type> || cute::is_same_v<InitMasks, cute::false_type>);
if constexpr (cute::is_same_v<InitMasks, cute::true_type>) {
init_masks(cluster_shape, mcast_direction);
}
}
CUTLASS_DEVICE
void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) {
impl_.producer_acquire(state, barrier_token);
}
CUTLASS_DEVICE
void producer_acquire_bytes(uint32_t stage, uint32_t bytes, uint32_t phase, ProducerToken barrier_token) {
detail::pipeline_check_is_producer(params_.role);
if (barrier_token != BarrierStatus::WaitDone) {
empty_barrier_ptr_[stage].wait(phase);
}
if (params_.is_leader) {
full_barrier_ptr_[stage].arrive_and_expect_tx(bytes);
}
#ifndef NDEBUG
if (params_.role == ThreadCategory::Consumer || params_.role == ThreadCategory::NonParticipant) {
asm volatile ("brkpt;\n" ::);
}
// Most likely you have elected more than one leader
if (params_.is_leader && (threadIdx.x % 32 != 0)) {
asm volatile ("brkpt;\n" ::);
}
#endif
}
CUTLASS_DEVICE
void producer_acquire_bytes(PipelineState state, uint32_t bytes, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) {
producer_acquire_bytes(state.index(), bytes, state.phase(), barrier_token);
}
CUTLASS_DEVICE
ProducerBarrierType* producer_get_barrier(PipelineState state) {
return impl_.producer_get_barrier(state);
}
CUTLASS_DEVICE
void consumer_wait(PipelineState state, ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) {
impl_.consumer_wait(state, barrier_token);
}
CUTLASS_DEVICE
void consumer_release(PipelineState state) {
consumer_release(state.index(), false);
}
private:
Impl impl_;
Params params_;
EmptyBarrier *empty_barrier_ptr_;
FullBarrier *full_barrier_ptr_;
uint16_t block_id_mask_ = 0;
static constexpr bool is_2sm_mma = size(AtomThrShape_MNK{}) > 1;
// Consumer signalling Producer of completion
// Ensures all blocks in the Same Row and Column get notifed.
CUTLASS_DEVICE
void consumer_release(uint32_t stage, uint32_t skip) {
detail::pipeline_check_is_consumer(params_.role);
uint64_t* smem_ptr = reinterpret_cast<uint64_t*>(&empty_barrier_ptr_[stage]);
if constexpr (is_2sm_mma) { // Mma cluster shape is 2x1
if (!skip) {
cutlass::arch::umma_arrive_multicast_2x1SM(smem_ptr, block_id_mask_);
}
}
else {
if (!skip) {
if constexpr (cute::is_static_v<ClusterShape> and size(ClusterShape{}) == 1) {
cutlass::arch::umma_arrive(smem_ptr);
}
else {
cutlass::arch::umma_arrive_multicast(smem_ptr, block_id_mask_);
}
}
}
}
};
}

View File

@ -0,0 +1,197 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
namespace cutlass::fmha::kernel {
////////////////////////////////////////////////////////////////////////////////
// Swizzle Q tile and H tile to improve L2 cache hit rate,
// and launch the longest main loop first to keep most SMs busy.
struct CausalIndividualTileScheduler {
static constexpr int TileQ = 16;
static constexpr int TileH = 8;
static constexpr int TileSize = TileQ * TileH;
struct Params {
dim3 grid;
int tile_max_q;
FastDivmod divmod_tile_col;
FastDivmod divmod_tile_size;
FastDivmod divmod_tile_head;
};
bool valid_ = true;
Params params;
CUTLASS_DEVICE
CausalIndividualTileScheduler(Params const& params) : params(params) {}
template<class ProblemSize, class ClusterShape, class TileShape>
static Params to_underlying_arguments(
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
ClusterShape const& cluster_shape, TileShape const& tile_shape) {
using namespace cute;
dim3 grid(size<3,0>(problem_size), round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)), size<3,1>(problem_size));
// gridDim.x must multiple of TileH
const int tile_col_count = grid.x / TileH;
const int tile_max_q = grid.y / TileQ * TileQ;
return Params{ grid , tile_max_q, tile_col_count, TileSize, TileH};
}
static dim3 get_grid_shape(Params const& params) {
return params.grid;
}
CUTLASS_DEVICE
bool is_valid() {
return valid_;
}
CUTLASS_DEVICE
auto get_block_coord() {
using namespace cute;
const int block_idx = blockIdx.y * gridDim.x + blockIdx.x;
int tile_idx, tile_tail;
params.divmod_tile_size(tile_idx, tile_tail, block_idx);
int tile_row_idx, tile_col_idx;
params.divmod_tile_col(tile_row_idx,tile_col_idx, tile_idx);
int row_offset_in_tail, col_offset_in_tail;
params.divmod_tile_head(row_offset_in_tail,col_offset_in_tail, tile_tail);
const int row_idx = tile_row_idx * TileQ + row_offset_in_tail;
const int col_idx = tile_col_idx * TileH + col_offset_in_tail;
// last q tile launch first
if(blockIdx.y >= params.tile_max_q) {
return make_coord(int(gridDim.y - 1 - blockIdx.y), _0{}, make_coord(int(blockIdx.x), int(blockIdx.z)));
}
return make_coord(int(gridDim.y) - 1 - row_idx, _0{}, make_coord(col_idx, int(blockIdx.z)));
}
CUTLASS_DEVICE
CausalIndividualTileScheduler& operator++() {
valid_ = false;
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
// Launch order: H Q B
struct CausalPersistentTileScheduler {
struct Params {
int num_blocks;
FastDivmod divmod_h;
FastDivmod divmod_m_block;
FastDivmod divmod_b;
KernelHardwareInfo hw_info;
};
int block_idx = 0;
Params params;
CUTLASS_DEVICE
CausalPersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {}
template<class ProblemSize, class ClusterShape, class TileShape>
static Params to_underlying_arguments(
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
ClusterShape const& cluster_shape, TileShape const& tile_shape) {
using namespace cute;
// Get SM count if needed, otherwise use user supplied SM count
int sm_count = hw_info.sm_count;
if (sm_count <= 0) {
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
}
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
hw_info.sm_count = sm_count;
int num_m_blocks = cutlass::round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape));
int num_blocks = num_m_blocks * size<3,0>(problem_size) * size<3,1>(problem_size);
return Params {
num_blocks,
{ size<3,0>(problem_size) }, { num_m_blocks}, { size<3,1>(problem_size) },
hw_info
};
}
static dim3 get_grid_shape(Params const& params) {
dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1);
return grid;
}
CUTLASS_DEVICE
bool is_valid() {
return block_idx < params.num_blocks;
}
CUTLASS_DEVICE
auto get_block_coord() {
using namespace cute;
int block_decode = block_idx;
int m_block, bidb, bidh;
params.divmod_h(block_decode, bidh, block_decode);
params.divmod_m_block(block_decode, m_block, block_decode);
params.divmod_b(block_decode, bidb, block_decode);
return make_coord(m_block, _0{}, make_coord(bidh, bidb));
}
CUTLASS_DEVICE
CausalPersistentTileScheduler& operator++() {
block_idx += gridDim.x;
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::fmha::kernel

View File

@ -1245,7 +1245,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
};
bool leading_causal_masking = false;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask, Mask>) {
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>
|| std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
leading_causal_masking = warp_uniform(iter_index == get<1>(blk_coord));
}
bool trailing_residual_masking = false;
@ -1682,7 +1683,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
);
int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{});
int iter_start = 0;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask, Mask>) {
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask> ||
std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{};
}
if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) {

View File

@ -28,6 +28,7 @@
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cute/layout.hpp"
@ -38,6 +39,7 @@
#include "kernel/fmha_options.hpp"
#include "kernel/fmha_tile_scheduler.hpp"
#include "kernel/fmha_causal_tile_scheduler.hpp"
#include "collective/fmha_fusion.hpp"
#include "collective/fmha_common.hpp"
@ -79,6 +81,45 @@ struct Sm100FmhaCtxKernelWarpspecializedSchedule {
static const int NumRegsCorrection = 96 - (kDebugUsingPrintf ? 16 : 0);
static const int NumRegsOther = 32 + (kDebugUsingPrintf ? 16 : 0);
static const int NumRegsEmpty = 24;
static const int NumWarps = 16;
};
struct Sm100MlaFwdCtxKernelWarpspecializedSchedule {
enum class WarpRole {
Softmax0,
Softmax1,
Correction,
MMA,
Load,
Epilogue,
Empty
};
static constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) {
int wg_idx = warp_idx / 4; // warp_idx
if (wg_idx == 0) return WarpRole::Softmax0; // 0 - 3
if (wg_idx == 1) return WarpRole::Softmax1; // 4 - 7
if (wg_idx == 2) return WarpRole::Correction; // 8 - 11
if (warp_idx == 12) return WarpRole::MMA; // 12
if (warp_idx == 13) return WarpRole::Load; // 13
if (warp_idx == 14) return WarpRole::Epilogue; // 14
return WarpRole::Empty; // 15
}
static const int NumWarpsSoftmax = 4;
static const int NumWarpsCorrection = 4;
static const int NumWarpsEpilogue = 1;
static const int NumWarpsLoad = 1;
static const bool kDebugUsingPrintf = false;
static const int NumRegsSoftmax = 184;
static const int NumRegsCorrection = 96 - (kDebugUsingPrintf ? 16 : 0);
static const int NumRegsOther = 48 + (kDebugUsingPrintf ? 16 : 0);
static const int NumRegsEmpty = 24;
static const int NumWarps = 16;
@ -106,6 +147,9 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
static const int NumWarpsCorrection = KernelSchedule::NumWarpsCorrection;
static const int NumWarpsEpilogue = KernelSchedule::NumWarpsEpilogue;
static const int NumWarpsLoad = KernelSchedule::NumWarpsLoad;
static_assert(NumWarpsEpilogue == CollectiveEpilogue::NumWarpsEpilogue);
static_assert(NumWarpsLoad == CollectiveEpilogue::NumWarpsLoad);
static const int NumRegsSoftmax = KernelSchedule::NumRegsSoftmax;
static const int NumRegsCorrection = KernelSchedule::NumRegsCorrection;
@ -114,13 +158,31 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
static const int NumWarps = KernelSchedule::NumWarps;
static constexpr bool IsMla = std::is_same_v<KernelSchedule, Sm100MlaFwdCtxKernelWarpspecializedSchedule>;
using ClusterShape = typename CollectiveMainloop::ClusterShape;
using TmemAllocator = cute::TMEM::Allocator1Sm;
struct SharedStorage {
typename CollectiveMainloop::TensorStorage mainloop;
typename CollectiveEpilogue::TensorStorage epilogue;
using UnionType = union {
typename CollectiveMainloop::TensorStorage mainloop;
typename CollectiveEpilogue::TensorStorage epilogue;
};
using StructType = struct {
typename CollectiveMainloop::TensorStorage mainloop;
typename CollectiveEpilogue::TensorStorage epilogue;
};
static constexpr bool IsPersistent = std::is_same_v<TileScheduler, PersistentTileScheduler> || std::is_same_v<TileScheduler, CausalPersistentTileScheduler>;
using MainloopEpilogueStorage = std::conditional_t<IsPersistent,
std::conditional_t<IsMla,
std::conditional_t<CollectiveMainloop::IsOrderLoadEpilogue, UnionType, StructType>,
StructType>,
UnionType>;
MainloopEpilogueStorage mainloop_epilogue;
struct PipelineStorage {
alignas(16) typename CollectiveMainloop::PipelineQ::SharedStorage load_q;
@ -206,6 +268,16 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem);
auto get_epilogue_storage = [&]() {
if constexpr (IsMla && CollectiveMainloop::IsOrderLoadEpilogue) {
return reinterpret_cast<typename CollectiveEpilogue::TensorStorage *>(shared_storage.mainloop_epilogue.mainloop.smem_o.data());
} else {
return &shared_storage.mainloop_epilogue.epilogue;
}
};
typename CollectiveEpilogue::TensorStorage & epilogue_storage = *get_epilogue_storage();
typename CollectiveMainloop::PipelineQ::Params pipeline_load_q_params;
if (role == WarpRole::Load) {
pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Producer;
@ -228,7 +300,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Consumer;
}
pipeline_load_kv_params.is_leader = lane_predicate && (role == WarpRole::Load);
pipeline_load_kv_params.transaction_bytes = CollectiveMainloop::TransactionBytesLoadKV;
pipeline_load_kv_params.transaction_bytes = CollectiveMainloop::TransactionBytesLoadK;
typename CollectiveMainloop::PipelineKV pipeline_load_kv(
shared_storage.pipelines.load_kv,
pipeline_load_kv_params,
@ -409,7 +481,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
blk_coord,
params.mainloop, logical_problem_shape,
params.problem_shape,
shared_storage.epilogue,
epilogue_storage,
pipeline_corr_epi, pipeline_corr_epi_producer_state,
epilogue
);
@ -420,7 +492,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
blk_coord,
params.mainloop, logical_problem_shape,
params.problem_shape,
shared_storage.epilogue,
epilogue_storage,
pipeline_s0_corr, pipeline_s0_corr_consumer_state,
pipeline_s1_corr, pipeline_s1_corr_consumer_state,
pipeline_mma_corr, pipeline_mma_corr_consumer_state,
@ -462,7 +534,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
mainloop.mma(
blk_coord,
params.mainloop, logical_problem_shape,
shared_storage.mainloop,
shared_storage.mainloop_epilogue.mainloop,
pipeline_load_q, pipeline_load_q_consumer_state,
pipeline_load_kv, pipeline_load_kv_consumer_state,
pipeline_mma_s0, pipeline_mma_s0_producer_state,
@ -475,6 +547,11 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
else if (role == WarpRole::Load) {
warpgroup_reg_set<NumRegsOther>();
if constexpr (IsMla && CollectiveMainloop::IsOrderLoadEpilogue) {
cutlass::arch::NamedBarrier::arrive((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
}
CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord();
@ -493,7 +570,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
mainloop.load(
blk_coord, logical_problem_shape,
params.mainloop, params.problem_shape,
shared_storage.mainloop,
shared_storage.mainloop_epilogue.mainloop,
pipeline_load_q, pipeline_load_q_producer_state,
pipeline_load_kv, pipeline_load_kv_producer_state
);
@ -517,7 +594,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
epilogue.store(
blk_coord, logical_problem_shape,
params.epilogue, params.problem_shape,
shared_storage.epilogue,
epilogue_storage,
pipeline_corr_epi, pipeline_corr_epi_consumer_state
);

View File

@ -59,16 +59,34 @@ void __global__ fmha_reference_kernel(
extern __shared__ char mS_mem[];
ElementAccumulator* mS = reinterpret_cast<ElementAccumulator*>(mS_mem);
ElementAccumulator softmax_scale = static_cast<ElementAccumulator>(1.0 / sqrt(1.0 * size<1>(mO)));
ElementAccumulator softmax_scale = static_cast<ElementAccumulator>(1.0 / sqrt(1.0 * size<1>(mQ)));
auto id = make_identity_tensor(make_shape(1, 1));
for (int idx_L = blockIdx.y; idx_L < size<3>(problem_shape_in); idx_L += gridDim.y) {
for (int idx_Q = blockIdx.x; idx_Q < size<0>(problem_shape_in); idx_Q += gridDim.x) {
auto coord_L = idx2crd(idx_L, shape<3>(problem_shape_in));
auto coord_in = cute::make_tuple(idx_Q, _0{}, _0{}, coord_L);
auto get_coord_in = [&]() {
if constexpr (rank_v<decltype(get<2>(ProblemShapeIn{}))> == 2) {
return cute::make_tuple(idx_Q, _0{}, cute::make_tuple(_0{}, _0{}), coord_L);
} else {
return cute::make_tuple(idx_Q, _0{}, _0{}, coord_L);
}
};
auto coord_in = get_coord_in();
auto [problem_shape, coord] = apply_variable_length(problem_shape_in, coord_in, get<3,1>(coord_in));
int head_qk = 0;
int head_v = 0;
if constexpr (rank_v<decltype(get<2>(problem_shape))> == 2) {
// MLA case: head_qk 192, head_v = 128
head_qk = size<2, 0>(problem_shape) + size<2, 1>(problem_shape);
head_v = size<2, 0>(problem_shape);
} else {
head_qk = size<2>(problem_shape);
head_v = head_qk;
}
if (get<0,0>(coord) >= get<0>(problem_shape)) continue;
int offset_Q = 0;
@ -82,7 +100,7 @@ void __global__ fmha_reference_kernel(
}
if (get<1>(problem_shape) == 0) {
for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) {
for (int idx_D = threadIdx.x; idx_D < head_qk; idx_D += blockDim.x) {
mO(idx_Q + offset_Q, idx_D, idx_L) = Element(0);
}
@ -94,7 +112,7 @@ void __global__ fmha_reference_kernel(
for (int idx_K = threadIdx.x; idx_K < size<1>(problem_shape); idx_K += blockDim.x) {
ElementAccumulator acc = 0;
for (int idx_D = 0; idx_D < size<2>(problem_shape); idx_D++) {
for (int idx_D = 0; idx_D < head_qk; idx_D++) {
ElementAccumulator eQ = mQ(idx_Q + offset_Q, idx_D, idx_L);
ElementAccumulator eK = mK(idx_K + offset_K, idx_D, idx_L);
acc += eQ * eK;
@ -128,7 +146,8 @@ void __global__ fmha_reference_kernel(
ElementAccumulator scale = 1.0f / sum;
for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) {
for (int idx_D = threadIdx.x; idx_D < head_v; idx_D += blockDim.x) {
ElementAccumulator acc = 0;
for (int idx_K = 0; idx_K < size<1>(problem_shape); idx_K++) {
ElementAccumulator eV = mV(idx_K + offset_K, idx_D, idx_L);