Add Blackwell MLA forward (shape: d=192, dv=128) implementation in example_77 (#2472)
This commit is contained in:
@ -853,7 +853,7 @@ struct FwdRunner {
|
||||
flops *= static_cast<double>(size<1>(problem_shape));
|
||||
flops *= static_cast<double>(size<3,1>(problem_shape));
|
||||
}
|
||||
flops *= 4.0 * (std::is_same_v<ActiveMask, CausalMask> ? 0.5 : 1.0);
|
||||
flops *= 4.0 * (std::is_same_v<ActiveMask, CausalMask<true>> || std::is_same_v<ActiveMask, CausalMask<false>> ? 0.5 : 1.0);
|
||||
flops *= static_cast<double>(size<2>(problem_shape));
|
||||
flops *= static_cast<double>(size<3,0>(problem_shape));
|
||||
double tflops_s = flops * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/);
|
||||
|
||||
1069
examples/77_blackwell_fmha/77_blackwell_mla_fwd.cu
Normal file
1069
examples/77_blackwell_fmha/77_blackwell_mla_fwd.cu
Normal file
File diff suppressed because it is too large
Load Diff
@ -33,6 +33,7 @@ set_property(
|
||||
77_blackwell_fmha_gen.cu
|
||||
77_blackwell_mla.cu
|
||||
77_blackwell_fmha_bwd.cu
|
||||
77_blackwell_mla_fwd.cu
|
||||
PROPERTY
|
||||
COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0"
|
||||
)
|
||||
@ -59,6 +60,22 @@ set(TEST_VARLEN_12 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4
|
||||
set(TEST_VARLEN_13 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=2 --varlen-q=177:366:479 --varlen-k=257:0:766)
|
||||
set(TEST_VARLEN_14 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=1 --varlen-k=1)
|
||||
|
||||
set(TEST_MLA_FWD_VARLEN_00 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=4 --varlen-q=128 --varlen-k=128)
|
||||
set(TEST_MLA_FWD_VARLEN_01 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
|
||||
set(TEST_MLA_FWD_VARLEN_02 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=2 --varlen-q=128 --varlen-k=128)
|
||||
set(TEST_MLA_FWD_VARLEN_03 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=8 --varlen-q=256:256 --varlen-k=512:512)
|
||||
set(TEST_MLA_FWD_VARLEN_04 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=4 --varlen-q=256:256 --varlen-k=512:512)
|
||||
set(TEST_MLA_FWD_VARLEN_05 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=1 --varlen-q=256:256 --varlen-k=512:512)
|
||||
set(TEST_MLA_FWD_VARLEN_06 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=2 --varlen-q=256:256:256:256 --varlen-k=256:768:512:512)
|
||||
set(TEST_MLA_FWD_VARLEN_07 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=2 --varlen-q=256:256:256:256 --varlen-k=256:0:1280:512)
|
||||
set(TEST_MLA_FWD_VARLEN_08 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=2 --varlen-q=256:0:512:256 --varlen-k=256:256:1024:512)
|
||||
set(TEST_MLA_FWD_VARLEN_09 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=16 --h_k=16 --varlen-q=100:300 --varlen-k=100:300)
|
||||
set(TEST_MLA_FWD_VARLEN_10 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=4 --varlen-q=2:3 --varlen-k=2:5)
|
||||
set(TEST_MLA_FWD_VARLEN_11 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=2 --varlen-q=11:10 --varlen-k=13:10)
|
||||
set(TEST_MLA_FWD_VARLEN_12 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=4 --varlen-q=177:766 --varlen-k=257:845)
|
||||
set(TEST_MLA_FWD_VARLEN_13 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=2 --varlen-q=177:0:479 --varlen-k=257:0:766)
|
||||
set(TEST_MLA_FWD_VARLEN_14 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=4 --varlen-q=1 --varlen-k=1)
|
||||
|
||||
set(TEST_GEN_BASIC --b=1 --h=4 --k=512 --d=128 --verify)
|
||||
set(TEST_GEN_VARLEN --b=1 --h=4 --k=512 --d=128 --verify --varlen)
|
||||
set(TEST_GEN_HDIM64 --b=2 --h=4 --k=512 --d=64 --verify)
|
||||
@ -161,6 +178,35 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
|
||||
target_include_directories(77_blackwell_fmha_bwd_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_fmha_bwd_${PREC} PRIVATE ${PREC_MACRO})
|
||||
target_compile_options(77_blackwell_fmha_bwd_${PREC} PRIVATE -Xptxas -v)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_mla_fwd_${PREC}
|
||||
77_blackwell_mla_fwd.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_BASIC
|
||||
TEST_CAUSAL
|
||||
TEST_VARLEN
|
||||
TEST_HDIM64
|
||||
TEST_GQA
|
||||
TEST_MLA_FWD_VARLEN_00
|
||||
TEST_MLA_FWD_VARLEN_01
|
||||
TEST_MLA_FWD_VARLEN_02
|
||||
TEST_MLA_FWD_VARLEN_03
|
||||
TEST_MLA_FWD_VARLEN_04
|
||||
TEST_MLA_FWD_VARLEN_05
|
||||
TEST_MLA_FWD_VARLEN_06
|
||||
TEST_MLA_FWD_VARLEN_07
|
||||
TEST_MLA_FWD_VARLEN_08
|
||||
TEST_MLA_FWD_VARLEN_09
|
||||
TEST_MLA_FWD_VARLEN_10
|
||||
TEST_MLA_FWD_VARLEN_11
|
||||
TEST_MLA_FWD_VARLEN_12
|
||||
TEST_MLA_FWD_VARLEN_13
|
||||
TEST_MLA_FWD_VARLEN_14
|
||||
)
|
||||
target_include_directories(77_blackwell_mla_fwd_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_mla_fwd_${PREC} PRIVATE ${PREC_MACRO})
|
||||
target_compile_options(77_blackwell_mla_fwd_${PREC} PRIVATE -Xptxas -v)
|
||||
endforeach()
|
||||
|
||||
# Add a target that builds all examples
|
||||
@ -176,5 +222,7 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
|
||||
77_blackwell_mla_2sm_cpasync_fp16
|
||||
77_blackwell_fmha_bwd_fp8
|
||||
77_blackwell_fmha_bwd_fp16
|
||||
77_blackwell_mla_fwd_fp8
|
||||
77_blackwell_mla_fwd_fp16
|
||||
)
|
||||
endif()
|
||||
|
||||
@ -184,10 +184,16 @@ struct ResidualMaskForBackward : NoMask {
|
||||
}
|
||||
};
|
||||
|
||||
// There are two ways to do causal if N_Q != N_K
|
||||
// (1) The Q is at the beginning of the matrix
|
||||
// (2) The Q is at the end of the matrix
|
||||
template<bool kIsQBegin = true>
|
||||
struct CausalMask : NoMask {
|
||||
|
||||
using Base = NoMask;
|
||||
|
||||
static constexpr bool IsQBegin = kIsQBegin;
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_trip_count(
|
||||
@ -197,9 +203,16 @@ struct CausalMask : NoMask {
|
||||
|
||||
// See note below on different ways to think about causal attention
|
||||
// Again, we'd add the offset_q into the max_blocks_q calculation
|
||||
int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size);
|
||||
int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape));
|
||||
return std::min(max_blocks_k, max_blocks_q);
|
||||
if constexpr (IsQBegin) {
|
||||
int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size);
|
||||
int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape));
|
||||
return std::min(max_blocks_k, max_blocks_q);
|
||||
} else {
|
||||
const int offset_q = get<1>(problem_size) - get<0>(problem_size);
|
||||
int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size);
|
||||
int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape) + offset_q, get<1>(tile_shape));
|
||||
return std::min(max_blocks_k, max_blocks_q);
|
||||
}
|
||||
}
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
@ -208,9 +221,14 @@ struct CausalMask : NoMask {
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
|
||||
if constexpr (IsQBegin) {
|
||||
int trip_count = get_trip_count(blk_coord, tile_shape, problem_size);
|
||||
return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));
|
||||
} else {
|
||||
const int offset_tile_q = get<1>(problem_size) % get<1>(tile_shape);
|
||||
return ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape));
|
||||
}
|
||||
}
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
@ -232,26 +250,36 @@ struct CausalMask : NoMask {
|
||||
|
||||
// There are two ways to do causal if N_Q != N_K
|
||||
// (1) is to assume that the Q is at the beginning of the matrix
|
||||
// - this is what we demonstrate here
|
||||
// - this is the default setting.
|
||||
// (2) is that it is at the end of the matrix
|
||||
// - this is usually what we want for inference settings
|
||||
// where we only compute the next row and use cache for the rest
|
||||
// - if you'd like this, you only need to add an offset like so:
|
||||
// get<0>(pos) + offset_q < get<1>(pos)
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(acc_qk); i++) {
|
||||
auto pos = index_qk(i);
|
||||
if ((get<0>(pos) < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) {
|
||||
acc_qk(i) = -INFINITY;
|
||||
// - if you'd like this, you only need to set kIsQBegin=false
|
||||
|
||||
if constexpr (IsQBegin) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(acc_qk); i++) {
|
||||
auto pos = index_qk(i);
|
||||
if ((get<0>(pos) < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) {
|
||||
acc_qk(i) = -INFINITY;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const auto offset_q = get<1>(problem_size) - get<0>(problem_size);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(acc_qk); i++) {
|
||||
auto pos = index_qk(i);
|
||||
if ((get<0>(pos) + offset_q < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) {
|
||||
acc_qk(i) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
struct CausalForBackwardMask : CausalMask, ResidualMaskForBackward {
|
||||
struct CausalForBackwardMask : CausalMask<true>, ResidualMaskForBackward {
|
||||
|
||||
using Base = CausalMask;
|
||||
using Base = CausalMask<true>;
|
||||
|
||||
template<class AccQK, class IndexQK, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
|
||||
@ -42,7 +42,8 @@ template<
|
||||
class ElementAcc,
|
||||
class TileShape, // Q, D, _
|
||||
class StrideO, // Q, D, B
|
||||
class StrideLSE_ // Q, B
|
||||
class StrideLSE_, // Q, B
|
||||
class OrderLoadEpilogue = cute::false_type
|
||||
>
|
||||
struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
||||
|
||||
@ -56,7 +57,10 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
||||
using SmemLayoutO_ = SmemLayoutO;
|
||||
using StrideLSE = StrideLSE_;
|
||||
using ElementOut = Element;
|
||||
|
||||
|
||||
static const int NumWarpsEpilogue = 1;
|
||||
static const int NumWarpsLoad = 1;
|
||||
|
||||
struct TensorStorage {
|
||||
|
||||
using SmemLayoutO = SmemLayoutO_;
|
||||
@ -86,6 +90,19 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
||||
StrideLSE dLSE;
|
||||
};
|
||||
|
||||
// FMHA and MLA have different input ProblemShapes;
|
||||
// get problem_shape_O according to the input ProblemShape.
|
||||
template<class ProblemShape>
|
||||
CUTLASS_DEVICE static constexpr
|
||||
auto get_problem_shape_O (
|
||||
ProblemShape const& problem_shape) {
|
||||
if constexpr (rank_v<decltype(get<2>(ProblemShape{}))> == 2) {
|
||||
return replace<1>(select<0,2,3>(problem_shape), get<2, 0>(problem_shape));
|
||||
} else {
|
||||
return select<0,2,3>(problem_shape);
|
||||
}
|
||||
}
|
||||
|
||||
template<class ProblemShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemShape const& problem_shape,
|
||||
@ -94,7 +111,8 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
||||
|
||||
auto ptr_O = args.ptr_O;
|
||||
StrideO dO = args.dO;
|
||||
auto problem_shape_O = select<0,2,3>(problem_shape);
|
||||
|
||||
auto problem_shape_O = get_problem_shape_O(problem_shape);
|
||||
|
||||
if constexpr (is_variable_length_v<tuple_element_t<0, ProblemShape>>) {
|
||||
auto cumulative_length_q = get<0>(problem_shape).cumulative_length;
|
||||
@ -146,7 +164,7 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
||||
int o0_index = 2 * get<0>(blk_coord);
|
||||
int o1_index = 2 * get<0>(blk_coord) + 1;
|
||||
|
||||
Tensor mO_qdl_p = params.tma_store_o.get_tma_tensor(select<0,2,3>(problem_shape));
|
||||
Tensor mO_qdl_p = params.tma_store_o.get_tma_tensor(get_problem_shape_O(problem_shape));
|
||||
// offset mode 0 by (max_length - real_length)
|
||||
// offset mode 3,1 by cumulative_length + real_length
|
||||
// the ptr is already offset by - max_length
|
||||
@ -201,6 +219,11 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
||||
|
||||
tma_store_wait<0>();
|
||||
|
||||
if constexpr (cute::is_same_v<OrderLoadEpilogue, cute::true_type>) {
|
||||
cutlass::arch::NamedBarrier::arrive((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp,
|
||||
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
||||
}
|
||||
|
||||
pipeline.consumer_release(pipeline_release_state);
|
||||
++pipeline_release_state;
|
||||
|
||||
|
||||
@ -58,7 +58,9 @@ template<
|
||||
// and referes to the two softmax warps
|
||||
// (2, 1, 1) means that they are stacked (best for large Q since it loads the least K/V)
|
||||
// (1, 2, 1) means they sit side by side (best for small Q / large K)
|
||||
class ThreadShape = Shape<_2, _1, _1>
|
||||
class ThreadShape = Shape<_2, _1, _1>,
|
||||
// Since shared memory is sufficient for FMHA, there is no need to reuse shared memory.
|
||||
class OrderLoadEpilogue = cute::false_type
|
||||
>
|
||||
struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
|
||||
@ -106,6 +108,8 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
using SmemLayoutK = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutB{}, Int<StageCountKV>{}));
|
||||
using SmemLayoutV = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutB{}, Int<StageCountKV>{}));
|
||||
|
||||
// Reuse shared memory for V and O.
|
||||
static constexpr bool IsOrderLoadEpilogue = std::is_same_v<OrderLoadEpilogue, cute::true_type>;
|
||||
struct TensorStorage {
|
||||
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
|
||||
union {
|
||||
@ -168,9 +172,10 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
|
||||
static const int TransactionBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v<Element>);
|
||||
|
||||
static const int TransactionBytesLoadKV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v<Element>);
|
||||
static const int TransactionBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v<Element>);
|
||||
static const int TransactionBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v<Element>);
|
||||
|
||||
static_assert(cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v<Element>) == cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v<Element>), "K and V smem layouts must be of equal size");
|
||||
static_assert(TransactionBytesLoadK == TransactionBytesLoadV, "K and V smem layouts must be of equal size");
|
||||
|
||||
using Load = Sm100FmhaLoadTmaWarpspecialized<
|
||||
Element, StrideQ, StrideK, StrideV,
|
||||
@ -525,7 +530,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1);
|
||||
Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));
|
||||
|
||||
auto tilePlikeFP32 = get<1>(TileShapeQK{}) / Int<sizeof(float)>{} * Int<sizeof(Element)>{};
|
||||
auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int<sizeof(float)>{} * Int<sizeof(Element)>{};
|
||||
Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
|
||||
tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1));
|
||||
Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,340 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/memory_sm80.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/layout.hpp"
|
||||
|
||||
#include "collective/fmha_common.hpp"
|
||||
#include "collective/fmha_fusion.hpp"
|
||||
|
||||
namespace cutlass::fmha::collective {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<
|
||||
class Element,
|
||||
class StrideQ,
|
||||
class StrideK,
|
||||
class StrideV,
|
||||
class CollectiveMmaQK,
|
||||
class CollectiveMmaPV,
|
||||
class SmemLayoutQ,
|
||||
class SmemLayoutK,
|
||||
class SmemLayoutV,
|
||||
class TensorStorage,
|
||||
class PipelineQ,
|
||||
class PipelineKV,
|
||||
class Mask,
|
||||
class TileShape,
|
||||
class OrderLoadEpilogue = cute::false_type
|
||||
>
|
||||
struct Sm100MlaFwdLoadTmaWarpspecialized {
|
||||
|
||||
using TileShapeQK = typename CollectiveMmaQK::TileShape;
|
||||
using TileShapePV = typename CollectiveMmaPV::TileShape;
|
||||
|
||||
static constexpr int TransactionBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v<Element>);
|
||||
static constexpr int TransactionBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v<Element>);
|
||||
|
||||
static const int NumWarpsEpilogue = 1;
|
||||
static const int NumWarpsLoad = 1;
|
||||
|
||||
struct Arguments {
|
||||
const Element* ptr_Q;
|
||||
StrideQ dQ;
|
||||
const Element* ptr_K;
|
||||
StrideK dK;
|
||||
const Element* ptr_V;
|
||||
StrideV dV;
|
||||
};
|
||||
|
||||
using TMA_Q = typename CollectiveMmaQK::Params::TMA_A;
|
||||
using TMA_K = typename CollectiveMmaQK::Params::TMA_B;
|
||||
using TMA_V = typename CollectiveMmaPV::Params::TMA_B;
|
||||
|
||||
struct Params {
|
||||
TMA_Q tma_load_q;
|
||||
TMA_K tma_load_k;
|
||||
TMA_V tma_load_v;
|
||||
};
|
||||
|
||||
template<class ProblemShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemShape const& problem_shape,
|
||||
Arguments const& args,
|
||||
void* workspace) {
|
||||
|
||||
auto ptr_Q = args.ptr_Q;
|
||||
auto ptr_K = args.ptr_K;
|
||||
auto ptr_V = args.ptr_V;
|
||||
auto dQ = args.dQ;
|
||||
auto dK = args.dK;
|
||||
auto dV = args.dV;
|
||||
auto problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape));
|
||||
|
||||
if constexpr (is_variable_length_v<tuple_element_t<0, ProblemShape>>) {
|
||||
auto cumulative_length_q = get<0>(problem_shape).cumulative_length;
|
||||
if (cumulative_length_q != nullptr) {
|
||||
int max_length_q = get<0>(problem_shape).max_length;
|
||||
// for variable sequence lenght, the batch is in units of row_stride
|
||||
get<2,1>(dQ) = get<0>(dQ);
|
||||
get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_q * (1 + get<3,1>(problem_shape)));
|
||||
// offset ptr by the amount we add back in later
|
||||
ptr_Q -= max_length_q * get<0>(dQ);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (is_variable_length_v<tuple_element_t<1, ProblemShape>>) {
|
||||
auto cumulative_length_kv = get<1>(problem_shape).cumulative_length;
|
||||
if (cumulative_length_kv != nullptr) {
|
||||
int max_length_kv = get<1>(problem_shape).max_length;
|
||||
// for variable sequence lenght, the batch is in units of row_stride
|
||||
get<2,1>(dK) = get<0>(dK);
|
||||
get<2,1>(dV) = get<0>(dV);
|
||||
get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_kv * (1 + get<3,1>(problem_shape)));
|
||||
// offset ptr by the amount we add back in later
|
||||
ptr_K -= max_length_kv * get<0>(dK);
|
||||
ptr_V -= max_length_kv * get<0>(dV);
|
||||
}
|
||||
}
|
||||
|
||||
auto problem_shape_pv = replace<1>(select<0,2,1,3>(problem_shape_qk), get<2, 0>(problem_shape));
|
||||
|
||||
auto params_qk = CollectiveMmaQK::to_underlying_arguments(
|
||||
problem_shape_qk,
|
||||
typename CollectiveMmaQK::Arguments {
|
||||
ptr_Q, dQ,
|
||||
ptr_K, dK,
|
||||
}, /*workspace=*/ nullptr);
|
||||
|
||||
auto params_pv = CollectiveMmaPV::to_underlying_arguments(
|
||||
problem_shape_pv,
|
||||
typename CollectiveMmaPV::Arguments {
|
||||
ptr_K, dK, // never used, dummy
|
||||
ptr_V, select<1,0,2>(dV),
|
||||
}, /*workspace=*/ nullptr);
|
||||
|
||||
return Params{
|
||||
params_qk.tma_load_a,
|
||||
params_qk.tma_load_b,
|
||||
params_pv.tma_load_b
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& params) {
|
||||
cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor());
|
||||
}
|
||||
|
||||
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
|
||||
CUTLASS_DEVICE void
|
||||
load(
|
||||
BlkCoord const& blk_coord_in, ProblemShape const& problem_shape,
|
||||
Params const& params, ParamsProblemShape const& params_problem_shape,
|
||||
TensorStorage& storage,
|
||||
PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state,
|
||||
PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) {
|
||||
|
||||
BlkCoord blk_coord_q = blk_coord_in;
|
||||
BlkCoord blk_coord_kv = blk_coord_in;
|
||||
|
||||
auto problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape));
|
||||
auto problem_shape_v = replace<2>(problem_shape, get<2, 0>(problem_shape));
|
||||
|
||||
int mask_tile_count = Mask{}.get_trip_count(blk_coord_in, TileShape{}, problem_shape);
|
||||
|
||||
using X = Underscore;
|
||||
|
||||
// this one is only executed by one thread, no need to elect_one
|
||||
|
||||
// Q1, K1, Q2, V1, K2, V2, K3, V3, ...
|
||||
// two pipes: Q and KV
|
||||
// from Memory (prod) to TensorCore (cons)
|
||||
|
||||
// compute gQ, sQ
|
||||
// we load 2*get<0>(blk_coord), and 2*get<0>(blk_coord) + 1
|
||||
ThrMMA mma_qk = typename CollectiveMmaQK::TiledMma{}.get_slice(0);
|
||||
Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape_qk));
|
||||
|
||||
int q_offs_0 = 0;
|
||||
int q_offs_2_1 = 0;
|
||||
|
||||
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
|
||||
auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length;
|
||||
if (cumulative_length_q != nullptr) {
|
||||
int max_length_q = get<0>(params_problem_shape).max_length;
|
||||
q_offs_0 = max_length_q - get<0>(problem_shape);
|
||||
q_offs_2_1 = cumulative_length_q[get<2,1>(blk_coord_q)] + get<0>(problem_shape);
|
||||
get<2,1>(blk_coord_q) = 0;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, q_offs_2_1)), mQ_qdl_p);
|
||||
|
||||
Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{});
|
||||
Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl);
|
||||
Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{});
|
||||
auto [tQgQ_qdl, tQsQ] = tma_partition(
|
||||
params.tma_load_q, _0{}, make_layout(_1{}),
|
||||
group_modes<0,3>(sQ), group_modes<0,3>(tSgQ_qdl)
|
||||
);
|
||||
Tensor tQgQ = tQgQ_qdl(_, _, _0{}, get<2>(blk_coord_q));
|
||||
|
||||
// compute gK, sK
|
||||
Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape_qk));
|
||||
|
||||
int kv_offs_0 = 0;
|
||||
int kv_offs_2_1 = 0;
|
||||
|
||||
if constexpr (is_variable_length_v<tuple_element_t<1, ParamsProblemShape>>) {
|
||||
auto cumulative_length = get<1>(params_problem_shape).cumulative_length;
|
||||
if (cumulative_length != nullptr) {
|
||||
int max_length = get<1>(params_problem_shape).max_length;
|
||||
kv_offs_0 = max_length - get<1>(problem_shape);
|
||||
kv_offs_2_1 = cumulative_length[get<2,1>(blk_coord_kv)] + get<1>(problem_shape);
|
||||
get<2,1>(blk_coord_kv) = 0;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, kv_offs_2_1)), mK_kdl_p);
|
||||
|
||||
Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step<X, _1, _1>{});
|
||||
Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl);
|
||||
Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{});
|
||||
auto [tKgK_kdl, tKsK] = tma_partition(
|
||||
params.tma_load_k, _0{}, make_layout(_1{}),
|
||||
group_modes<0,3>(sK), group_modes<0,3>(tSgK_kdl)
|
||||
);
|
||||
Tensor tKgK = tKgK_kdl(_, _, _0{}, get<2>(blk_coord_kv));
|
||||
|
||||
// compute gV, sV
|
||||
ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0);
|
||||
Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape_v));
|
||||
|
||||
Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, kv_offs_2_1)), mV_dkl_p);
|
||||
|
||||
Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step<X, _1, _1>{});
|
||||
Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl);
|
||||
Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{});
|
||||
auto [tVgV_dkl, tVsV] = tma_partition(
|
||||
params.tma_load_v, _0{}, make_layout(_1{}),
|
||||
group_modes<0,3>(sV), group_modes<0,3>(tOgV_dkl)
|
||||
);
|
||||
auto tVgV = tVgV_dkl(_, _0{}, _, get<2>(blk_coord_kv));
|
||||
|
||||
// blk_coord in decomposed in terms of TileShape, not TileShapeQK
|
||||
// As such, it needs to be transformed as
|
||||
// (a,b,c): a -> 2*a (Q0) 2*a+1 (Q1)
|
||||
// b -> 2*a (Ki i even) 2*a+1 (Ki i odd)
|
||||
|
||||
uint32_t lane_predicate = cute::elect_one_sync();
|
||||
|
||||
// Q1
|
||||
int q0_index = 2 * get<0>(blk_coord_q);
|
||||
int q1_index = 2 * get<0>(blk_coord_q) + 1;
|
||||
pipeline_q.producer_acquire(pipeline_q_producer_state);
|
||||
if (lane_predicate) {
|
||||
auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state);
|
||||
copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q0_index), tQsQ(_, pipeline_q_producer_state.index()));
|
||||
}
|
||||
++pipeline_q_producer_state;
|
||||
|
||||
// K1
|
||||
int k_index = 0;
|
||||
pipeline_kv.producer_acquire(pipeline_kv_producer_state);
|
||||
if (lane_predicate) {
|
||||
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
|
||||
copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index() / 2));
|
||||
}
|
||||
++pipeline_kv_producer_state;
|
||||
|
||||
// Q2
|
||||
pipeline_q.producer_acquire(pipeline_q_producer_state);
|
||||
if (lane_predicate) {
|
||||
auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state);
|
||||
copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q1_index), tQsQ(_, pipeline_q_producer_state.index()));
|
||||
}
|
||||
++pipeline_q_producer_state;
|
||||
|
||||
if constexpr (cute::is_same_v<OrderLoadEpilogue, cute::true_type>) {
|
||||
cutlass::arch::NamedBarrier::sync((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp,
|
||||
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
||||
}
|
||||
|
||||
// V1
|
||||
pipeline_kv.producer_acquire_bytes(pipeline_kv_producer_state, TransactionBytesLoadV);
|
||||
if (lane_predicate) {
|
||||
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
|
||||
copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index() / 2));
|
||||
}
|
||||
++pipeline_kv_producer_state;
|
||||
k_index += 1;
|
||||
|
||||
// loop:
|
||||
mask_tile_count -= 1;
|
||||
for (; mask_tile_count > 0; mask_tile_count -= 1) {
|
||||
|
||||
// Ki
|
||||
pipeline_kv.producer_acquire(pipeline_kv_producer_state);
|
||||
if (lane_predicate) {
|
||||
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
|
||||
copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index() / 2));
|
||||
|
||||
// prefetch vi
|
||||
cute::prefetch(params.tma_load_v, tVgV(_, k_index));
|
||||
}
|
||||
++pipeline_kv_producer_state;
|
||||
|
||||
// Vi
|
||||
pipeline_kv.producer_acquire_bytes(pipeline_kv_producer_state, TransactionBytesLoadV);
|
||||
if (lane_predicate) {
|
||||
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
|
||||
copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index() / 2));
|
||||
|
||||
// prefetch ki+1
|
||||
if(mask_tile_count > 1) {
|
||||
cute::prefetch(params.tma_load_k, tKgK(_, k_index + 1));
|
||||
}
|
||||
}
|
||||
++pipeline_kv_producer_state;
|
||||
k_index += 1;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
250
examples/77_blackwell_fmha/common/pipeline_mla.hpp
Normal file
250
examples/77_blackwell_fmha/common/pipeline_mla.hpp
Normal file
@ -0,0 +1,250 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*!
|
||||
\file
|
||||
\brief Support the producer to acquire specific bytes of data.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/pipeline/sm100_pipeline.hpp"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <
|
||||
int Stages_,
|
||||
class ClusterShape = Shape<int,int,_1>,
|
||||
class AtomThrShape_MNK_ = Shape<_1,_1,_1>
|
||||
>
|
||||
class PipelineTmaAsyncMla {
|
||||
|
||||
public:
|
||||
static constexpr uint32_t Stages = Stages_;
|
||||
using AtomThrShape_MNK = AtomThrShape_MNK_;
|
||||
|
||||
private:
|
||||
using Impl = PipelineTmaUmmaAsync<Stages_, ClusterShape, AtomThrShape_MNK_>;
|
||||
|
||||
public:
|
||||
using FullBarrier = typename Impl::FullBarrier;
|
||||
using EmptyBarrier = typename Impl::EmptyBarrier;
|
||||
using ProducerBarrierType = typename Impl::ProducerBarrierType;
|
||||
using ConsumerBarrierType = typename Impl::ConsumerBarrierType;
|
||||
using PipelineState = typename Impl::PipelineState;
|
||||
using SharedStorage = typename Impl::SharedStorage;
|
||||
using ThreadCategory = typename Impl::ThreadCategory;
|
||||
using Params = typename Impl::Params;
|
||||
|
||||
|
||||
using McastDirection = McastDirection;
|
||||
|
||||
// Helper function to initialize barriers
|
||||
static
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape) {
|
||||
int warp_idx = canonical_warp_idx_sync();
|
||||
if (warp_idx == params.initializing_warp) {
|
||||
// Barrier FULL and EMPTY init
|
||||
constexpr int producer_arv_cnt = 1;
|
||||
auto atom_thr_shape = AtomThrShape_MNK{};
|
||||
uint32_t const multicast_consumer_arrival_count = (cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape)) +
|
||||
(cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape)) - 1;
|
||||
|
||||
cutlass::arch::detail::initialize_barrier_array_pair_aligned<decltype(storage.full_barrier_), decltype(storage.empty_barrier_), Stages>(
|
||||
storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count);
|
||||
}
|
||||
cutlass::arch::fence_barrier_init();
|
||||
}
|
||||
|
||||
static
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction) {
|
||||
auto atom_thr_shape = AtomThrShape_MNK{};
|
||||
|
||||
int warp_idx = canonical_warp_idx_sync();
|
||||
if (warp_idx == params.initializing_warp) {
|
||||
// Barrier FULL and EMPTY init
|
||||
constexpr int producer_arv_cnt = 1;
|
||||
uint32_t const multicast_consumer_arrival_count = (mcast_direction == McastDirection::kRow) ?
|
||||
cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape) : // Mcast with row ctas
|
||||
cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape); // Mcast with col ctas
|
||||
|
||||
cutlass::arch::detail::initialize_barrier_array_pair_aligned<decltype(storage.full_barrier_), decltype(storage.empty_barrier_), Stages>(
|
||||
storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count);
|
||||
}
|
||||
cutlass::arch::fence_barrier_init();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void init_masks(ClusterShape cluster_shape, dim3 block_id_in_cluster = cute::block_id_in_cluster()) {
|
||||
// Calculate consumer mask
|
||||
if (params_.role == ThreadCategory::Consumer) {
|
||||
auto cluster_layout = make_layout(cluster_shape);
|
||||
block_id_mask_ = detail::calculate_multicast_mask<McastDirection::kRowCol>(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster);
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void init_masks(ClusterShape cluster_shape, McastDirection mcast_direction) {
|
||||
// Calculate consumer mask
|
||||
dim3 block_id_in_cluster = cute::block_id_in_cluster();
|
||||
auto cluster_layout = make_layout(cluster_shape);
|
||||
if (mcast_direction == McastDirection::kRow) {
|
||||
block_id_mask_ = detail::calculate_multicast_mask<McastDirection::kRow>(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster);
|
||||
}
|
||||
else {
|
||||
block_id_mask_ = detail::calculate_multicast_mask<McastDirection::kCol>(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public:
|
||||
template<typename InitBarriers = cute::true_type, typename InitMasks = cute::true_type>
|
||||
CUTLASS_DEVICE
|
||||
PipelineTmaAsyncMla(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}, InitMasks = {})
|
||||
: impl_(storage, params, cluster_shape, cute::false_type{}, cute::false_type{})
|
||||
, params_(params)
|
||||
, empty_barrier_ptr_(&storage.empty_barrier_[0])
|
||||
, full_barrier_ptr_(&storage.full_barrier_[0]) {
|
||||
static_assert(cute::is_same_v<InitBarriers, cute::true_type> || cute::is_same_v<InitBarriers, cute::false_type>);
|
||||
if constexpr (cute::is_same_v<InitBarriers, cute::true_type>) {
|
||||
init_barriers(storage, params_, cluster_shape);
|
||||
}
|
||||
|
||||
static_assert(cute::is_same_v<InitMasks, cute::true_type> || cute::is_same_v<InitMasks, cute::false_type>);
|
||||
if constexpr (cute::is_same_v<InitMasks, cute::true_type>) {
|
||||
init_masks(cluster_shape);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename InitBarriers = cute::true_type, typename InitMasks = cute::true_type>
|
||||
CUTLASS_DEVICE
|
||||
PipelineTmaAsyncMla(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction, InitBarriers = {}, InitMasks = {})
|
||||
: impl_(storage, params, cluster_shape, cute::false_type{}, cute::false_type{})
|
||||
, params_(params)
|
||||
, empty_barrier_ptr_(&storage.empty_barrier_[0])
|
||||
, full_barrier_ptr_(&storage.full_barrier_[0]) {
|
||||
static_assert(cute::is_same_v<InitBarriers, cute::true_type> || cute::is_same_v<InitBarriers, cute::false_type>);
|
||||
if constexpr (cute::is_same_v<InitBarriers, cute::true_type>) {
|
||||
init_barriers(storage, params_, cluster_shape, mcast_direction);
|
||||
}
|
||||
|
||||
static_assert(cute::is_same_v<InitMasks, cute::true_type> || cute::is_same_v<InitMasks, cute::false_type>);
|
||||
if constexpr (cute::is_same_v<InitMasks, cute::true_type>) {
|
||||
init_masks(cluster_shape, mcast_direction);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) {
|
||||
impl_.producer_acquire(state, barrier_token);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void producer_acquire_bytes(uint32_t stage, uint32_t bytes, uint32_t phase, ProducerToken barrier_token) {
|
||||
detail::pipeline_check_is_producer(params_.role);
|
||||
if (barrier_token != BarrierStatus::WaitDone) {
|
||||
empty_barrier_ptr_[stage].wait(phase);
|
||||
}
|
||||
|
||||
if (params_.is_leader) {
|
||||
full_barrier_ptr_[stage].arrive_and_expect_tx(bytes);
|
||||
}
|
||||
#ifndef NDEBUG
|
||||
if (params_.role == ThreadCategory::Consumer || params_.role == ThreadCategory::NonParticipant) {
|
||||
asm volatile ("brkpt;\n" ::);
|
||||
}
|
||||
|
||||
// Most likely you have elected more than one leader
|
||||
if (params_.is_leader && (threadIdx.x % 32 != 0)) {
|
||||
asm volatile ("brkpt;\n" ::);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void producer_acquire_bytes(PipelineState state, uint32_t bytes, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) {
|
||||
producer_acquire_bytes(state.index(), bytes, state.phase(), barrier_token);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ProducerBarrierType* producer_get_barrier(PipelineState state) {
|
||||
return impl_.producer_get_barrier(state);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void consumer_wait(PipelineState state, ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) {
|
||||
impl_.consumer_wait(state, barrier_token);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void consumer_release(PipelineState state) {
|
||||
consumer_release(state.index(), false);
|
||||
}
|
||||
|
||||
private:
|
||||
Impl impl_;
|
||||
Params params_;
|
||||
EmptyBarrier *empty_barrier_ptr_;
|
||||
FullBarrier *full_barrier_ptr_;
|
||||
uint16_t block_id_mask_ = 0;
|
||||
static constexpr bool is_2sm_mma = size(AtomThrShape_MNK{}) > 1;
|
||||
|
||||
// Consumer signalling Producer of completion
|
||||
// Ensures all blocks in the Same Row and Column get notifed.
|
||||
CUTLASS_DEVICE
|
||||
void consumer_release(uint32_t stage, uint32_t skip) {
|
||||
detail::pipeline_check_is_consumer(params_.role);
|
||||
uint64_t* smem_ptr = reinterpret_cast<uint64_t*>(&empty_barrier_ptr_[stage]);
|
||||
if constexpr (is_2sm_mma) { // Mma cluster shape is 2x1
|
||||
if (!skip) {
|
||||
cutlass::arch::umma_arrive_multicast_2x1SM(smem_ptr, block_id_mask_);
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (!skip) {
|
||||
if constexpr (cute::is_static_v<ClusterShape> and size(ClusterShape{}) == 1) {
|
||||
cutlass::arch::umma_arrive(smem_ptr);
|
||||
}
|
||||
else {
|
||||
cutlass::arch::umma_arrive_multicast(smem_ptr, block_id_mask_);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
197
examples/77_blackwell_fmha/kernel/fmha_causal_tile_scheduler.hpp
Normal file
197
examples/77_blackwell_fmha/kernel/fmha_causal_tile_scheduler.hpp
Normal file
@ -0,0 +1,197 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Swizzle Q tile and H tile to improve L2 cache hit rate,
|
||||
// and launch the longest main loop first to keep most SMs busy.
|
||||
|
||||
struct CausalIndividualTileScheduler {
|
||||
|
||||
static constexpr int TileQ = 16;
|
||||
static constexpr int TileH = 8;
|
||||
static constexpr int TileSize = TileQ * TileH;
|
||||
|
||||
struct Params {
|
||||
dim3 grid;
|
||||
int tile_max_q;
|
||||
FastDivmod divmod_tile_col;
|
||||
FastDivmod divmod_tile_size;
|
||||
FastDivmod divmod_tile_head;
|
||||
};
|
||||
|
||||
bool valid_ = true;
|
||||
Params params;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
CausalIndividualTileScheduler(Params const& params) : params(params) {}
|
||||
|
||||
template<class ProblemSize, class ClusterShape, class TileShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
|
||||
ClusterShape const& cluster_shape, TileShape const& tile_shape) {
|
||||
using namespace cute;
|
||||
|
||||
dim3 grid(size<3,0>(problem_size), round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)), size<3,1>(problem_size));
|
||||
// gridDim.x must multiple of TileH
|
||||
const int tile_col_count = grid.x / TileH;
|
||||
const int tile_max_q = grid.y / TileQ * TileQ;
|
||||
return Params{ grid , tile_max_q, tile_col_count, TileSize, TileH};
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
return params.grid;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_valid() {
|
||||
return valid_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
auto get_block_coord() {
|
||||
using namespace cute;
|
||||
const int block_idx = blockIdx.y * gridDim.x + blockIdx.x;
|
||||
|
||||
int tile_idx, tile_tail;
|
||||
params.divmod_tile_size(tile_idx, tile_tail, block_idx);
|
||||
|
||||
int tile_row_idx, tile_col_idx;
|
||||
params.divmod_tile_col(tile_row_idx,tile_col_idx, tile_idx);
|
||||
|
||||
int row_offset_in_tail, col_offset_in_tail;
|
||||
params.divmod_tile_head(row_offset_in_tail,col_offset_in_tail, tile_tail);
|
||||
|
||||
const int row_idx = tile_row_idx * TileQ + row_offset_in_tail;
|
||||
const int col_idx = tile_col_idx * TileH + col_offset_in_tail;
|
||||
|
||||
// last q tile launch first
|
||||
if(blockIdx.y >= params.tile_max_q) {
|
||||
return make_coord(int(gridDim.y - 1 - blockIdx.y), _0{}, make_coord(int(blockIdx.x), int(blockIdx.z)));
|
||||
}
|
||||
|
||||
return make_coord(int(gridDim.y) - 1 - row_idx, _0{}, make_coord(col_idx, int(blockIdx.z)));
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
CausalIndividualTileScheduler& operator++() {
|
||||
valid_ = false;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Launch order: H Q B
|
||||
struct CausalPersistentTileScheduler {
|
||||
|
||||
struct Params {
|
||||
int num_blocks;
|
||||
FastDivmod divmod_h;
|
||||
FastDivmod divmod_m_block;
|
||||
FastDivmod divmod_b;
|
||||
|
||||
KernelHardwareInfo hw_info;
|
||||
};
|
||||
|
||||
int block_idx = 0;
|
||||
Params params;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
CausalPersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {}
|
||||
|
||||
template<class ProblemSize, class ClusterShape, class TileShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
|
||||
ClusterShape const& cluster_shape, TileShape const& tile_shape) {
|
||||
using namespace cute;
|
||||
// Get SM count if needed, otherwise use user supplied SM count
|
||||
int sm_count = hw_info.sm_count;
|
||||
if (sm_count <= 0) {
|
||||
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
|
||||
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
|
||||
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
|
||||
hw_info.sm_count = sm_count;
|
||||
|
||||
int num_m_blocks = cutlass::round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape));
|
||||
int num_blocks = num_m_blocks * size<3,0>(problem_size) * size<3,1>(problem_size);
|
||||
|
||||
return Params {
|
||||
num_blocks,
|
||||
{ size<3,0>(problem_size) }, { num_m_blocks}, { size<3,1>(problem_size) },
|
||||
hw_info
|
||||
};
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1);
|
||||
return grid;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_valid() {
|
||||
return block_idx < params.num_blocks;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
auto get_block_coord() {
|
||||
using namespace cute;
|
||||
int block_decode = block_idx;
|
||||
int m_block, bidb, bidh;
|
||||
params.divmod_h(block_decode, bidh, block_decode);
|
||||
params.divmod_m_block(block_decode, m_block, block_decode);
|
||||
params.divmod_b(block_decode, bidb, block_decode);
|
||||
return make_coord(m_block, _0{}, make_coord(bidh, bidb));
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
CausalPersistentTileScheduler& operator++() {
|
||||
block_idx += gridDim.x;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
||||
@ -1245,7 +1245,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
};
|
||||
|
||||
bool leading_causal_masking = false;
|
||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask, Mask>) {
|
||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>
|
||||
|| std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
|
||||
leading_causal_masking = warp_uniform(iter_index == get<1>(blk_coord));
|
||||
}
|
||||
bool trailing_residual_masking = false;
|
||||
@ -1682,7 +1683,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
);
|
||||
int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{});
|
||||
int iter_start = 0;
|
||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask, Mask>) {
|
||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask> ||
|
||||
std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
|
||||
iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{};
|
||||
}
|
||||
if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) {
|
||||
|
||||
@ -28,6 +28,7 @@
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/layout.hpp"
|
||||
@ -38,6 +39,7 @@
|
||||
|
||||
#include "kernel/fmha_options.hpp"
|
||||
#include "kernel/fmha_tile_scheduler.hpp"
|
||||
#include "kernel/fmha_causal_tile_scheduler.hpp"
|
||||
#include "collective/fmha_fusion.hpp"
|
||||
#include "collective/fmha_common.hpp"
|
||||
|
||||
@ -79,6 +81,45 @@ struct Sm100FmhaCtxKernelWarpspecializedSchedule {
|
||||
static const int NumRegsCorrection = 96 - (kDebugUsingPrintf ? 16 : 0);
|
||||
static const int NumRegsOther = 32 + (kDebugUsingPrintf ? 16 : 0);
|
||||
static const int NumRegsEmpty = 24;
|
||||
|
||||
static const int NumWarps = 16;
|
||||
|
||||
};
|
||||
|
||||
|
||||
struct Sm100MlaFwdCtxKernelWarpspecializedSchedule {
|
||||
|
||||
enum class WarpRole {
|
||||
Softmax0,
|
||||
Softmax1,
|
||||
Correction,
|
||||
MMA,
|
||||
Load,
|
||||
Epilogue,
|
||||
Empty
|
||||
};
|
||||
|
||||
static constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) {
|
||||
int wg_idx = warp_idx / 4; // warp_idx
|
||||
if (wg_idx == 0) return WarpRole::Softmax0; // 0 - 3
|
||||
if (wg_idx == 1) return WarpRole::Softmax1; // 4 - 7
|
||||
if (wg_idx == 2) return WarpRole::Correction; // 8 - 11
|
||||
if (warp_idx == 12) return WarpRole::MMA; // 12
|
||||
if (warp_idx == 13) return WarpRole::Load; // 13
|
||||
if (warp_idx == 14) return WarpRole::Epilogue; // 14
|
||||
return WarpRole::Empty; // 15
|
||||
}
|
||||
|
||||
static const int NumWarpsSoftmax = 4;
|
||||
static const int NumWarpsCorrection = 4;
|
||||
static const int NumWarpsEpilogue = 1;
|
||||
static const int NumWarpsLoad = 1;
|
||||
|
||||
static const bool kDebugUsingPrintf = false;
|
||||
static const int NumRegsSoftmax = 184;
|
||||
static const int NumRegsCorrection = 96 - (kDebugUsingPrintf ? 16 : 0);
|
||||
static const int NumRegsOther = 48 + (kDebugUsingPrintf ? 16 : 0);
|
||||
static const int NumRegsEmpty = 24;
|
||||
|
||||
static const int NumWarps = 16;
|
||||
|
||||
@ -106,6 +147,9 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
static const int NumWarpsCorrection = KernelSchedule::NumWarpsCorrection;
|
||||
static const int NumWarpsEpilogue = KernelSchedule::NumWarpsEpilogue;
|
||||
static const int NumWarpsLoad = KernelSchedule::NumWarpsLoad;
|
||||
|
||||
static_assert(NumWarpsEpilogue == CollectiveEpilogue::NumWarpsEpilogue);
|
||||
static_assert(NumWarpsLoad == CollectiveEpilogue::NumWarpsLoad);
|
||||
|
||||
static const int NumRegsSoftmax = KernelSchedule::NumRegsSoftmax;
|
||||
static const int NumRegsCorrection = KernelSchedule::NumRegsCorrection;
|
||||
@ -114,13 +158,31 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
|
||||
static const int NumWarps = KernelSchedule::NumWarps;
|
||||
|
||||
static constexpr bool IsMla = std::is_same_v<KernelSchedule, Sm100MlaFwdCtxKernelWarpspecializedSchedule>;
|
||||
|
||||
using ClusterShape = typename CollectiveMainloop::ClusterShape;
|
||||
|
||||
using TmemAllocator = cute::TMEM::Allocator1Sm;
|
||||
|
||||
struct SharedStorage {
|
||||
typename CollectiveMainloop::TensorStorage mainloop;
|
||||
typename CollectiveEpilogue::TensorStorage epilogue;
|
||||
using UnionType = union {
|
||||
typename CollectiveMainloop::TensorStorage mainloop;
|
||||
typename CollectiveEpilogue::TensorStorage epilogue;
|
||||
};
|
||||
|
||||
using StructType = struct {
|
||||
typename CollectiveMainloop::TensorStorage mainloop;
|
||||
typename CollectiveEpilogue::TensorStorage epilogue;
|
||||
};
|
||||
|
||||
static constexpr bool IsPersistent = std::is_same_v<TileScheduler, PersistentTileScheduler> || std::is_same_v<TileScheduler, CausalPersistentTileScheduler>;
|
||||
using MainloopEpilogueStorage = std::conditional_t<IsPersistent,
|
||||
std::conditional_t<IsMla,
|
||||
std::conditional_t<CollectiveMainloop::IsOrderLoadEpilogue, UnionType, StructType>,
|
||||
StructType>,
|
||||
UnionType>;
|
||||
|
||||
MainloopEpilogueStorage mainloop_epilogue;
|
||||
|
||||
struct PipelineStorage {
|
||||
alignas(16) typename CollectiveMainloop::PipelineQ::SharedStorage load_q;
|
||||
@ -206,6 +268,16 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
|
||||
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem);
|
||||
|
||||
auto get_epilogue_storage = [&]() {
|
||||
if constexpr (IsMla && CollectiveMainloop::IsOrderLoadEpilogue) {
|
||||
return reinterpret_cast<typename CollectiveEpilogue::TensorStorage *>(shared_storage.mainloop_epilogue.mainloop.smem_o.data());
|
||||
} else {
|
||||
return &shared_storage.mainloop_epilogue.epilogue;
|
||||
}
|
||||
};
|
||||
typename CollectiveEpilogue::TensorStorage & epilogue_storage = *get_epilogue_storage();
|
||||
|
||||
|
||||
typename CollectiveMainloop::PipelineQ::Params pipeline_load_q_params;
|
||||
if (role == WarpRole::Load) {
|
||||
pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Producer;
|
||||
@ -228,7 +300,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Consumer;
|
||||
}
|
||||
pipeline_load_kv_params.is_leader = lane_predicate && (role == WarpRole::Load);
|
||||
pipeline_load_kv_params.transaction_bytes = CollectiveMainloop::TransactionBytesLoadKV;
|
||||
pipeline_load_kv_params.transaction_bytes = CollectiveMainloop::TransactionBytesLoadK;
|
||||
typename CollectiveMainloop::PipelineKV pipeline_load_kv(
|
||||
shared_storage.pipelines.load_kv,
|
||||
pipeline_load_kv_params,
|
||||
@ -409,7 +481,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
blk_coord,
|
||||
params.mainloop, logical_problem_shape,
|
||||
params.problem_shape,
|
||||
shared_storage.epilogue,
|
||||
epilogue_storage,
|
||||
pipeline_corr_epi, pipeline_corr_epi_producer_state,
|
||||
epilogue
|
||||
);
|
||||
@ -420,7 +492,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
blk_coord,
|
||||
params.mainloop, logical_problem_shape,
|
||||
params.problem_shape,
|
||||
shared_storage.epilogue,
|
||||
epilogue_storage,
|
||||
pipeline_s0_corr, pipeline_s0_corr_consumer_state,
|
||||
pipeline_s1_corr, pipeline_s1_corr_consumer_state,
|
||||
pipeline_mma_corr, pipeline_mma_corr_consumer_state,
|
||||
@ -462,7 +534,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
mainloop.mma(
|
||||
blk_coord,
|
||||
params.mainloop, logical_problem_shape,
|
||||
shared_storage.mainloop,
|
||||
shared_storage.mainloop_epilogue.mainloop,
|
||||
pipeline_load_q, pipeline_load_q_consumer_state,
|
||||
pipeline_load_kv, pipeline_load_kv_consumer_state,
|
||||
pipeline_mma_s0, pipeline_mma_s0_producer_state,
|
||||
@ -475,6 +547,11 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
else if (role == WarpRole::Load) {
|
||||
warpgroup_reg_set<NumRegsOther>();
|
||||
|
||||
if constexpr (IsMla && CollectiveMainloop::IsOrderLoadEpilogue) {
|
||||
cutlass::arch::NamedBarrier::arrive((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp,
|
||||
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
@ -493,7 +570,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
mainloop.load(
|
||||
blk_coord, logical_problem_shape,
|
||||
params.mainloop, params.problem_shape,
|
||||
shared_storage.mainloop,
|
||||
shared_storage.mainloop_epilogue.mainloop,
|
||||
pipeline_load_q, pipeline_load_q_producer_state,
|
||||
pipeline_load_kv, pipeline_load_kv_producer_state
|
||||
);
|
||||
@ -517,7 +594,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
epilogue.store(
|
||||
blk_coord, logical_problem_shape,
|
||||
params.epilogue, params.problem_shape,
|
||||
shared_storage.epilogue,
|
||||
epilogue_storage,
|
||||
pipeline_corr_epi, pipeline_corr_epi_consumer_state
|
||||
);
|
||||
|
||||
|
||||
@ -59,16 +59,34 @@ void __global__ fmha_reference_kernel(
|
||||
extern __shared__ char mS_mem[];
|
||||
ElementAccumulator* mS = reinterpret_cast<ElementAccumulator*>(mS_mem);
|
||||
|
||||
ElementAccumulator softmax_scale = static_cast<ElementAccumulator>(1.0 / sqrt(1.0 * size<1>(mO)));
|
||||
ElementAccumulator softmax_scale = static_cast<ElementAccumulator>(1.0 / sqrt(1.0 * size<1>(mQ)));
|
||||
|
||||
auto id = make_identity_tensor(make_shape(1, 1));
|
||||
for (int idx_L = blockIdx.y; idx_L < size<3>(problem_shape_in); idx_L += gridDim.y) {
|
||||
for (int idx_Q = blockIdx.x; idx_Q < size<0>(problem_shape_in); idx_Q += gridDim.x) {
|
||||
|
||||
|
||||
auto coord_L = idx2crd(idx_L, shape<3>(problem_shape_in));
|
||||
auto coord_in = cute::make_tuple(idx_Q, _0{}, _0{}, coord_L);
|
||||
auto get_coord_in = [&]() {
|
||||
if constexpr (rank_v<decltype(get<2>(ProblemShapeIn{}))> == 2) {
|
||||
return cute::make_tuple(idx_Q, _0{}, cute::make_tuple(_0{}, _0{}), coord_L);
|
||||
} else {
|
||||
return cute::make_tuple(idx_Q, _0{}, _0{}, coord_L);
|
||||
}
|
||||
};
|
||||
auto coord_in = get_coord_in();
|
||||
auto [problem_shape, coord] = apply_variable_length(problem_shape_in, coord_in, get<3,1>(coord_in));
|
||||
|
||||
int head_qk = 0;
|
||||
int head_v = 0;
|
||||
if constexpr (rank_v<decltype(get<2>(problem_shape))> == 2) {
|
||||
// MLA case: head_qk 192, head_v = 128
|
||||
head_qk = size<2, 0>(problem_shape) + size<2, 1>(problem_shape);
|
||||
head_v = size<2, 0>(problem_shape);
|
||||
} else {
|
||||
head_qk = size<2>(problem_shape);
|
||||
head_v = head_qk;
|
||||
}
|
||||
|
||||
if (get<0,0>(coord) >= get<0>(problem_shape)) continue;
|
||||
|
||||
int offset_Q = 0;
|
||||
@ -82,7 +100,7 @@ void __global__ fmha_reference_kernel(
|
||||
}
|
||||
|
||||
if (get<1>(problem_shape) == 0) {
|
||||
for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) {
|
||||
for (int idx_D = threadIdx.x; idx_D < head_qk; idx_D += blockDim.x) {
|
||||
mO(idx_Q + offset_Q, idx_D, idx_L) = Element(0);
|
||||
}
|
||||
|
||||
@ -94,7 +112,7 @@ void __global__ fmha_reference_kernel(
|
||||
|
||||
for (int idx_K = threadIdx.x; idx_K < size<1>(problem_shape); idx_K += blockDim.x) {
|
||||
ElementAccumulator acc = 0;
|
||||
for (int idx_D = 0; idx_D < size<2>(problem_shape); idx_D++) {
|
||||
for (int idx_D = 0; idx_D < head_qk; idx_D++) {
|
||||
ElementAccumulator eQ = mQ(idx_Q + offset_Q, idx_D, idx_L);
|
||||
ElementAccumulator eK = mK(idx_K + offset_K, idx_D, idx_L);
|
||||
acc += eQ * eK;
|
||||
@ -128,7 +146,8 @@ void __global__ fmha_reference_kernel(
|
||||
|
||||
ElementAccumulator scale = 1.0f / sum;
|
||||
|
||||
for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) {
|
||||
|
||||
for (int idx_D = threadIdx.x; idx_D < head_v; idx_D += blockDim.x) {
|
||||
ElementAccumulator acc = 0;
|
||||
for (int idx_K = 0; idx_K < size<1>(problem_shape); idx_K++) {
|
||||
ElementAccumulator eV = mV(idx_K + offset_K, idx_D, idx_L);
|
||||
|
||||
Reference in New Issue
Block a user