Compare commits
1 Commits
v4.0.0
...
feature/en
| Author | SHA1 | Date | |
|---|---|---|---|
| 6aa1894093 |
@ -35,13 +35,7 @@
|
||||
- Added non-power-of-two tile sizes.
|
||||
- Improved performance for K-major scale factors.
|
||||
- The argument `mma_promotion_interval` has been removed from non-grouped GEMM to align with the grouped and Blackwell SM100 versions.
|
||||
* Enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
|
||||
- Support LSE output in FMHA Forward kernel.
|
||||
- Enhance performance measurement: support of different warmup iterations; buffer rotation to keep L2 cold; separate testing of persistent and non-persistent.
|
||||
- Enhance testing of variable sequence length.
|
||||
- Disable B2B mode in MLA to simplify the sample.
|
||||
- Clarify that `fmha_gen` sample only supports head dim 128.
|
||||
- Fixes for split-kv output in MLA.
|
||||
* Support LSE output in Blackwell SM100 FMHA Forward kernel in example 77.
|
||||
* Improve Blackwell and Hopper grouped GEMM performance, functionality, and profiler support.
|
||||
- Enable runtime datatype for Blackwell SM100 grouped GEMM. Profiler support is also added.
|
||||
- Enable kernel parameter exploration for Blackwell SM100 grouped GEMM - raster_order, swizzle.
|
||||
|
||||
@ -76,13 +76,7 @@ To get started quickly - please refer :
|
||||
- Added non-power-of-two tile sizes.
|
||||
- Improved performance for K-major scale factors.
|
||||
- The argument `mma_promotion_interval` has been removed from non-grouped GEMM to align with the grouped and Blackwell SM100 versions.
|
||||
* Enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
|
||||
- Support LSE output in FMHA Forward kernel.
|
||||
- Enhance performance measurement: support of different warmup iterations; buffer rotation to keep L2 cold; separate testing of persistent and non-persistent.
|
||||
- Enhance testing of variable sequence length.
|
||||
- Disable B2B mode in MLA to simplify the sample.
|
||||
- Clarify that `fmha_gen` sample only supports head dim 128.
|
||||
- Fixes for split-kv output in MLA.
|
||||
* Support LSE output in Blackwell SM100 FMHA Forward kernel in example 77.
|
||||
* Improve Blackwell and Hopper grouped GEMM performance, functionality, and profiler support.
|
||||
- Enable runtime datatype for Blackwell SM100 grouped GEMM. Profiler support is also added.
|
||||
- Enable kernel parameter exploration for Blackwell SM100 grouped GEMM - raster_order, swizzle.
|
||||
|
||||
@ -888,18 +888,11 @@ struct FwdRunner {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main_result = 0;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to print a description of the example run and its result
|
||||
void print_result(const std::string& description, ExampleResult result, bool verbose) {
|
||||
std::ios fmt(nullptr);
|
||||
fmt.copyfmt(std::cout);
|
||||
std::cout << (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] ");
|
||||
if (! result.passed) {
|
||||
main_result = -1;
|
||||
}
|
||||
std::cout << std::setw(32) << std::left << description;
|
||||
std::cout.copyfmt(fmt);
|
||||
std::cout << " : " << result.tflops_tc_s << " TFLOPS/s" << std::endl;
|
||||
@ -1100,7 +1093,7 @@ int main_single(int argc, char const **args) {
|
||||
});
|
||||
#endif
|
||||
|
||||
return main_result;
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1108,6 +1101,8 @@ int main_single(int argc, char const **args) {
|
||||
int main(int argc, char const **args) {
|
||||
std::vector<std::string> full_arguments(args, args + argc);
|
||||
|
||||
int result = 0;
|
||||
|
||||
bool recursed = false;
|
||||
for (size_t i = 1; i < full_arguments.size(); i++) {
|
||||
if (full_arguments[i].find(',') != std::string::npos) {
|
||||
@ -1134,7 +1129,7 @@ int main(int argc, char const **args) {
|
||||
main_single(argc, args);
|
||||
}
|
||||
|
||||
return main_result;
|
||||
return result;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -689,18 +689,11 @@ struct ExampleRunner {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main_result = 0;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to print a description of the example run and its result
|
||||
void print_result(const std::string& description, ExampleResult result, bool verbose) {
|
||||
std::ios fmt(nullptr);
|
||||
fmt.copyfmt(std::cout);
|
||||
std::cout << (result.supported ? (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] ") : "[NSUP] ");
|
||||
if (result.supported && ! result.passed) {
|
||||
main_result = -1;
|
||||
}
|
||||
std::cout << std::setw(32) << std::left << description;
|
||||
std::cout.copyfmt(fmt);
|
||||
std::cout << " : " << result.tbytes_s << " TB/s" << std::endl;
|
||||
@ -788,17 +781,12 @@ int main_single(int argc, char const **args) {
|
||||
std::integral_constant<KernelType, KernelType::MODE>{}, Shape<_##m, _##n, _##k>{}, Shape<_##tm, _##tn, _##tk>{} \
|
||||
)
|
||||
|
||||
if (options.d == 128) {
|
||||
RUN(UMMA_I, 128, 64, 128, 1, 1, 1);
|
||||
RUN(UMMA_I, 128, 128, 128, 1, 1, 1);
|
||||
RUN(UMMA_I, 128, 256, 128, 1, 1, 1);
|
||||
RUN(UMMA_P, 128, 64, 128, 1, 1, 1);
|
||||
RUN(UMMA_P, 128, 128, 128, 1, 1, 1);
|
||||
RUN(UMMA_P, 128, 256, 128, 1, 1, 1);
|
||||
}
|
||||
else {
|
||||
std::cout << "Head Dimension != 128 is not supported for the fmha_gen example\n";
|
||||
}
|
||||
RUN(UMMA_I, 128, 64, 128, 1, 1, 1);
|
||||
RUN(UMMA_I, 128, 128, 128, 1, 1, 1);
|
||||
RUN(UMMA_I, 128, 256, 128, 1, 1, 1);
|
||||
RUN(UMMA_P, 128, 64, 128, 1, 1, 1);
|
||||
RUN(UMMA_P, 128, 128, 128, 1, 1, 1);
|
||||
RUN(UMMA_P, 128, 256, 128, 1, 1, 1);
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
@ -809,6 +797,8 @@ int main_single(int argc, char const **args) {
|
||||
int main(int argc, char const **args) {
|
||||
std::vector<std::string> full_arguments(args, args + argc);
|
||||
|
||||
int result = 0;
|
||||
|
||||
bool recursed = false;
|
||||
for (size_t i = 1; i < full_arguments.size(); i++) {
|
||||
if (full_arguments[i].find(',') != std::string::npos) {
|
||||
@ -835,7 +825,7 @@ int main(int argc, char const **args) {
|
||||
main_single(argc, args);
|
||||
}
|
||||
|
||||
return main_result;
|
||||
return result;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -391,7 +391,11 @@ struct Runner {
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
double max_diff = 0;
|
||||
double mean_diff = 0;
|
||||
#ifdef B2B
|
||||
reference_rel_diff(block_O, block_ref_O, max_diff, mean_diff);
|
||||
#else
|
||||
reference_abs_diff(block_O, block_ref_O, max_diff, mean_diff);
|
||||
#endif
|
||||
|
||||
bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
if (! passed_O) {
|
||||
@ -400,6 +404,7 @@ struct Runner {
|
||||
}
|
||||
|
||||
bool passed_LSE = true;
|
||||
#ifndef B2B
|
||||
reference_abs_diff(block_LSE, block_ref_LSE, max_diff, mean_diff);
|
||||
|
||||
passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
@ -407,6 +412,7 @@ struct Runner {
|
||||
std::cerr << "failed LSE: max diff " << max_diff
|
||||
<< " mean " << mean_diff << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
return passed_O && passed_LSE;
|
||||
}
|
||||
@ -672,18 +678,11 @@ struct Runner {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main_result = 0;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to print a description of the example run and its result
|
||||
void print_result(const std::string& description, ExampleResult result, bool verbose) {
|
||||
std::ios fmt(nullptr);
|
||||
fmt.copyfmt(std::cout);
|
||||
std::cout << (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] ");
|
||||
if (! result.passed) {
|
||||
main_result = -1;
|
||||
}
|
||||
std::cout << std::setw(32) << std::left << description;
|
||||
std::cout.copyfmt(fmt);
|
||||
std::cout << " : " << result.tflops_tc_s << " TFLOPS/s " << result.tbytes_s << " TB/s" << std::endl;
|
||||
@ -807,6 +806,8 @@ int main_single(int argc, char const **args) {
|
||||
int main(int argc, char const **args) {
|
||||
std::vector<std::string> full_arguments(args, args + argc);
|
||||
|
||||
int result = 0;
|
||||
|
||||
bool recursed = false;
|
||||
for (size_t i = 1; i < full_arguments.size(); i++) {
|
||||
if (full_arguments[i].find(',') != std::string::npos) {
|
||||
@ -833,7 +834,7 @@ int main(int argc, char const **args) {
|
||||
main_single(argc, args);
|
||||
}
|
||||
|
||||
return main_result;
|
||||
return result;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -43,30 +43,30 @@ set(TEST_VARLEN --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=residual --v
|
||||
set(TEST_HDIM64 --b=2 --h=4 --q=512 --k=512 --d=64 --verify)
|
||||
set(TEST_GQA --b=2 --h=4 --h_k=2 --q=512 --k=512 --d=64 --verify)
|
||||
|
||||
set(TEST_VARLEN_00 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=4 --varlen-q=128 --varlen-k=128)
|
||||
set(TEST_VARLEN_01 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
|
||||
set(TEST_VARLEN_02 --verify --varlen --mask=causal,residual --d=128 --h=4 --h_k=2 --varlen-q=128 --varlen-k=128)
|
||||
set(TEST_VARLEN_03 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=8 --varlen-q=256:256 --varlen-k=512:512)
|
||||
set(TEST_VARLEN_04 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=4 --varlen-q=256:256 --varlen-k=512:512)
|
||||
set(TEST_VARLEN_05 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=1 --varlen-q=256:256 --varlen-k=512:512)
|
||||
set(TEST_VARLEN_06 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=2 --varlen-q=256:256:256:256 --varlen-k=256:768:512:512)
|
||||
set(TEST_VARLEN_07 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=2 --varlen-q=256:256:256:256 --varlen-k=256:0:1280:512)
|
||||
set(TEST_VARLEN_08 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=2 --varlen-q=256:0:512:256 --varlen-k=256:256:1024:512)
|
||||
set(TEST_VARLEN_09 --verify --varlen --mask=causal,residual --d=64 --h=16 --h_k=16 --varlen-q=100:300 --varlen-k=100:300)
|
||||
set(TEST_VARLEN_10 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=3:2 --varlen-k=2:5)
|
||||
set(TEST_VARLEN_11 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=2 --varlen-q=17:10 --varlen-k=13:10)
|
||||
set(TEST_VARLEN_12 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=177:845 --varlen-k=257:766)
|
||||
set(TEST_VARLEN_13 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=2 --varlen-q=177:366:479 --varlen-k=257:0:766)
|
||||
set(TEST_VARLEN_14 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=1 --varlen-k=1)
|
||||
set(TEST_VARLEN_00 --verify --mask=causal,residual --d=128 --h=8 --h_k=4 --varlen-q=128 --varlen-k=128)
|
||||
set(TEST_VARLEN_01 --verify --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
|
||||
set(TEST_VARLEN_02 --verify --mask=causal,residual --d=128 --h=4 --h_k=2 --varlen-q=128 --varlen-k=128)
|
||||
set(TEST_VARLEN_03 --verify --mask=causal,residual --d=128 --h=8 --h_k=8 --varlen-q=256:256 --varlen-k=512:512)
|
||||
set(TEST_VARLEN_04 --verify --mask=causal,residual --d=128 --h=8 --h_k=4 --varlen-q=256:256 --varlen-k=512:512)
|
||||
set(TEST_VARLEN_05 --verify --mask=causal,residual --d=128 --h=8 --h_k=1 --varlen-q=256:256 --varlen-k=512:512)
|
||||
set(TEST_VARLEN_06 --verify --mask=causal,residual --d=128 --h=8 --h_k=2 --varlen-q=256:256:256:256 --varlen-k=256:768:512:512)
|
||||
set(TEST_VARLEN_07 --verify --mask=causal,residual --d=128 --h=8 --h_k=2 --varlen-q=256:256:256:256 --varlen-k=256:0:1280:512)
|
||||
set(TEST_VARLEN_08 --verify --mask=causal,residual --d=128 --h=8 --h_k=2 --varlen-q=256:0:512:256 --varlen-k=256:256:1024:512)
|
||||
set(TEST_VARLEN_09 --verify --mask=causal,residual --d=64 --h=16 --h_k=16 --varlen-q=100:300 --varlen-k=100:300)
|
||||
set(TEST_VARLEN_10 --verify --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=3:2 --varlen-k=2:5)
|
||||
set(TEST_VARLEN_11 --verify --mask=causal,residual --d=64 --h=4 --h_k=2 --varlen-q=17:10 --varlen-k=13:10)
|
||||
set(TEST_VARLEN_12 --verify --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=177:845 --varlen-k=257:766)
|
||||
set(TEST_VARLEN_13 --verify --mask=causal,residual --d=64 --h=4 --h_k=2 --varlen-q=177:366:479 --varlen-k=257:0:766)
|
||||
set(TEST_VARLEN_14 --verify --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=1 --varlen-k=1)
|
||||
|
||||
set(TEST_GEN_BASIC --b=1 --h=4 --k=512 --d=128 --verify)
|
||||
set(TEST_GEN_VARLEN --b=1 --h=4 --k=512 --d=128 --verify --varlen)
|
||||
set(TEST_GEN_HDIM64 --b=2 --h=4 --k=512 --d=64 --verify)
|
||||
set(TEST_GEN_GQA --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify)
|
||||
set(TEST_GEN_GQA --b=2 --h=4 --h_k=2 --k=512 --d=64 --verify)
|
||||
set(TEST_GEN_REMAP --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --remap)
|
||||
set(TEST_GEN_CACHEONLY --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --cache-only)
|
||||
|
||||
set(TEST_MLA_BASIC --b=1 --k=512 --page=128 --verify)
|
||||
set(TEST_MLA_BASIC --b=1 --k=512 --verify)
|
||||
|
||||
if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC_ARCHS MATCHES 100a))
|
||||
|
||||
@ -107,7 +107,7 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_GEN_BASIC
|
||||
TEST_GEN_VARLEN
|
||||
# TEST_GEN_HDIM64
|
||||
TEST_GEN_HDIM64
|
||||
TEST_GEN_GQA
|
||||
TEST_GEN_REMAP
|
||||
TEST_GEN_CACHEONLY
|
||||
@ -135,6 +135,16 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
|
||||
target_compile_definitions(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE ${PREC_MACRO} CPASYNC)
|
||||
target_compile_options(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE -Xptxas -v)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_mla_b2b_2sm_${PREC}
|
||||
77_blackwell_mla.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_MLA_BASIC
|
||||
)
|
||||
target_include_directories(77_blackwell_mla_b2b_2sm_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_mla_b2b_2sm_${PREC} PRIVATE ${PREC_MACRO} B2B)
|
||||
target_compile_options(77_blackwell_mla_b2b_2sm_${PREC} PRIVATE -Xptxas -v)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_fmha_bwd_${PREC}
|
||||
77_blackwell_fmha_bwd.cu
|
||||
@ -173,6 +183,8 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
|
||||
77_blackwell_mla_2sm_fp16
|
||||
77_blackwell_mla_2sm_cpasync_fp8
|
||||
77_blackwell_mla_2sm_cpasync_fp16
|
||||
77_blackwell_mla_b2b_2sm_fp8
|
||||
77_blackwell_mla_b2b_2sm_fp16
|
||||
77_blackwell_fmha_bwd_fp8
|
||||
77_blackwell_fmha_bwd_fp16
|
||||
)
|
||||
|
||||
@ -55,7 +55,6 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
||||
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, replace<2>(TileShape{}, _2{}), Step<_2, _1, _3>{}));
|
||||
using SmemLayoutO_ = SmemLayoutO;
|
||||
using StrideLSE = StrideLSE_;
|
||||
using ElementOut = Element;
|
||||
|
||||
struct TensorStorage {
|
||||
|
||||
|
||||
@ -1065,7 +1065,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
// F2FP
|
||||
// store to smem
|
||||
Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{});
|
||||
Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE);
|
||||
Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), repeat_like(typename CollectiveEpilogue::StrideLSE{}, _1{}), epilogue.params.dLSE);
|
||||
|
||||
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO);
|
||||
|
||||
@ -1129,85 +1129,6 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
++pipeline_epi_producer_state;
|
||||
}
|
||||
|
||||
|
||||
template<
|
||||
class BlkCoord, class ProblemShape, class ParamsProblemShape,
|
||||
class TensorStorageEpi, class CollectiveEpilogue
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
correction_empty(
|
||||
BlkCoord const& blk_coord,
|
||||
Params const& params, ProblemShape const& problem_shape,
|
||||
ParamsProblemShape const& params_problem_shape,
|
||||
TensorStorageEpi& shared_storage_epi,
|
||||
PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state,
|
||||
CollectiveEpilogue& epilogue) {
|
||||
|
||||
pipeline_epi.producer_acquire(pipeline_epi_producer_state);
|
||||
|
||||
Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{});
|
||||
Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE);
|
||||
float lse = -INFINITY;
|
||||
int thread_idx = threadIdx.x % (4 * NumThreadsPerWarp);
|
||||
|
||||
#define DSHOW(x) print(#x ": "); print(x); print("\n")
|
||||
if (threadIdx.x % 128 == 0 && block0()) {
|
||||
DSHOW(sO);
|
||||
}
|
||||
#if 1
|
||||
|
||||
using ElementOut = typename CollectiveEpilogue::ElementOut;
|
||||
auto tiled_copy = make_cotiled_copy(
|
||||
Copy_Atom<UniversalCopy<uint32_t>, ElementOut>{},
|
||||
make_ordered_layout(make_shape(_128{}, Int<sizeof(uint32_t) / sizeof(ElementOut)>{}), Step<_1, _0>{}),
|
||||
sO.layout());
|
||||
|
||||
auto thr_copy = tiled_copy.get_slice(thread_idx);
|
||||
auto tOgO = thr_copy.partition_D(sO);
|
||||
auto tOrO = make_tensor<ElementOut>(shape(tOgO(_,_,_,_0{})));
|
||||
clear(tOrO);
|
||||
|
||||
copy(tiled_copy, tOrO, tOgO(_,_,_,_0{}));
|
||||
#endif
|
||||
|
||||
if (epilogue.params.ptr_LSE != nullptr) {
|
||||
int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord);
|
||||
|
||||
int row_offset = 0;
|
||||
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
|
||||
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
|
||||
}
|
||||
|
||||
if (row_idx < get<0>(problem_shape)) {
|
||||
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
|
||||
}
|
||||
}
|
||||
|
||||
pipeline_epi.producer_commit(pipeline_epi_producer_state);
|
||||
++pipeline_epi_producer_state;
|
||||
|
||||
copy(tiled_copy, tOrO, tOgO(_,_,_,_1{}));
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
pipeline_epi.producer_acquire(pipeline_epi_producer_state);
|
||||
|
||||
if (epilogue.params.ptr_LSE != nullptr) {
|
||||
int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{});
|
||||
|
||||
int row_offset = 0;
|
||||
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
|
||||
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
|
||||
}
|
||||
|
||||
if (row_idx < get<0>(problem_shape)) {
|
||||
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
|
||||
}
|
||||
}
|
||||
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
pipeline_epi.producer_commit(pipeline_epi_producer_state);
|
||||
++pipeline_epi_producer_state;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
|
||||
@ -831,7 +831,7 @@ struct Sm100FmhaGenMainloopWarpspecialized {
|
||||
// loop:
|
||||
// TMEM_LOAD, TMEM_LOAD, FMUL2, FFMA2, STG
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < get<2>(TileShape{}) / kCorrectionTileSize; i++) {
|
||||
for (int i = 0; i < 128 / kCorrectionTileSize; i++) {
|
||||
Tensor tTMEM_LOADtO0_i = tTMEM_LOADtO0;
|
||||
tTMEM_LOADtO0_i.data() = tTMEM_LOADtO0_i.data().get() + uint32_t(i * kCorrectionTileSize);
|
||||
Tensor tTMEM_LOADtO1_i = tTMEM_LOADtO1;
|
||||
@ -917,7 +917,7 @@ struct Sm100FmhaGenMainloopWarpspecialized {
|
||||
|
||||
float2 scale_f32x2 = make_float2(scale, scale);
|
||||
|
||||
Tensor tTMrO = make_tensor<ElementPV>(make_shape(shape(tTMEM_LOADcO), Int<get<2>(TileShape{}) / kCorrectionTileSize>{}));
|
||||
Tensor tTMrO = make_tensor<ElementPV>(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{}));
|
||||
|
||||
auto copy_in = [&](int i) {
|
||||
Tensor tTMEM_LOADtO_i = tTMEM_LOADtO;
|
||||
|
||||
@ -170,8 +170,8 @@ struct Sm100FmhaLoadCpAsyncWarpspecialized {
|
||||
auto tSgQ = thr_mma_qk.partition_A(gQ);
|
||||
auto tScQ = thr_mma_qk.partition_A(cQ);
|
||||
|
||||
auto atom_q_tv = Layout<Shape<Shape<_2, _32>, _16>, Stride<Stride<_16, _32>, _1>>{};
|
||||
auto atom_kv_tv = Layout<Shape<Shape<_2, _32>, _16>, Stride<Stride<_16, _32>, _1>>{};
|
||||
auto atom_q_tv = Layout<Shape<Shape<_2, _32>, Shape<_16, _16>>, Stride<Stride<_16, _32>, Stride<_1, _1024>>>{};
|
||||
auto atom_kv_tv = Layout<Shape<Shape<_2, _32>, Shape<_16, _4>>, Stride<Stride<_16, _32>, Stride<_1, _1024>>>{};
|
||||
|
||||
auto tiled_copy_q = make_cotiled_copy(
|
||||
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, Element>{},
|
||||
|
||||
@ -372,10 +372,6 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (get<1>(logical_problem_shape) == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool is_softmax_0 = role == WarpRole::Softmax0;
|
||||
|
||||
mainloop.softmax(
|
||||
@ -404,18 +400,6 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (get<1>(logical_problem_shape) == 0) {
|
||||
mainloop.correction_empty(
|
||||
blk_coord,
|
||||
params.mainloop, logical_problem_shape,
|
||||
params.problem_shape,
|
||||
shared_storage.epilogue,
|
||||
pipeline_corr_epi, pipeline_corr_epi_producer_state,
|
||||
epilogue
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
mainloop.correction(
|
||||
blk_coord,
|
||||
params.mainloop, logical_problem_shape,
|
||||
@ -428,6 +412,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
epilogue
|
||||
);
|
||||
|
||||
|
||||
}
|
||||
|
||||
if constexpr (NumWarpsEpilogue == 0) {
|
||||
@ -455,9 +440,6 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (get<1>(logical_problem_shape) == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
mainloop.mma(
|
||||
blk_coord,
|
||||
@ -470,6 +452,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
pipeline_mma_corr, pipeline_mma_corr_producer_state
|
||||
);
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
else if (role == WarpRole::Load) {
|
||||
@ -486,10 +469,6 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (get<1>(logical_problem_shape) == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
mainloop.load(
|
||||
blk_coord, logical_problem_shape,
|
||||
params.mainloop, params.problem_shape,
|
||||
|
||||
@ -784,6 +784,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
auto pages_per_tile = Pow2{TileShapeS{} / page_size};
|
||||
int thread_idx = threadIdx.x % cutlass::NumThreadsPerWarp;
|
||||
|
||||
#if 1
|
||||
for (; k_tile_count > 0; ++k_index, --k_tile_count) {
|
||||
pipeline_page_table.producer_acquire(pipeline_pt_producer_state);
|
||||
|
||||
@ -804,6 +805,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
pipeline_page_table.producer_commit(pipeline_pt_producer_state, cutlass::arch::cpasync_barrier_arrive);
|
||||
++pipeline_pt_producer_state;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@ -1637,6 +1639,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
row_max_new = cutlass::max(row_max_new, shared_tensors.smem_exchange[peer_index]);
|
||||
}
|
||||
|
||||
#ifndef B2B
|
||||
// find correction factor
|
||||
ElementAcc softmax_scale_log2 = mainloop_args.softmax_scale * static_cast<ElementAcc>(M_LOG2E);
|
||||
correction_factor = ::exp2f(softmax_scale_log2 * (row_max - row_max_new));
|
||||
@ -1648,6 +1651,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
for (int i = 0; i < size(tTR_rAcc); i++) {
|
||||
tTR_rAcc(i) = ::exp2f(softmax_scale_log2 * tTR_rAcc(i) - row_max_scale_log2);
|
||||
}
|
||||
#endif
|
||||
|
||||
// quantize
|
||||
cutlass::NumericArrayConverter<Element, ElementAcc, AlignmentS> epilogue_op;
|
||||
@ -1701,6 +1705,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
uint32_t tmem_o) {
|
||||
|
||||
// for b2b gemm, do nothing
|
||||
#ifndef B2B
|
||||
auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{};
|
||||
auto store_op = TMEM::tmem_load_to_store(load_op);
|
||||
|
||||
@ -1743,6 +1748,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
|
||||
// store o
|
||||
copy(tiled_r2t, tTR_rAcc, tTR_tAcc);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@ -1800,6 +1806,8 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
|
||||
copy(tTR_rO_src, tR2G_rO_dst);
|
||||
|
||||
#ifndef B2B
|
||||
|
||||
// compute LSE
|
||||
ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max;
|
||||
|
||||
@ -1811,6 +1819,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
{
|
||||
gLSE(threadIdx.x) = lse;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else {
|
||||
Tensor mO = make_tensor(make_gmem_ptr(epilogue_args.ptr_o), make_shape(H, D_latent, B), epilogue_args.stride_o);
|
||||
@ -1839,7 +1848,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
|
||||
copy(tTR_rO_src, tR2G_rO_dst);
|
||||
|
||||
|
||||
#ifndef B2B
|
||||
if (epilogue_args.ptr_lse != nullptr) {
|
||||
// compute LSE
|
||||
ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max;
|
||||
@ -1854,6 +1863,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
gLSE(threadIdx.x) = lse;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@ -1970,6 +1980,9 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
|
||||
pipeline_mma_o.consumer_wait(pipeline_mma_o_consumer_state);
|
||||
|
||||
#ifdef B2B
|
||||
row_sum = 1;
|
||||
#else
|
||||
if constexpr (kWarpsInN > 1) {
|
||||
// reduce row_sum if needed (for 2x2 dp)
|
||||
shared_tensors.smem_exchange[threadIdx.x] = row_sum;
|
||||
@ -1978,6 +1991,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
int peer_index = (threadIdx.x + 64) % 128;
|
||||
row_sum += shared_tensors.smem_exchange[peer_index];
|
||||
}
|
||||
#endif
|
||||
|
||||
cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive();
|
||||
|
||||
|
||||
@ -80,17 +80,6 @@ void __global__ fmha_reference_kernel(
|
||||
if constexpr (rank<1>(decltype(coord){}) == 2) {
|
||||
offset_K = get<1,1>(coord);
|
||||
}
|
||||
|
||||
if (get<1>(problem_shape) == 0) {
|
||||
for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) {
|
||||
mO(idx_Q + offset_Q, idx_D, idx_L) = Element(0);
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0 && mLSE.data() != nullptr) {
|
||||
mLSE(idx_Q + offset_Q, idx_L) = -INFINITY;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int idx_K = threadIdx.x; idx_K < size<1>(problem_shape); idx_K += blockDim.x) {
|
||||
ElementAccumulator acc = 0;
|
||||
|
||||
@ -111,9 +111,11 @@ void __global__ fmha_mla_reference_kernel(
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#ifndef B2B
|
||||
for (int idx_K = threadIdx.x; idx_K < K; idx_K += blockDim.x) {
|
||||
mS[idx_K] = expf(softmax_scale * (mS[idx_K] - maxS));
|
||||
}
|
||||
#endif
|
||||
|
||||
__syncthreads();
|
||||
|
||||
@ -123,6 +125,9 @@ void __global__ fmha_mla_reference_kernel(
|
||||
}
|
||||
|
||||
ElementAcc o_scale = 1.0f / sum;
|
||||
#ifdef B2B
|
||||
o_scale = 1.0;
|
||||
#endif
|
||||
|
||||
for (int idx_D = threadIdx.x; idx_D < D_latent; idx_D += blockDim.x) {
|
||||
ElementAcc acc = 0;
|
||||
|
||||
@ -101,12 +101,8 @@ __global__ void reference_abs_diff_kernel(
|
||||
__shared__ double block_sum_diff;
|
||||
|
||||
for (size_t i = threadIdx.x + blockIdx.x * blockDim.x; i < count; i += blockDim.x * gridDim.x) {
|
||||
if (data[i] == data_ref[i]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
double diff = fabs(data[i] - data_ref[i]);
|
||||
if (print_diff) if (not isfinite(diff) || diff > 0.01f) printf("difference at %lld: %f ... %f vs %f\n", static_cast<long long int>(i), diff, (double)data[i], (double)data_ref[i]);
|
||||
if (print_diff) if (diff != diff || diff > 0.01f) printf("difference at %lld: %f ... %f vs %f\n", static_cast<long long int>(i), diff, (double)data[i], (double)data_ref[i]);
|
||||
thread_max_diff = fmax(diff, thread_max_diff);
|
||||
thread_sum_diff += diff;
|
||||
}
|
||||
@ -198,11 +194,8 @@ __global__ void reference_rel_diff_kernel(
|
||||
__shared__ double block_sum_diff;
|
||||
|
||||
for (size_t i = threadIdx.x + blockIdx.x * blockDim.x; i < count; i += blockDim.x * gridDim.x) {
|
||||
if (data[i] == data_ref[i]) {
|
||||
continue;
|
||||
}
|
||||
double diff = fabs(data[i] - data_ref[i]) / fabs(data_ref[i]);
|
||||
if (print_diff) if (not isfinite(diff) || diff > 0.01f) printf("difference at %lld: %f ... %f vs %f\n", static_cast<long long int>(i), diff, (double)data[i], (double)data_ref[i]);
|
||||
if (print_diff) if (diff != diff || diff > 0.01f) printf("difference at %lld: %f ... %f vs %f\n", static_cast<long long int>(i), diff, (double)data[i], (double)data_ref[i]);
|
||||
thread_max_diff = fmax(diff, thread_max_diff);
|
||||
thread_sum_diff += diff;
|
||||
}
|
||||
|
||||
@ -566,6 +566,8 @@ check_input_datatypes() {
|
||||
((SfVectorSizeA == 32 && cute::is_same_v<KernelScheduleAuto, BuilderScheduleTag>)
|
||||
|| (SfVectorSizeA == 32 && cute::is_same_v<KernelTmaWarpSpecializedPingpong, BuilderScheduleTag>)
|
||||
|| (SfVectorSizeA == 32 && cute::is_same_v<KernelTmaWarpSpecializedCooperative, BuilderScheduleTag>)
|
||||
|| (SfVectorSizeA == 32 && cute::is_same_v<KernelPtrArrayTmaWarpSpecializedPingpong, BuilderScheduleTag>)
|
||||
|| (SfVectorSizeA == 32 && cute::is_same_v<KernelPtrArrayTmaWarpSpecializedCooperative, BuilderScheduleTag>)
|
||||
|| (SfVectorSizeA == 32 && cute::is_base_of_v<KernelScheduleBlockScaledGemmSm100, BuilderScheduleTag>)
|
||||
|| (SfVectorSizeA == 32 && cute::is_base_of_v<KernelSchedulePtrArrayBlockScaledGemmSm100, BuilderScheduleTag>)
|
||||
|| (SfVectorSizeA == 64 && cute::is_base_of_v<KernelScheduleBlockScaledSparseGemmSm100, BuilderScheduleTag>)
|
||||
|
||||
@ -2256,12 +2256,16 @@ bool TestSmall(double alpha = 1.0, double beta = 1.0,
|
||||
using ElementA = typename Gemm::GemmKernel::ElementA;
|
||||
using ElementB = typename Gemm::GemmKernel::ElementB;
|
||||
using TiledMma = typename Gemm::GemmKernel::TiledMma;
|
||||
int alignment_bits = 128;
|
||||
|
||||
static constexpr bool IsF8F6F4 = cutlass::gemm::collective::detail::is_sm100_mma_f8f6f4<TiledMma, ElementA, ElementB>();
|
||||
alignment_bits = cutlass::detail::get_input_alignment_bits<ElementA, IsF8F6F4>();
|
||||
// For fp4 and fp6 kernels, the min alignment_input is 128 elements, so we don't need to add alignment_input in test problem sizes.
|
||||
int alignment_input = (alignment_bits / cute::sizeof_bits<ElementA>::value == 128) ? 0 : (alignment_bits / cute::sizeof_bits<ElementA>::value);
|
||||
// For fp4 and fp6 kernels, the min alignment_input is 128 elements, so we don't need to add alignment_input in test problem sizes.
|
||||
int alignment_bits_a = cutlass::detail::get_input_alignment_bits<ElementA, IsF8F6F4>();
|
||||
int alignment_input_a = (alignment_bits_a / cute::sizeof_bits<ElementA>::value == 128) ? 0 : (alignment_bits_a / cute::sizeof_bits<ElementA>::value);
|
||||
|
||||
int alignment_bits_b = cutlass::detail::get_input_alignment_bits<ElementB, IsF8F6F4>();
|
||||
int alignment_input_b = (alignment_bits_b / cute::sizeof_bits<ElementB>::value == 128) ? 0 : (alignment_bits_b / cute::sizeof_bits<ElementB>::value);
|
||||
|
||||
int alignment_input = (alignment_input_a == 0 || alignment_input_b == 0) ? 0 : std::max(alignment_input_a, alignment_input_b);
|
||||
|
||||
|
||||
if constexpr (apply_alignment_offset) {
|
||||
|
||||
@ -71,6 +71,7 @@ cutlass_test_unit_gemm_device_add_executable(
|
||||
cutlass_test_unit_gemm_device_add_executable(
|
||||
cutlass_test_unit_bs_grouped_gemm_device_tensorop_sm120
|
||||
sm120_bs_gemm_nvf4_nvf4_f32_nvf4_group_gemm_fusion.cu
|
||||
sm120_bs_gemm_mxf8_mxf4_f32_group_gemm_fusion.cu
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
@ -0,0 +1,362 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
/*! \file
|
||||
\brief Tests for device-wide grouped GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
|
||||
#include "../../../common/cutlass_unit_test.h"
|
||||
#include "../gemm_testbed_3x_ptr_array.hpp"
|
||||
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
|
||||
// Pingpong kernel schedule
|
||||
TEST(SM120_Device_Gemm_e5m2t_e2m1n_e2m1t_tensorop_f32_epilogue_VS32_group_pingpong, row_sf) {
|
||||
using ElementInputA = float_e5m2_t;
|
||||
using ElementInputB = float_e2m1_t;
|
||||
using ElementA = cutlass::mx_float8_t<ElementInputA>;
|
||||
using ElementB = cutlass::mx_float4_t<ElementInputB>;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::float_e2m1_t;
|
||||
using ElementCompute = float;
|
||||
using ElementAccumulator = float;
|
||||
using ElementSF = cutlass::float_ue8m0_t;
|
||||
using ElementSFD = ElementSF;
|
||||
using ElementAccumulator = float;
|
||||
using GmemLayoutA = cutlass::layout::RowMajor;
|
||||
using GmemLayoutB = cutlass::layout::ColumnMajor;
|
||||
using GmemLayoutC = cutlass::layout::RowMajor;
|
||||
constexpr int SFVectorSize = 32;
|
||||
using TileShape_MNK = Shape<_128,_128,_128>;
|
||||
using ClusterShape_MNK = Shape<_1,_1,_1>;
|
||||
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementInputA>::value;
|
||||
constexpr int AlignmentB = 128;
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
//
|
||||
// Construct CollectiveEpilogue
|
||||
//
|
||||
|
||||
constexpr int OutputSFVectorSize = SFVectorSize;
|
||||
// D = alpha * acc + beta * C
|
||||
// With Row-major BlockScaleFactor generation.
|
||||
using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor<
|
||||
OutputSFVectorSize,
|
||||
ElementD,
|
||||
ElementCompute,
|
||||
ElementSFD, GmemLayoutC,
|
||||
ElementC>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm120, cutlass::arch::OpClassBlockScaledTensorOp,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementC, GmemLayoutC *, AlignmentC,
|
||||
ElementD, GmemLayoutC *, AlignmentD,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto,
|
||||
FusionOperation
|
||||
>::CollectiveOp;
|
||||
|
||||
//
|
||||
// Construct CollectiveMainloop
|
||||
//
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm120, cutlass::arch::OpClassBlockScaledTensorOp,
|
||||
ElementA, GmemLayoutA *, AlignmentA,
|
||||
ElementB, GmemLayoutB *, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cutlass::gemm::GroupProblemShape<Shape<int,int,int>>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
auto pass = test::gemm::device::TestSmallFusion<Gemm>(1.0, 0.5);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
|
||||
|
||||
TEST(SM120_Device_Gemm_e5m2t_e2m1n_e2m1t_tensorop_f32_epilogue_VS32_group_pingpong, silu_row_sf) {
|
||||
using ElementInputA = float_e5m2_t;
|
||||
using ElementInputB = float_e2m1_t;
|
||||
using ElementA = cutlass::mx_float8_t<ElementInputA>;
|
||||
using ElementB = cutlass::mx_float4_t<ElementInputB>;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::float_e2m1_t;
|
||||
using ElementCompute = float;
|
||||
using ElementAccumulator = float;
|
||||
using ElementSF = cutlass::float_ue4m3_t;
|
||||
using ElementSFD = ElementSF;
|
||||
using ElementAccumulator = float;
|
||||
using GmemLayoutA = cutlass::layout::RowMajor;
|
||||
using GmemLayoutB = cutlass::layout::ColumnMajor;
|
||||
using GmemLayoutC = cutlass::layout::RowMajor;
|
||||
constexpr int SFVectorSize = 32;
|
||||
using TileShape_MNK = Shape<_128,_128,_128>;
|
||||
using ClusterShape_MNK = Shape<_1,_1,_1>;
|
||||
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementInputA>::value;
|
||||
constexpr int AlignmentB = 128;
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
//
|
||||
// Construct CollectiveEpilogue
|
||||
//
|
||||
|
||||
constexpr int OutputSFVectorSize = SFVectorSize;
|
||||
// D = SiLu(alpha * acc + beta * C)
|
||||
// With Row-major BlockScaleFactor generation.
|
||||
using FusionOperation = cutlass::epilogue::fusion::LinCombEltActBlockScaleFactor<
|
||||
cutlass::epilogue::thread::SiLu,
|
||||
OutputSFVectorSize,
|
||||
ElementD,
|
||||
ElementCompute,
|
||||
ElementSFD, GmemLayoutC,
|
||||
ElementC>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm120, cutlass::arch::OpClassBlockScaledTensorOp,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementC, GmemLayoutC *, AlignmentC,
|
||||
ElementD, GmemLayoutC *, AlignmentD,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto,
|
||||
FusionOperation
|
||||
>::CollectiveOp;
|
||||
|
||||
//
|
||||
// Construct CollectiveMainloop
|
||||
//
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm120, cutlass::arch::OpClassBlockScaledTensorOp,
|
||||
ElementA, GmemLayoutA *, AlignmentA,
|
||||
ElementB, GmemLayoutB *, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cutlass::gemm::GroupProblemShape<Shape<int,int,int>>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
auto pass = test::gemm::device::TestSmallFusion<Gemm>(1.0, 0.5);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
|
||||
// Cooperative kenel schedule
|
||||
TEST(SM120_Device_Gemm_e5m2t_e2m1n_e2m1t_tensorop_f32_epilogue_VS32_group_cooperative, row_sf) {
|
||||
using ElementInputA = float_e5m2_t;
|
||||
using ElementInputB = float_e2m1_t;
|
||||
using ElementA = cutlass::mx_float8_t<ElementInputA>;
|
||||
using ElementB = cutlass::mx_float4_t<ElementInputB>;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::float_e2m1_t;
|
||||
using ElementCompute = float;
|
||||
using ElementAccumulator = float;
|
||||
using ElementSF = cutlass::float_ue4m3_t;
|
||||
using ElementSFD = ElementSF;
|
||||
using ElementAccumulator = float;
|
||||
using GmemLayoutA = cutlass::layout::RowMajor;
|
||||
using GmemLayoutB = cutlass::layout::ColumnMajor;
|
||||
using GmemLayoutC = cutlass::layout::RowMajor;
|
||||
constexpr int SFVectorSize = 32;
|
||||
using TileShape_MNK = Shape<_128,_128,_128>;
|
||||
using ClusterShape_MNK = Shape<_1,_1,_1>;
|
||||
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementInputA>::value;
|
||||
constexpr int AlignmentB = 128;
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
//
|
||||
// Construct CollectiveEpilogue
|
||||
//
|
||||
|
||||
constexpr int OutputSFVectorSize = SFVectorSize;
|
||||
// D = alpha * acc + beta * C
|
||||
// With Row-major BlockScaleFactor generation.
|
||||
using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor<
|
||||
OutputSFVectorSize,
|
||||
ElementD,
|
||||
ElementCompute,
|
||||
ElementSFD, GmemLayoutC,
|
||||
ElementC>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm120, cutlass::arch::OpClassBlockScaledTensorOp,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementC, GmemLayoutC *, AlignmentC,
|
||||
ElementD, GmemLayoutC *, AlignmentD,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto,
|
||||
FusionOperation
|
||||
>::CollectiveOp;
|
||||
|
||||
//
|
||||
// Construct CollectiveMainloop
|
||||
//
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm120, cutlass::arch::OpClassBlockScaledTensorOp,
|
||||
ElementA, GmemLayoutA *, AlignmentA,
|
||||
ElementB, GmemLayoutB *, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cutlass::gemm::GroupProblemShape<Shape<int,int,int>>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
auto pass = test::gemm::device::TestSmallFusion<Gemm>(1.0, 0.5);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
|
||||
|
||||
TEST(SM120_Device_Gemm_e5m2t_e2m1n_e2m1t_tensorop_f32_epilogue_VS32_group_cooperative, silu_row_sf) {
|
||||
using ElementInputA = float_e5m2_t;
|
||||
using ElementInputB = float_e2m1_t;
|
||||
using ElementA = cutlass::mx_float8_t<ElementInputA>;
|
||||
using ElementB = cutlass::mx_float4_t<ElementInputB>;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::float_e2m1_t;
|
||||
using ElementCompute = float;
|
||||
using ElementAccumulator = float;
|
||||
using ElementSF = cutlass::float_ue4m3_t;
|
||||
using ElementSFD = ElementSF;
|
||||
using ElementAccumulator = float;
|
||||
using GmemLayoutA = cutlass::layout::RowMajor;
|
||||
using GmemLayoutB = cutlass::layout::ColumnMajor;
|
||||
using GmemLayoutC = cutlass::layout::RowMajor;
|
||||
constexpr int SFVectorSize = 32;
|
||||
using TileShape_MNK = Shape<_128,_128,_128>;
|
||||
using ClusterShape_MNK = Shape<_1,_1,_1>;
|
||||
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementInputA>::value;
|
||||
constexpr int AlignmentB = 128;
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
//
|
||||
// Construct CollectiveEpilogue
|
||||
//
|
||||
|
||||
constexpr int OutputSFVectorSize = SFVectorSize;
|
||||
// D = SiLu(alpha * acc + beta * C)
|
||||
// With Row-major BlockScaleFactor generation.
|
||||
using FusionOperation = cutlass::epilogue::fusion::LinCombEltActBlockScaleFactor<
|
||||
cutlass::epilogue::thread::SiLu,
|
||||
OutputSFVectorSize,
|
||||
ElementD,
|
||||
ElementCompute,
|
||||
ElementSFD, GmemLayoutC,
|
||||
ElementC>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm120, cutlass::arch::OpClassBlockScaledTensorOp,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementC, GmemLayoutC *, AlignmentC,
|
||||
ElementD, GmemLayoutC *, AlignmentD,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto,
|
||||
FusionOperation
|
||||
>::CollectiveOp;
|
||||
|
||||
//
|
||||
// Construct CollectiveMainloop
|
||||
//
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm120, cutlass::arch::OpClassBlockScaledTensorOp,
|
||||
ElementA, GmemLayoutA *, AlignmentA,
|
||||
ElementB, GmemLayoutB *, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cutlass::gemm::GroupProblemShape<Shape<int,int,int>>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
auto pass = test::gemm::device::TestSmallFusion<Gemm>(1.0, 0.5);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
Reference in New Issue
Block a user