Compare commits

..

1 Commits

Author SHA1 Message Date
58a5197b9d "Update CHANGELOG for 4.0 tagging" 2025-06-06 09:43:11 -04:00
17 changed files with 94 additions and 312 deletions

1
.github/CODEOWNERS vendored
View File

@ -1 +0,0 @@
# This file defines code ownership rules for the repository.

View File

@ -18,7 +18,7 @@
- [Ampere GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/tensorop_gemm.py)
- [FlashAttention-2 implementation targeting Ampere and Ada class GPUs (SM80, SM86, SM89)](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/flash_attention_v2.py)
- [SmemAllocator to facilitate shared memory allocation and management](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/smem_allocator.py)
- [C-structure based customized interface between JIT function and user codes](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/cute/ffi/jit_argument.py)
- [C-structure based customized interface between JIT function and user codes](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/cute/jit_argument.py)
* [Educational notebooks for getting started with CuTe DSL](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks)
* API updates
- Fixed API mismatch in class ``cute.runtime.Pointer``: change ``element_type`` to ``dtype`` to match ``typing.Pointer``

View File

@ -59,7 +59,7 @@ To get started quickly - please refer :
- [Ampere GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/tensorop_gemm.py)
- [FlashAttention-2 implementation targeting Ampere and Ada class GPUs (SM80, SM86, SM89)](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/flash_attention_v2.py)
- [SmemAllocator to facilitate shared memory allocation and management](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/smem_allocator.py)
- [C-structure based customized interface between JIT function and user codes](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/cute/ffi/jit_argument.py)
- [C-structure based customized interface between JIT function and user codes](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/cute/jit_argument.py)
* [Educational notebooks for getting started with CuTe DSL](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks)
* API updates
- Fixed API mismatch in class ``cute.runtime.Pointer``: change ``element_type`` to ``dtype`` to match ``typing.Pointer``

View File

@ -116,8 +116,6 @@ struct Options {
int h_k = 1;
int q = 256;
int k = 256;
std::vector<int> varlen_q;
std::vector<int> varlen_k;
int d = 128;
int warmup_iterations = 1;
int iterations = 3;
@ -183,76 +181,13 @@ struct Options {
cmd.get_cmd_line_argument("h_k", h_k, -1);
if (h_k == -1) h_k = h;
varlen = cmd.check_cmd_line_flag("varlen");
cmd.get_cmd_line_argument("q", q, -1);
cmd.get_cmd_line_argument("k", k, -1);
cmd.get_cmd_line_argument("b", b, -1);
std::string varlen_q_str;
cmd.get_cmd_line_argument("varlen-q", varlen_q_str);
std::string varlen_k_str;
cmd.get_cmd_line_argument("varlen-k", varlen_k_str);
if (varlen && ! varlen_q_str.empty()) {
varlen_q.clear();
while (! varlen_q_str.empty()) {
size_t pos = varlen_q_str.find(':');
varlen_q.push_back(std::stoi(varlen_q_str.substr(0, pos)));
if (pos == std::string::npos) {
break;
}
varlen_q_str = varlen_q_str.substr(pos + 1);
}
if (b == -1) {
b = static_cast<int>(varlen_q.size());
}
if (b != static_cast<int>(varlen_q.size())) {
std::cout << "Error: Invalid --varlen-q length\n";
std::exit(-1);
}
int new_q = 0;
for (auto elem : varlen_q) {
new_q += elem;
}
if (q != -1) {
std::cout << "Error: Can't provide --q and --varlen-q\n";
std::exit(-1);
}
q = new_q;
}
if (varlen && ! varlen_k_str.empty()) {
varlen_k.clear();
while (! varlen_k_str.empty()) {
size_t pos = varlen_k_str.find(':');
varlen_k.push_back(std::stoi(varlen_k_str.substr(0, pos)));
if (pos == std::string::npos) {
break;
}
varlen_k_str = varlen_k_str.substr(pos + 1);
}
if (b == -1) {
b = static_cast<int>(varlen_k.size());
}
if (b != static_cast<int>(varlen_k.size())) {
std::cout << " Error: Invalid --varlen-k length\n";
std::exit(-1);
}
int new_k = 0;
for (auto elem : varlen_k) {
new_k += elem;
}
if (k != -1) {
std::cout << "Error: Can't provide --k and --varlen-k\n";
std::exit(-1);
}
k = new_k;
}
if (q == -1) q = k;
if (k == -1) k = q;
if (q == -1 && k == -1) q = k = defaults.q;
cmd.get_cmd_line_argument("b", b, -1);
if (b == -1) b = 16384 / k;
if (b == 0) b = 1;
@ -262,6 +197,7 @@ struct Options {
verify = cmd.check_cmd_line_flag("verify");
verbose = cmd.check_cmd_line_flag("verbose");
varlen = cmd.check_cmd_line_flag("varlen");
persistent = cmd.check_cmd_line_flag("persistent");
std::string mask;
@ -304,9 +240,7 @@ struct Options {
<< " --h_k=<int> Sets the H_K/V extent (for GQA/MQA)\n"
<< " --q=<int> Sets the Q extent\n"
<< " --k=<int> Sets the K extent\n"
<< " --varlen-q=<int>:<int...> Sets the variable Q extent per batch (colon separated)\n"
<< " --varlen-k=<int>:<int...> Sets the variable K extent per batch (colon separated)\n"
<< " --d=<int> Sets the D extent\n"
<< " --d=<int> Sets the D extentn"
<< " --tensor_ring_buffers=<int> Sets the number of tensor ring buffers\n"
<< " --warmup_iterations=<int> Sets the warmup iterations\n"
<< " --iterations=<int> Benchmarking iterations\n"
@ -541,10 +475,7 @@ struct FwdRunner {
}
template<class ProblemShape>
auto initialize_varlen(
const Options& options, const ProblemShape& problem_size,
const bool kVarlenSame = true) {
auto initialize_varlen(const ProblemShape& problem_size, const bool kVarlenSame = true) {
int num_batches = get<3,1>(problem_size);
// generate Q as --b times
@ -572,12 +503,8 @@ struct FwdRunner {
int max_seqlen_kv = 0;
for (int i = 0; i < num_batches; i++) {
int seqlen_q = (! options.varlen_q.empty()) ? options.varlen_q.at(i) :
kVarlenSame ? get<0>(problem_size) :
generate_positive_int(dist_q, rng);
int seqlen_kv = (! options.varlen_k.empty()) ? options.varlen_k.at(i) :
kVarlenSame ? get<1>(problem_size) :
generate_positive_int(dist_kv, rng);
int seqlen_q = kVarlenSame ? get<0>(problem_size) : generate_positive_int(dist_q, rng);
int seqlen_kv = kVarlenSame ? get<1>(problem_size) : generate_positive_int(dist_kv, rng);
total_seqlen_q += seqlen_q;
total_seqlen_kv += seqlen_kv;
@ -618,7 +545,7 @@ struct FwdRunner {
decltype(problem_shape_in) problem_size;
if constexpr (kIsVarlen) {
auto [problem_shape_init, problem_shape_launch] = initialize_varlen(options, problem_shape_in);
auto [problem_shape_init, problem_shape_launch] = initialize_varlen(problem_shape_in);
problem_shape = problem_shape_launch;
problem_size = problem_shape_init;
}
@ -661,8 +588,6 @@ struct FwdRunner {
buffer.block_V.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
buffer.block_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
buffer.block_LSE.reset(size(shape_LSE));
buffer.block_ref_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
buffer.block_ref_LSE.reset(size(shape_LSE));
initialize_block(buffer.block_Q, seed + 2023, options.init_style_q);
initialize_block(buffer.block_K, seed + 2022, options.init_style_k);
@ -888,18 +813,11 @@ struct FwdRunner {
///////////////////////////////////////////////////////////////////////////////////////////////////
int main_result = 0;
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to print a description of the example run and its result
void print_result(const std::string& description, ExampleResult result, bool verbose) {
std::ios fmt(nullptr);
fmt.copyfmt(std::cout);
std::cout << (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] ");
if (! result.passed) {
main_result = -1;
}
std::cout << std::setw(32) << std::left << description;
std::cout.copyfmt(fmt);
std::cout << " : " << result.tflops_tc_s << " TFLOPS/s" << std::endl;
@ -1100,7 +1018,7 @@ int main_single(int argc, char const **args) {
});
#endif
return main_result;
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -1108,6 +1026,8 @@ int main_single(int argc, char const **args) {
int main(int argc, char const **args) {
std::vector<std::string> full_arguments(args, args + argc);
int result = 0;
bool recursed = false;
for (size_t i = 1; i < full_arguments.size(); i++) {
if (full_arguments[i].find(',') != std::string::npos) {
@ -1134,7 +1054,7 @@ int main(int argc, char const **args) {
main_single(argc, args);
}
return main_result;
return result;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -689,18 +689,11 @@ struct ExampleRunner {
///////////////////////////////////////////////////////////////////////////////////////////////////
int main_result = 0;
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to print a description of the example run and its result
void print_result(const std::string& description, ExampleResult result, bool verbose) {
std::ios fmt(nullptr);
fmt.copyfmt(std::cout);
std::cout << (result.supported ? (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] ") : "[NSUP] ");
if (result.supported && ! result.passed) {
main_result = -1;
}
std::cout << std::setw(32) << std::left << description;
std::cout.copyfmt(fmt);
std::cout << " : " << result.tbytes_s << " TB/s" << std::endl;
@ -788,17 +781,12 @@ int main_single(int argc, char const **args) {
std::integral_constant<KernelType, KernelType::MODE>{}, Shape<_##m, _##n, _##k>{}, Shape<_##tm, _##tn, _##tk>{} \
)
if (options.d == 128) {
RUN(UMMA_I, 128, 64, 128, 1, 1, 1);
RUN(UMMA_I, 128, 128, 128, 1, 1, 1);
RUN(UMMA_I, 128, 256, 128, 1, 1, 1);
RUN(UMMA_P, 128, 64, 128, 1, 1, 1);
RUN(UMMA_P, 128, 128, 128, 1, 1, 1);
RUN(UMMA_P, 128, 256, 128, 1, 1, 1);
}
else {
std::cout << "Head Dimension != 128 is not supported for the fmha_gen example\n";
}
RUN(UMMA_I, 128, 64, 128, 1, 1, 1);
RUN(UMMA_I, 128, 128, 128, 1, 1, 1);
RUN(UMMA_I, 128, 256, 128, 1, 1, 1);
RUN(UMMA_P, 128, 64, 128, 1, 1, 1);
RUN(UMMA_P, 128, 128, 128, 1, 1, 1);
RUN(UMMA_P, 128, 256, 128, 1, 1, 1);
#endif
return 0;
@ -809,6 +797,8 @@ int main_single(int argc, char const **args) {
int main(int argc, char const **args) {
std::vector<std::string> full_arguments(args, args + argc);
int result = 0;
bool recursed = false;
for (size_t i = 1; i < full_arguments.size(); i++) {
if (full_arguments[i].find(',') != std::string::npos) {
@ -835,7 +825,7 @@ int main(int argc, char const **args) {
main_single(argc, args);
}
return main_result;
return result;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -391,7 +391,11 @@ struct Runner {
// Check if output from CUTLASS kernel and reference kernel are equal or not
double max_diff = 0;
double mean_diff = 0;
#ifdef B2B
reference_rel_diff(block_O, block_ref_O, max_diff, mean_diff);
#else
reference_abs_diff(block_O, block_ref_O, max_diff, mean_diff);
#endif
bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
if (! passed_O) {
@ -400,6 +404,7 @@ struct Runner {
}
bool passed_LSE = true;
#ifndef B2B
reference_abs_diff(block_LSE, block_ref_LSE, max_diff, mean_diff);
passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
@ -407,6 +412,7 @@ struct Runner {
std::cerr << "failed LSE: max diff " << max_diff
<< " mean " << mean_diff << std::endl;
}
#endif
return passed_O && passed_LSE;
}
@ -672,18 +678,11 @@ struct Runner {
///////////////////////////////////////////////////////////////////////////////////////////////////
int main_result = 0;
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to print a description of the example run and its result
void print_result(const std::string& description, ExampleResult result, bool verbose) {
std::ios fmt(nullptr);
fmt.copyfmt(std::cout);
std::cout << (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] ");
if (! result.passed) {
main_result = -1;
}
std::cout << std::setw(32) << std::left << description;
std::cout.copyfmt(fmt);
std::cout << " : " << result.tflops_tc_s << " TFLOPS/s " << result.tbytes_s << " TB/s" << std::endl;
@ -807,6 +806,8 @@ int main_single(int argc, char const **args) {
int main(int argc, char const **args) {
std::vector<std::string> full_arguments(args, args + argc);
int result = 0;
bool recursed = false;
for (size_t i = 1; i < full_arguments.size(); i++) {
if (full_arguments[i].find(',') != std::string::npos) {
@ -833,7 +834,7 @@ int main(int argc, char const **args) {
main_single(argc, args);
}
return main_result;
return result;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -43,30 +43,14 @@ set(TEST_VARLEN --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=residual --v
set(TEST_HDIM64 --b=2 --h=4 --q=512 --k=512 --d=64 --verify)
set(TEST_GQA --b=2 --h=4 --h_k=2 --q=512 --k=512 --d=64 --verify)
set(TEST_VARLEN_00 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=4 --varlen-q=128 --varlen-k=128)
set(TEST_VARLEN_01 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
set(TEST_VARLEN_02 --verify --varlen --mask=causal,residual --d=128 --h=4 --h_k=2 --varlen-q=128 --varlen-k=128)
set(TEST_VARLEN_03 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=8 --varlen-q=256:256 --varlen-k=512:512)
set(TEST_VARLEN_04 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=4 --varlen-q=256:256 --varlen-k=512:512)
set(TEST_VARLEN_05 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=1 --varlen-q=256:256 --varlen-k=512:512)
set(TEST_VARLEN_06 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=2 --varlen-q=256:256:256:256 --varlen-k=256:768:512:512)
set(TEST_VARLEN_07 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=2 --varlen-q=256:256:256:256 --varlen-k=256:0:1280:512)
set(TEST_VARLEN_08 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=2 --varlen-q=256:0:512:256 --varlen-k=256:256:1024:512)
set(TEST_VARLEN_09 --verify --varlen --mask=causal,residual --d=64 --h=16 --h_k=16 --varlen-q=100:300 --varlen-k=100:300)
set(TEST_VARLEN_10 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=3:2 --varlen-k=2:5)
set(TEST_VARLEN_11 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=2 --varlen-q=17:10 --varlen-k=13:10)
set(TEST_VARLEN_12 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=177:845 --varlen-k=257:766)
set(TEST_VARLEN_13 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=2 --varlen-q=177:366:479 --varlen-k=257:0:766)
set(TEST_VARLEN_14 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=1 --varlen-k=1)
set(TEST_GEN_BASIC --b=1 --h=4 --k=512 --d=128 --verify)
set(TEST_GEN_VARLEN --b=1 --h=4 --k=512 --d=128 --verify --varlen)
set(TEST_GEN_HDIM64 --b=2 --h=4 --k=512 --d=64 --verify)
set(TEST_GEN_GQA --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify)
set(TEST_GEN_GQA --b=2 --h=4 --h_k=2 --k=512 --d=64 --verify)
set(TEST_GEN_REMAP --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --remap)
set(TEST_GEN_CACHEONLY --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --cache-only)
set(TEST_MLA_BASIC --b=1 --k=512 --page=128 --verify)
set(TEST_MLA_BASIC --b=1 --k=512 --verify)
if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC_ARCHS MATCHES 100a))
@ -78,25 +62,10 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
77_blackwell_fmha.cu
TEST_COMMAND_OPTIONS
TEST_BASIC
TEST_CAUSAL
TEST_VARLEN
TEST_HDIM64
TEST_GQA
TEST_VARLEN_00
TEST_VARLEN_01
TEST_VARLEN_02
TEST_VARLEN_03
TEST_VARLEN_04
TEST_VARLEN_05
TEST_VARLEN_06
TEST_VARLEN_07
TEST_VARLEN_08
TEST_VARLEN_09
TEST_VARLEN_10
TEST_VARLEN_11
TEST_VARLEN_12
TEST_VARLEN_13
TEST_VARLEN_14
# TEST_CAUSAL
# TEST_VARLEN
# TEST_HDIM64
# TEST_GQA)
)
target_include_directories(77_blackwell_fmha_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_fmha_${PREC} PRIVATE ${PREC_MACRO})
@ -106,11 +75,11 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
77_blackwell_fmha_gen.cu
TEST_COMMAND_OPTIONS
TEST_GEN_BASIC
TEST_GEN_VARLEN
# TEST_GEN_VARLEN
# TEST_GEN_HDIM64
TEST_GEN_GQA
TEST_GEN_REMAP
TEST_GEN_CACHEONLY
# TEST_GEN_GQA
# TEST_GEN_REMAP
# TEST_GEN_CACHEONLY)
)
target_include_directories(77_blackwell_fmha_gen_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_fmha_gen_${PREC} PRIVATE ${PREC_MACRO})
@ -135,12 +104,26 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
target_compile_definitions(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE ${PREC_MACRO} CPASYNC)
target_compile_options(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE -Xptxas -v)
cutlass_example_add_executable(
77_blackwell_mla_b2b_2sm_${PREC}
77_blackwell_mla.cu
TEST_COMMAND_OPTIONS
TEST_MLA_BASIC
)
target_include_directories(77_blackwell_mla_b2b_2sm_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_mla_b2b_2sm_${PREC} PRIVATE ${PREC_MACRO} B2B)
target_compile_options(77_blackwell_mla_b2b_2sm_${PREC} PRIVATE -Xptxas -v)
cutlass_example_add_executable(
77_blackwell_fmha_bwd_${PREC}
77_blackwell_fmha_bwd.cu
TEST_COMMAND_OPTIONS
TEST_BASIC
TEST_VARLEN
# TEST_GEN_VARLEN
# TEST_GEN_HDIM64
# TEST_GEN_GQA
# TEST_GEN_REMAP
# TEST_GEN_CACHEONLY)
)
target_include_directories(77_blackwell_fmha_bwd_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_fmha_bwd_${PREC} PRIVATE ${PREC_MACRO})
@ -173,7 +156,11 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
77_blackwell_mla_2sm_fp16
77_blackwell_mla_2sm_cpasync_fp8
77_blackwell_mla_2sm_cpasync_fp16
77_blackwell_mla_b2b_2sm_fp8
77_blackwell_mla_b2b_2sm_fp16
77_blackwell_fmha_bwd_fp8
77_blackwell_fmha_bwd_fp16
77_blackwell_fmha_bwd_sat_fp8
77_blackwell_fmha_bwd_sat_fp16
)
endif()

View File

@ -55,7 +55,6 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, replace<2>(TileShape{}, _2{}), Step<_2, _1, _3>{}));
using SmemLayoutO_ = SmemLayoutO;
using StrideLSE = StrideLSE_;
using ElementOut = Element;
struct TensorStorage {

View File

@ -942,15 +942,11 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
}
}
template<
class BlkCoord, class ProblemShape, class ParamsProblemShape,
class TensorStorageEpi, class CollectiveEpilogue
>
template<class BlkCoord, class ProblemShape, class TensorStorageEpi, class CollectiveEpilogue>
CUTLASS_DEVICE auto
correction(
BlkCoord const& blk_coord,
Params const& params, ProblemShape const& problem_shape,
ParamsProblemShape const& params_problem_shape,
TensorStorageEpi& shared_storage_epi,
PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state,
PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state,
@ -1065,22 +1061,17 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
// F2FP
// store to smem
Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{});
Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE);
Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), repeat_like(typename CollectiveEpilogue::StrideLSE{}, _1{}), epilogue.params.dLSE);
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO);
if (epilogue.params.ptr_LSE != nullptr) {
int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord);
int row_offset = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
}
ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);
if (row_idx < get<0>(problem_shape)) {
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
gLSE(row_idx, get<2>(blk_coord)) = lse;
}
}
@ -1110,13 +1101,8 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);
int row_offset = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
}
if (row_idx < get<0>(problem_shape)) {
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
gLSE(row_idx, get<2>(blk_coord)) = lse;
}
}
@ -1129,85 +1115,6 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
++pipeline_epi_producer_state;
}
template<
class BlkCoord, class ProblemShape, class ParamsProblemShape,
class TensorStorageEpi, class CollectiveEpilogue
>
CUTLASS_DEVICE auto
correction_empty(
BlkCoord const& blk_coord,
Params const& params, ProblemShape const& problem_shape,
ParamsProblemShape const& params_problem_shape,
TensorStorageEpi& shared_storage_epi,
PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state,
CollectiveEpilogue& epilogue) {
pipeline_epi.producer_acquire(pipeline_epi_producer_state);
Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{});
Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE);
float lse = -INFINITY;
int thread_idx = threadIdx.x % (4 * NumThreadsPerWarp);
#define DSHOW(x) print(#x ": "); print(x); print("\n")
if (threadIdx.x % 128 == 0 && block0()) {
DSHOW(sO);
}
#if 1
using ElementOut = typename CollectiveEpilogue::ElementOut;
auto tiled_copy = make_cotiled_copy(
Copy_Atom<UniversalCopy<uint32_t>, ElementOut>{},
make_ordered_layout(make_shape(_128{}, Int<sizeof(uint32_t) / sizeof(ElementOut)>{}), Step<_1, _0>{}),
sO.layout());
auto thr_copy = tiled_copy.get_slice(thread_idx);
auto tOgO = thr_copy.partition_D(sO);
auto tOrO = make_tensor<ElementOut>(shape(tOgO(_,_,_,_0{})));
clear(tOrO);
copy(tiled_copy, tOrO, tOgO(_,_,_,_0{}));
#endif
if (epilogue.params.ptr_LSE != nullptr) {
int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord);
int row_offset = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
}
if (row_idx < get<0>(problem_shape)) {
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
}
}
pipeline_epi.producer_commit(pipeline_epi_producer_state);
++pipeline_epi_producer_state;
copy(tiled_copy, tOrO, tOgO(_,_,_,_1{}));
cutlass::arch::fence_view_async_shared();
pipeline_epi.producer_acquire(pipeline_epi_producer_state);
if (epilogue.params.ptr_LSE != nullptr) {
int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{});
int row_offset = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
}
if (row_idx < get<0>(problem_shape)) {
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
}
}
cutlass::arch::fence_view_async_shared();
pipeline_epi.producer_commit(pipeline_epi_producer_state);
++pipeline_epi_producer_state;
}
};
} // namespace cutlass::fmha::collective

View File

@ -831,7 +831,7 @@ struct Sm100FmhaGenMainloopWarpspecialized {
// loop:
// TMEM_LOAD, TMEM_LOAD, FMUL2, FFMA2, STG
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < get<2>(TileShape{}) / kCorrectionTileSize; i++) {
for (int i = 0; i < 128 / kCorrectionTileSize; i++) {
Tensor tTMEM_LOADtO0_i = tTMEM_LOADtO0;
tTMEM_LOADtO0_i.data() = tTMEM_LOADtO0_i.data().get() + uint32_t(i * kCorrectionTileSize);
Tensor tTMEM_LOADtO1_i = tTMEM_LOADtO1;
@ -917,7 +917,7 @@ struct Sm100FmhaGenMainloopWarpspecialized {
float2 scale_f32x2 = make_float2(scale, scale);
Tensor tTMrO = make_tensor<ElementPV>(make_shape(shape(tTMEM_LOADcO), Int<get<2>(TileShape{}) / kCorrectionTileSize>{}));
Tensor tTMrO = make_tensor<ElementPV>(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{}));
auto copy_in = [&](int i) {
Tensor tTMEM_LOADtO_i = tTMEM_LOADtO;

View File

@ -170,8 +170,8 @@ struct Sm100FmhaLoadCpAsyncWarpspecialized {
auto tSgQ = thr_mma_qk.partition_A(gQ);
auto tScQ = thr_mma_qk.partition_A(cQ);
auto atom_q_tv = Layout<Shape<Shape<_2, _32>, _16>, Stride<Stride<_16, _32>, _1>>{};
auto atom_kv_tv = Layout<Shape<Shape<_2, _32>, _16>, Stride<Stride<_16, _32>, _1>>{};
auto atom_q_tv = Layout<Shape<Shape<_2, _32>, Shape<_16, _16>>, Stride<Stride<_16, _32>, Stride<_1, _1024>>>{};
auto atom_kv_tv = Layout<Shape<Shape<_2, _32>, Shape<_16, _4>>, Stride<Stride<_16, _32>, Stride<_1, _1024>>>{};
auto tiled_copy_q = make_cotiled_copy(
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, Element>{},

View File

@ -372,10 +372,6 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
continue;
}
if (get<1>(logical_problem_shape) == 0) {
continue;
}
bool is_softmax_0 = role == WarpRole::Softmax0;
mainloop.softmax(
@ -404,22 +400,9 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
continue;
}
if (get<1>(logical_problem_shape) == 0) {
mainloop.correction_empty(
blk_coord,
params.mainloop, logical_problem_shape,
params.problem_shape,
shared_storage.epilogue,
pipeline_corr_epi, pipeline_corr_epi_producer_state,
epilogue
);
continue;
}
mainloop.correction(
blk_coord,
params.mainloop, logical_problem_shape,
params.problem_shape,
shared_storage.epilogue,
pipeline_s0_corr, pipeline_s0_corr_consumer_state,
pipeline_s1_corr, pipeline_s1_corr_consumer_state,
@ -428,6 +411,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
epilogue
);
}
if constexpr (NumWarpsEpilogue == 0) {
@ -455,9 +439,6 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
continue;
}
if (get<1>(logical_problem_shape) == 0) {
continue;
}
mainloop.mma(
blk_coord,
@ -470,6 +451,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
pipeline_mma_corr, pipeline_mma_corr_producer_state
);
}
}
else if (role == WarpRole::Load) {
@ -486,10 +468,6 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
continue;
}
if (get<1>(logical_problem_shape) == 0) {
continue;
}
mainloop.load(
blk_coord, logical_problem_shape,
params.mainloop, params.problem_shape,

View File

@ -784,6 +784,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
auto pages_per_tile = Pow2{TileShapeS{} / page_size};
int thread_idx = threadIdx.x % cutlass::NumThreadsPerWarp;
#if 1
for (; k_tile_count > 0; ++k_index, --k_tile_count) {
pipeline_page_table.producer_acquire(pipeline_pt_producer_state);
@ -804,6 +805,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
pipeline_page_table.producer_commit(pipeline_pt_producer_state, cutlass::arch::cpasync_barrier_arrive);
++pipeline_pt_producer_state;
}
#endif
}
@ -1637,6 +1639,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
row_max_new = cutlass::max(row_max_new, shared_tensors.smem_exchange[peer_index]);
}
#ifndef B2B
// find correction factor
ElementAcc softmax_scale_log2 = mainloop_args.softmax_scale * static_cast<ElementAcc>(M_LOG2E);
correction_factor = ::exp2f(softmax_scale_log2 * (row_max - row_max_new));
@ -1648,6 +1651,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
for (int i = 0; i < size(tTR_rAcc); i++) {
tTR_rAcc(i) = ::exp2f(softmax_scale_log2 * tTR_rAcc(i) - row_max_scale_log2);
}
#endif
// quantize
cutlass::NumericArrayConverter<Element, ElementAcc, AlignmentS> epilogue_op;
@ -1701,6 +1705,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
uint32_t tmem_o) {
// for b2b gemm, do nothing
#ifndef B2B
auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{};
auto store_op = TMEM::tmem_load_to_store(load_op);
@ -1743,6 +1748,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
// store o
copy(tiled_r2t, tTR_rAcc, tTR_tAcc);
#endif
}
@ -1800,6 +1806,8 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
copy(tTR_rO_src, tR2G_rO_dst);
#ifndef B2B
// compute LSE
ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max;
@ -1811,6 +1819,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
{
gLSE(threadIdx.x) = lse;
}
#endif
}
else {
Tensor mO = make_tensor(make_gmem_ptr(epilogue_args.ptr_o), make_shape(H, D_latent, B), epilogue_args.stride_o);
@ -1839,7 +1848,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
copy(tTR_rO_src, tR2G_rO_dst);
#ifndef B2B
if (epilogue_args.ptr_lse != nullptr) {
// compute LSE
ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max;
@ -1854,6 +1863,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
gLSE(threadIdx.x) = lse;
}
}
#endif
}
}
@ -1970,6 +1980,9 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
pipeline_mma_o.consumer_wait(pipeline_mma_o_consumer_state);
#ifdef B2B
row_sum = 1;
#else
if constexpr (kWarpsInN > 1) {
// reduce row_sum if needed (for 2x2 dp)
shared_tensors.smem_exchange[threadIdx.x] = row_sum;
@ -1978,6 +1991,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
int peer_index = (threadIdx.x + 64) % 128;
row_sum += shared_tensors.smem_exchange[peer_index];
}
#endif
cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive();

View File

@ -80,17 +80,6 @@ void __global__ fmha_reference_kernel(
if constexpr (rank<1>(decltype(coord){}) == 2) {
offset_K = get<1,1>(coord);
}
if (get<1>(problem_shape) == 0) {
for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) {
mO(idx_Q + offset_Q, idx_D, idx_L) = Element(0);
}
if (threadIdx.x == 0 && mLSE.data() != nullptr) {
mLSE(idx_Q + offset_Q, idx_L) = -INFINITY;
}
continue;
}
for (int idx_K = threadIdx.x; idx_K < size<1>(problem_shape); idx_K += blockDim.x) {
ElementAccumulator acc = 0;

View File

@ -111,9 +111,11 @@ void __global__ fmha_mla_reference_kernel(
__syncthreads();
#ifndef B2B
for (int idx_K = threadIdx.x; idx_K < K; idx_K += blockDim.x) {
mS[idx_K] = expf(softmax_scale * (mS[idx_K] - maxS));
}
#endif
__syncthreads();
@ -123,6 +125,9 @@ void __global__ fmha_mla_reference_kernel(
}
ElementAcc o_scale = 1.0f / sum;
#ifdef B2B
o_scale = 1.0;
#endif
for (int idx_D = threadIdx.x; idx_D < D_latent; idx_D += blockDim.x) {
ElementAcc acc = 0;

View File

@ -101,12 +101,8 @@ __global__ void reference_abs_diff_kernel(
__shared__ double block_sum_diff;
for (size_t i = threadIdx.x + blockIdx.x * blockDim.x; i < count; i += blockDim.x * gridDim.x) {
if (data[i] == data_ref[i]) {
continue;
}
double diff = fabs(data[i] - data_ref[i]);
if (print_diff) if (not isfinite(diff) || diff > 0.01f) printf("difference at %lld: %f ... %f vs %f\n", static_cast<long long int>(i), diff, (double)data[i], (double)data_ref[i]);
if (print_diff) if (diff != diff || diff > 0.01f) printf("difference at %lld: %f ... %f vs %f\n", static_cast<long long int>(i), diff, (double)data[i], (double)data_ref[i]);
thread_max_diff = fmax(diff, thread_max_diff);
thread_sum_diff += diff;
}
@ -198,11 +194,8 @@ __global__ void reference_rel_diff_kernel(
__shared__ double block_sum_diff;
for (size_t i = threadIdx.x + blockIdx.x * blockDim.x; i < count; i += blockDim.x * gridDim.x) {
if (data[i] == data_ref[i]) {
continue;
}
double diff = fabs(data[i] - data_ref[i]) / fabs(data_ref[i]);
if (print_diff) if (not isfinite(diff) || diff > 0.01f) printf("difference at %lld: %f ... %f vs %f\n", static_cast<long long int>(i), diff, (double)data[i], (double)data_ref[i]);
if (print_diff) if (diff != diff || diff > 0.01f) printf("difference at %lld: %f ... %f vs %f\n", static_cast<long long int>(i), diff, (double)data[i], (double)data_ref[i]);
thread_max_diff = fmax(diff, thread_max_diff);
thread_sum_diff += diff;
}

View File

@ -1,3 +1,3 @@
# Use `pip install -r requirements.txt` with the present file to install a
# wheel consistent with the present state of the github repository
nvidia-cutlass-dsl==4.0.0
nvidia-cutlass-dsl==4.0.0.dev1