v4.0 update. (#2398)

* Ex77 fix.
This commit is contained in:
Junkai-Wu
2025-06-12 21:10:29 +08:00
committed by GitHub
parent 5c6bca0441
commit dc4817921e
4 changed files with 142 additions and 27 deletions

View File

@ -116,6 +116,8 @@ struct Options {
int h_k = 1;
int q = 256;
int k = 256;
std::vector<int> varlen_q;
std::vector<int> varlen_k;
int d = 128;
int warmup_iterations = 1;
int iterations = 3;
@ -181,13 +183,76 @@ struct Options {
cmd.get_cmd_line_argument("h_k", h_k, -1);
if (h_k == -1) h_k = h;
varlen = cmd.check_cmd_line_flag("varlen");
cmd.get_cmd_line_argument("q", q, -1);
cmd.get_cmd_line_argument("k", k, -1);
cmd.get_cmd_line_argument("b", b, -1);
std::string varlen_q_str;
cmd.get_cmd_line_argument("varlen-q", varlen_q_str);
std::string varlen_k_str;
cmd.get_cmd_line_argument("varlen-k", varlen_k_str);
if (varlen && ! varlen_q_str.empty()) {
varlen_q.clear();
while (! varlen_q_str.empty()) {
size_t pos = varlen_q_str.find(':');
varlen_q.push_back(std::stoi(varlen_q_str.substr(0, pos)));
if (pos == std::string::npos) {
break;
}
varlen_q_str = varlen_q_str.substr(pos + 1);
}
if (b == -1) {
b = static_cast<int>(varlen_q.size());
}
if (b != static_cast<int>(varlen_q.size())) {
std::cout << "Error: Invalid --varlen-q length\n";
std::exit(-1);
}
int new_q = 0;
for (auto elem : varlen_q) {
new_q += elem;
}
if (q != -1) {
std::cout << "Error: Can't provide --q and --varlen-q\n";
std::exit(-1);
}
q = new_q;
}
if (varlen && ! varlen_k_str.empty()) {
varlen_k.clear();
while (! varlen_k_str.empty()) {
size_t pos = varlen_k_str.find(':');
varlen_k.push_back(std::stoi(varlen_k_str.substr(0, pos)));
if (pos == std::string::npos) {
break;
}
varlen_k_str = varlen_k_str.substr(pos + 1);
}
if (b == -1) {
b = static_cast<int>(varlen_k.size());
}
if (b != static_cast<int>(varlen_k.size())) {
std::cout << " Error: Invalid --varlen-k length\n";
std::exit(-1);
}
int new_k = 0;
for (auto elem : varlen_k) {
new_k += elem;
}
if (k != -1) {
std::cout << "Error: Can't provide --k and --varlen-k\n";
std::exit(-1);
}
k = new_k;
}
if (q == -1) q = k;
if (k == -1) k = q;
if (q == -1 && k == -1) q = k = defaults.q;
cmd.get_cmd_line_argument("b", b, -1);
if (b == -1) b = 16384 / k;
if (b == 0) b = 1;
@ -197,7 +262,6 @@ struct Options {
verify = cmd.check_cmd_line_flag("verify");
verbose = cmd.check_cmd_line_flag("verbose");
varlen = cmd.check_cmd_line_flag("varlen");
persistent = cmd.check_cmd_line_flag("persistent");
std::string mask;
@ -240,7 +304,9 @@ struct Options {
<< " --h_k=<int> Sets the H_K/V extent (for GQA/MQA)\n"
<< " --q=<int> Sets the Q extent\n"
<< " --k=<int> Sets the K extent\n"
<< " --d=<int> Sets the D extentn"
<< " --varlen-q=<int>:<int...> Sets the variable Q extent per batch (colon separated)\n"
<< " --varlen-k=<int>:<int...> Sets the variable K extent per batch (colon separated)\n"
<< " --d=<int> Sets the D extent\n"
<< " --tensor_ring_buffers=<int> Sets the number of tensor ring buffers\n"
<< " --warmup_iterations=<int> Sets the warmup iterations\n"
<< " --iterations=<int> Benchmarking iterations\n"
@ -475,7 +541,10 @@ struct FwdRunner {
}
template<class ProblemShape>
auto initialize_varlen(const ProblemShape& problem_size, const bool kVarlenSame = true) {
auto initialize_varlen(
const Options& options, const ProblemShape& problem_size,
const bool kVarlenSame = true) {
int num_batches = get<3,1>(problem_size);
// generate Q as --b times
@ -503,8 +572,12 @@ struct FwdRunner {
int max_seqlen_kv = 0;
for (int i = 0; i < num_batches; i++) {
int seqlen_q = kVarlenSame ? get<0>(problem_size) : generate_positive_int(dist_q, rng);
int seqlen_kv = kVarlenSame ? get<1>(problem_size) : generate_positive_int(dist_kv, rng);
int seqlen_q = (! options.varlen_q.empty()) ? options.varlen_q.at(i) :
kVarlenSame ? get<0>(problem_size) :
generate_positive_int(dist_q, rng);
int seqlen_kv = (! options.varlen_k.empty()) ? options.varlen_k.at(i) :
kVarlenSame ? get<1>(problem_size) :
generate_positive_int(dist_kv, rng);
total_seqlen_q += seqlen_q;
total_seqlen_kv += seqlen_kv;
@ -545,7 +618,7 @@ struct FwdRunner {
decltype(problem_shape_in) problem_size;
if constexpr (kIsVarlen) {
auto [problem_shape_init, problem_shape_launch] = initialize_varlen(problem_shape_in);
auto [problem_shape_init, problem_shape_launch] = initialize_varlen(options, problem_shape_in);
problem_shape = problem_shape_launch;
problem_size = problem_shape_init;
}
@ -588,6 +661,8 @@ struct FwdRunner {
buffer.block_V.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
buffer.block_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
buffer.block_LSE.reset(size(shape_LSE));
buffer.block_ref_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
buffer.block_ref_LSE.reset(size(shape_LSE));
initialize_block(buffer.block_Q, seed + 2023, options.init_style_q);
initialize_block(buffer.block_K, seed + 2022, options.init_style_k);

View File

@ -43,6 +43,22 @@ set(TEST_VARLEN --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=residual --v
set(TEST_HDIM64 --b=2 --h=4 --q=512 --k=512 --d=64 --verify)
set(TEST_GQA --b=2 --h=4 --h_k=2 --q=512 --k=512 --d=64 --verify)
set(TEST_VARLEN_00 --verify --mask=causal,residual --d=128 --h=8 --h_k=4 --varlen-q=128 --varlen-k=128)
set(TEST_VARLEN_01 --verify --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
set(TEST_VARLEN_02 --verify --mask=causal,residual --d=128 --h=4 --h_k=2 --varlen-q=128 --varlen-k=128)
set(TEST_VARLEN_03 --verify --mask=causal,residual --d=128 --h=8 --h_k=8 --varlen-q=256:256 --varlen-k=512:512)
set(TEST_VARLEN_04 --verify --mask=causal,residual --d=128 --h=8 --h_k=4 --varlen-q=256:256 --varlen-k=512:512)
set(TEST_VARLEN_05 --verify --mask=causal,residual --d=128 --h=8 --h_k=1 --varlen-q=256:256 --varlen-k=512:512)
set(TEST_VARLEN_06 --verify --mask=causal,residual --d=128 --h=8 --h_k=2 --varlen-q=256:256:256:256 --varlen-k=256:768:512:512)
set(TEST_VARLEN_07 --verify --mask=causal,residual --d=128 --h=8 --h_k=2 --varlen-q=256:256:256:256 --varlen-k=256:0:1280:512)
set(TEST_VARLEN_08 --verify --mask=causal,residual --d=128 --h=8 --h_k=2 --varlen-q=256:0:512:256 --varlen-k=256:256:1024:512)
set(TEST_VARLEN_09 --verify --mask=causal,residual --d=64 --h=16 --h_k=16 --varlen-q=100:300 --varlen-k=100:300)
set(TEST_VARLEN_10 --verify --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=3:2 --varlen-k=2:5)
set(TEST_VARLEN_11 --verify --mask=causal,residual --d=64 --h=4 --h_k=2 --varlen-q=17:10 --varlen-k=13:10)
set(TEST_VARLEN_12 --verify --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=177:845 --varlen-k=257:766)
set(TEST_VARLEN_13 --verify --mask=causal,residual --d=64 --h=4 --h_k=2 --varlen-q=177:366:479 --varlen-k=257:0:766)
set(TEST_VARLEN_14 --verify --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=1 --varlen-k=1)
set(TEST_GEN_BASIC --b=1 --h=4 --k=512 --d=128 --verify)
set(TEST_GEN_VARLEN --b=1 --h=4 --k=512 --d=128 --verify --varlen)
set(TEST_GEN_HDIM64 --b=2 --h=4 --k=512 --d=64 --verify)
@ -62,10 +78,25 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
77_blackwell_fmha.cu
TEST_COMMAND_OPTIONS
TEST_BASIC
# TEST_CAUSAL
# TEST_VARLEN
# TEST_HDIM64
# TEST_GQA)
TEST_CAUSAL
TEST_VARLEN
TEST_HDIM64
TEST_GQA
TEST_VARLEN_00
TEST_VARLEN_01
TEST_VARLEN_02
TEST_VARLEN_03
TEST_VARLEN_04
TEST_VARLEN_05
TEST_VARLEN_06
TEST_VARLEN_07
TEST_VARLEN_08
TEST_VARLEN_09
TEST_VARLEN_10
TEST_VARLEN_11
TEST_VARLEN_12
TEST_VARLEN_13
TEST_VARLEN_14
)
target_include_directories(77_blackwell_fmha_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_fmha_${PREC} PRIVATE ${PREC_MACRO})
@ -75,11 +106,11 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
77_blackwell_fmha_gen.cu
TEST_COMMAND_OPTIONS
TEST_GEN_BASIC
# TEST_GEN_VARLEN
# TEST_GEN_HDIM64
# TEST_GEN_GQA
# TEST_GEN_REMAP
# TEST_GEN_CACHEONLY)
TEST_GEN_VARLEN
TEST_GEN_HDIM64
TEST_GEN_GQA
TEST_GEN_REMAP
TEST_GEN_CACHEONLY
)
target_include_directories(77_blackwell_fmha_gen_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_fmha_gen_${PREC} PRIVATE ${PREC_MACRO})
@ -119,11 +150,7 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
77_blackwell_fmha_bwd.cu
TEST_COMMAND_OPTIONS
TEST_BASIC
# TEST_GEN_VARLEN
# TEST_GEN_HDIM64
# TEST_GEN_GQA
# TEST_GEN_REMAP
# TEST_GEN_CACHEONLY)
TEST_VARLEN
)
target_include_directories(77_blackwell_fmha_bwd_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_fmha_bwd_${PREC} PRIVATE ${PREC_MACRO})
@ -160,7 +187,5 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
77_blackwell_mla_b2b_2sm_fp16
77_blackwell_fmha_bwd_fp8
77_blackwell_fmha_bwd_fp16
77_blackwell_fmha_bwd_sat_fp8
77_blackwell_fmha_bwd_sat_fp16
)
endif()

View File

@ -942,11 +942,15 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
}
}
template<class BlkCoord, class ProblemShape, class TensorStorageEpi, class CollectiveEpilogue>
template<
class BlkCoord, class ProblemShape, class ParamsProblemShape,
class TensorStorageEpi, class CollectiveEpilogue
>
CUTLASS_DEVICE auto
correction(
BlkCoord const& blk_coord,
Params const& params, ProblemShape const& problem_shape,
ParamsProblemShape const& params_problem_shape,
TensorStorageEpi& shared_storage_epi,
PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state,
PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state,
@ -1068,10 +1072,15 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
if (epilogue.params.ptr_LSE != nullptr) {
int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord);
int row_offset = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
}
ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);
if (row_idx < get<0>(problem_shape)) {
gLSE(row_idx, get<2>(blk_coord)) = lse;
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
}
}
@ -1101,8 +1110,13 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);
int row_offset = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
}
if (row_idx < get<0>(problem_shape)) {
gLSE(row_idx, get<2>(blk_coord)) = lse;
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
}
}

View File

@ -403,6 +403,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
mainloop.correction(
blk_coord,
params.mainloop, logical_problem_shape,
params.problem_shape,
shared_storage.epilogue,
pipeline_s0_corr, pipeline_s0_corr_consumer_state,
pipeline_s1_corr, pipeline_s1_corr_consumer_state,