CUTLASS 3.4.0 (#1286)

* CUTLASS 3.4.0

* Update CHANGELOG.md

---------

Co-authored-by: Pradeep Ramani <prramani@nvidia.com>
This commit is contained in:
Pradeep Ramani
2023-12-29 12:21:31 -08:00
committed by GitHub
parent b7508e3379
commit 8236f30675
211 changed files with 11409 additions and 2763 deletions

View File

@ -1,4 +1,12 @@
# NVIDIA CUTLASS Changelog
## [3.4](https://github.com/NVIDIA/cutlass/releases/tag/v3.4) (2023-12-29)
* Expanded [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) support covering {16-bit, 8-bit} x {8-bit, 4-bit} input types with fast numerical converters and group scaling factors.
* Performance improvements to [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm)
* Beta release of [Pointer-Array Batched GEMMs](/examples/56_hopper_ptr_array_batched_gemm) now available on Hopper GPUs utilizing TMA and WGMMA (requires CUDA 12.3 or above).
* Beta release of [Group-GEMM](/examples/57_hopper_grouped_gemm) utilizing TMA and WGMMA (requires CUDA 12.3 or above).
* NamedBarriers usability improvement and list of [ReservedNamedBarriers](/include/cutlass/arch/barrier.h) has been officially released.
* Improved [CuTe TMA Tensor](/media/docs/cute/0z_tma_tensors.md) documentation.
## [3.3](https://github.com/NVIDIA/cutlass/releases/tag/v3.3) (2023-10-31)
* [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input operand types.

View File

@ -40,7 +40,7 @@ endif()
message(STATUS "CMake Version: ${CMAKE_VERSION}")
set(IMPLICIT_CMAKE_CXX_STANDARD OFF CACHE BOOL "Do not explicitly specify -std=c++11 if set")
project(CUTLASS VERSION 3.3.0 LANGUAGES CXX)
project(CUTLASS VERSION 3.4.0 LANGUAGES CXX)
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)
if (CUDA_VERSION VERSION_LESS 11.3)
@ -681,6 +681,12 @@ endif()
################################################################################
set(CUTLASS_DEFAULT_ACTIVE_TEST_SETS "default" CACHE STRING "Default
activated test sets. In `make test` mode, this string determines the
active set of tests. In `ctest` mode, this value can be overriden
with CUTLASS_TEST_SETS environment variable when running the ctest
executable.")
set(CUTLASS_CTEST_TEMPLATE_FILE ${CMAKE_CURRENT_LIST_DIR}/cmake/CTestTestfile.configure.cmake)
set(CUTLASS_CTEST_GENERATED_FILES "" CACHE INTERNAL "")
@ -701,11 +707,12 @@ function(cutlass_add_executable_tests NAME TARGET)
# generating the full variable name to be referenced.
# RESULT_CACHE_FILE: A file to be installed alongside the test executable with pre-computed
# test results to speed up test runtime.
# TEST_SETS_SUPPORTED: A list of test set names these tests support.
#
set(options DISABLE_EXECUTABLE_INSTALL_RULE)
set(oneValueArgs DISABLE_TESTS RESULT_CACHE_FILE TEST_COMMAND_OPTIONS_PREFIX)
set(multiValueArgs DEPENDS DEPENDEES TEST_COMMAND_OPTIONS)
set(multiValueArgs DEPENDS DEPENDEES TEST_COMMAND_OPTIONS TEST_SETS_SUPPORTED)
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
if (NOT DEFINED __DISABLE_TESTS)
@ -715,6 +722,12 @@ function(cutlass_add_executable_tests NAME TARGET)
set(TEST_EXE $<TARGET_FILE_NAME:${TARGET}>)
set(TEST_EXE_WORKING_DIRECTORY ./${CMAKE_INSTALL_BINDIR})
if (NOT DEFINED __TEST_SETS_SUPPORTED)
set(__TEST_SETS_SUPPORTED ${CUTLASS_DEFAULT_ACTIVE_TEST_SETS})
endif()
set(TEST_SETS_SUPPORTED ${__TEST_SETS_SUPPORTED})
if (__RESULT_CACHE_FILE)
add_custom_command(
@ -816,8 +829,6 @@ function(cutlass_add_executable_tests NAME TARGET)
set(TEST_GEN_DIR ${CMAKE_CURRENT_BINARY_DIR}/ctest/${TEST_NAME})
file(MAKE_DIRECTORY ${TEST_GEN_DIR})
set(TEST_SETS_SUPPORTED default)
set(TEST_EXE_PATH $<TARGET_FILE:${TARGET}>)
set(TEST_USE_EXTENDED_FORMAT ON)
configure_file("${CUTLASS_CTEST_TEMPLATE_FILE}" "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.cmake" @ONLY)
@ -873,9 +884,9 @@ if (CUTLASS_INSTALL_TESTS)
file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/ctest")
file(WRITE "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "# Generated File\n\n")
file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "if (NOT DEFINED ENV{CUTLASS_TEST_SET})\n")
file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" " set(ENV{CUTLASS_TEST_SET} \"default\")\n")
file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "cmake_policy(SET CMP0057 NEW) # Allow IN_LIST for if()\n\n")
file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "if (NOT DEFINED ENV{CUTLASS_TEST_SETS})\n")
file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" " set(ENV{CUTLASS_TEST_SETS} ${CUTLASS_DEFAULT_ACTIVE_TEST_SETS})\n")
file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "endif()\n\n")
foreach(GENERATED_FILE ${CUTLASS_CTEST_GENERATED_FILES})
@ -897,9 +908,15 @@ write_basic_package_version_file(
${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfigVersion.cmake
COMPATIBILITY AnyNewerVersion)
configure_file(
${CMAKE_CURRENT_SOURCE_DIR}/cmake/NvidiaCutlassConfig.cmake.in
${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfig.cmake
@ONLY
)
install(
FILES
${CMAKE_CURRENT_SOURCE_DIR}/cmake/NvidiaCutlassConfig.cmake
${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfig.cmake
${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfigVersion.cmake
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/NvidiaCutlass/
)

View File

@ -13,6 +13,7 @@ Cris Cecka<br />
Aniket Shivam<br />
Jack Kosaian<br />
Mark Hoemmen<br />
Richard Cai<br />
Honghao Lu<br />
Ethan Yan<br />
Haicheng Wu<br />
@ -21,6 +22,8 @@ Dustyn Blasig<br />
Fengqi Qiao<br />
Duane Merrill<br />
Yujia Zhai<br />
Rawn Henry<br />
Sergey Klevtsov<br />
Shang Zhang<br />
Piotr Majcher<br />
Paul Springer<br />
@ -55,6 +58,7 @@ Alan Kaatz<br />
Tina Li<br />
Timmy Liu<br />
Wei Liu<br />
Tim Martin<br />
Duane Merrill<br />
Kevin Siu<br />
Markus Tavenrath<br />

View File

@ -248,11 +248,15 @@ function(cutlass_unify_source_files TARGET_ARGS_VAR)
message(FATAL_ERROR "TARGET_ARGS_VAR parameter is required")
endif()
if (NOT DEFINED __BATCH_SOURCES)
set(__BATCH_SOURCES ON)
endif()
if (__BATCH_SOURCES AND NOT DEFINED __BATCH_SIZE)
set(__BATCH_SIZE ${CUTLASS_UNITY_BUILD_BATCH_SIZE})
endif()
if (CUTLASS_UNITY_BUILD_ENABLED AND DEFINED __BATCH_SIZE AND __BATCH_SIZE GREATER 1)
if (CUTLASS_UNITY_BUILD_ENABLED AND __BATCH_SOURCES AND __BATCH_SIZE GREATER 1)
set(CUDA_FILE_ARGS)
set(TARGET_SOURCE_ARGS)

View File

@ -1,8 +1,8 @@
![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition")
# CUTLASS 3.3
# CUTLASS 3.4
_CUTLASS 3.3 - October 2023_
_CUTLASS 3.4 - December 2023_
CUTLASS is a collection of CUDA C++ template abstractions for implementing
high-performance matrix-matrix multiplication (GEMM) and related computations at all levels
@ -41,17 +41,14 @@ and improves code composability and readability. More documentation specific to
In addition to GEMMs, CUTLASS implements high-performance convolution via the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline. This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components.
# What's New in CUTLASS 3.3
# What's New in CUTLASS 3.4
CUTLASS 3.3.0 is an update to CUTLASS adding:
CUTLASS 3.4.0 is an update to CUTLASS adding:
- New [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input types with optimal performance.
- New [Mixed-input Ampere GEMMs](https://github.com/NVIDIA/cutlass/pull/1084) with support for canonical layouts (TN). The implementation supports upcast on operandB {fp16, bf16} x {s8, u8} and upcast on operandA {s8, u8} x {fp16, bf16}. They also include fast numeric conversion recipes and warp level shuffles to achieve optimal performance.
- New [Copy Async based Hopper GEMMs](/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_cooperative.cu) - which support lower than 16B aligned input tensors (across s8/fp8/fp16/bf16/tf32 types) with optimal performance. As a part of this, new kernel schedules, and Copy Ops [SM80\_CP\_ASYNC\_CACHE\_\*](/include/cute/arch/copy_sm80.hpp) were also added.
- EVT Support for RELU with Aux bitmap tensor store (used in dRELU). See [SM90 EVT fusions](/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp) for details.
- Various subbyte enhancements like tagged device ptrs, support for vectorized copy, various operators to treat subbyte iterators as pointers, and full-fledged CuTe Tensor support.
- Support for Clang as a host compiler.
- Support for void-C kernels and SM80 mixed-input GEMMs in the CUTLASS Python interface
- Improved [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) supporting {16-bit, 8-bit} x {8-bit, 4-bit} input types with fast numerical converters and group scaling factors tuned for optimal performance on Hopper H100.
- Beta release of [Pointer-Array Batched GEMMs](/examples/56_hopper_ptr_array_batched_gemm) utilizing TMA and Hopper H100 tensor cores now available. (Requires CUDA 12.3 or above)
- Beta release of [Group-GEMM](/examples/57_hopper_grouped_gemm) - commonly used in optimization of Mixture-Of-Expert models, is now available on Hopper GPUs taking advantage of TMA and Hopper H100 tensor cores. (Requires CUDA 12.3 or above)
- Impovements to NamedBarriers including details of [ReservedNamedBarriers](/include/cutlass/arch/barrier.h) used within the CUTLASS library.
Minimum requirements:
@ -95,7 +92,7 @@ as shown in the above figure. Tensor Core operations are implemented using CUDA
CUTLASS requires a C++17 host compiler and
performs best when built with the [**CUDA 12.2.2 Toolkit**](https://developer.nvidia.com/cuda-toolkit-archive).
It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, CUDA 12.0 and CUDA 12.1.
It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, CUDA 12.0, CUDA 12.1, CUDA 12.2.2 and CUDA 12.3.1
## Operating Systems
We have tested the following environments.
@ -107,6 +104,7 @@ We have tested the following environments.
| Ubuntu 22.04 | GCC 11.2.0 |
| Ubuntu 22.04 | Clang 10.0.0 |
| Ubuntu 22.04 | Clang 14.0.6 |
| Ubuntu 22.04 | Clang 17.0.6 |
| Windows 10.0 | Visual Studio 2019 v16.11.27 |
Note: GCC 8.5.0 has known regressions regarding fold expressions and overloaded operators. Using GCC 7.5.0 or (preferred) GCC >= 9 is recommended.

View File

@ -2,10 +2,16 @@
set(TEST_SETS_SUPPORTED @TEST_SETS_SUPPORTED@)
#? if (DEFINED ENV{CUTLASS_TEST_SET} AND NOT ENV{CUTLASS_TEST_SET} IN_LIST TEST_SETS_SUPPORTED)
#? message(STATUS "Skipping tests for @TEST_EXE_PATH@ as $ENV{CUTLASS_TEST_SET} is not in the set of ${TEST_SETS_SUPPORTED}.")
#? return()
#? endif()
if (NOT DEFINED ENV{CUTLASS_TEST_SETS})
set(ENV{CUTLASS_TEST_SETS} @CUTLASS_DEFAULT_ACTIVE_TEST_SETS@)
endif()
foreach(TEST_SET_REQUESTED IN ITEMS $ENV{CUTLASS_TEST_SETS})
if (NOT TEST_SET_REQUESTED IN_LIST TEST_SETS_SUPPORTED)
message(STATUS "Skipping tests for @TEST_EXE_PATH@ as ${TEST_SET_REQUESTED} is not in the set of [${TEST_SETS_SUPPORTED}].")
return()
endif()
endforeach()
set(TEST_EXE_PATH @TEST_EXE_PATH@)
set(TEST_EXE_WORKING_DIRECTORY @TEST_EXE_WORKING_DIRECTORY@)

View File

@ -7,6 +7,3 @@ if(TARGET nvidia::cutlass::CUTLASS)
endif()
include("${NvidiaCutlass_CMAKE_DIR}/NvidiaCutlassTargets.cmake")
# For backward compatibility with the old name
add_library(cutlass_lib ALIAS cutlass_library)

View File

@ -291,8 +291,8 @@ int run() {
LayoutInputB,
ElementOutput,
LayoutOutput,
int32_t,
int32_t>
ElementComputeEpilogue,
ElementComputeEpilogue>
gemm_device;
// Launch device reference gemm kernel

View File

@ -279,7 +279,7 @@ struct Options {
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "28_ampere_gemm_bias_fusion example\n\n"
out << "23_ampere_operand_gemm_reduction_fusion\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --m=<int> GEMM M\n"
@ -297,7 +297,7 @@ struct Options {
<< " --tag=<string> String to replicate across the first column in the results table\n";
out << "\n\nExamples:\n\n"
<< "$ ./examples/23_ampere_gemm_bias_fusion_example/ampere_gemm_bias_fusion --m=1024 --n=1024 --k=1024 \n\n";
<< "$ ./examples/23_ampere_gemm_operand_reduction_fusion/23_ampere_gemm_operand_reduction_fusion --m=1024 --n=1024 --k=1024 \n\n";
return out;
}

View File

@ -602,7 +602,7 @@ Result profile_convolution(Options const &options) {
std::stringstream ss;
ss << "26_ampere_fused_wgrad_batch_normalization_"
ss << "30_wgrad_split_k_"
<< options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c()
<< "_"
<< options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c()

View File

@ -251,7 +251,7 @@ struct Options {
<< " --tag=<string> String to replicate across the first column in the results table\n";
out << "\n\nExamples:\n\n"
<< "$ ./examples/31_transposed_conv2d/31_transposed_conv2d --n=8 --h=32 --w=32 --c=16 --k=32 --r=3 --s=3\n\n";
<< "$ ./examples/34_transposed_conv2d/34_transposed_conv2d --n=8 --h=32 --w=32 --c=16 --k=32 --r=3 --s=3\n\n";
return out;
}

View File

@ -398,7 +398,7 @@ struct Options {
<< "$ ./examples/38_syr2k_grouped/38_syr2k_grouped --benchmark=problems.txt\n\n"
<< "# Execute Grouped SYR2K and profile with NSight\n"
<< "$ nv-nsight-cu-cli ./examples/24_gemm_grouped/24_gemm_grouped --n=256 --k=256 --verbose=true --iterations=1 --reference-check=false\n\n";
<< "$ nv-nsight-cu-cli ./examples/38_syr2k_grouped/38_syr2k_grouped --n=256 --k=256 --verbose=true --iterations=1 --reference-check=false\n\n";
return out;
}

View File

@ -306,7 +306,7 @@ struct Options {
/// Prints the usage statement.
std::ostream &print_usage(std::ostream &out) const {
out << "41_depthwise_gemm_fprop example\n\n"
out << "46_depthwise_gemm_fprop example\n\n"
<< " This example uses Ampere's Tensor Core operators on F16 data types to compute\n"
<< " forward convolution on tensors of layout NHWC.\n\n"
<< "Options:\n\n"
@ -554,7 +554,7 @@ Result profile_convolution(Options const &options) {
if (options.save_workspace) {
std::stringstream ss;
ss << "45_depthwise_simt_conv2dfprop" << options.input_size.n() << "x" << options.input_size.h()
ss << "46_depthwise_simt_conv2dfprop" << options.input_size.n() << "x" << options.input_size.h()
<< "x" << options.input_size.w() << "x" << options.input_size.c() << "_"
<< options.filter_size.n() << "x" << options.filter_size.h() << "x"
<< options.filter_size.w() << "x" << options.filter_size.c() << ".dat";

View File

@ -291,7 +291,7 @@ struct ExampleRunner {
using CustomEVT = // alpha * acc + beta * C
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::homogeneous_multiply_add, ElementD, ElementCompute, RoundStyle>, // beta * C + (alpha * acc)
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementScalar>, // beta
cutlass::epilogue::fusion::Sm90SrcFetch, // C
cutlass::epilogue::fusion::Sm90SrcFetch<ElementC>, // C
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementScalar>, // alpha
cutlass::epilogue::fusion::Sm90AccFetch // acc
@ -302,7 +302,7 @@ struct ExampleRunner {
// Users can select one of these operations by passing one of the tags defined in include/cutlass/epilogue/fusion/operations.hpp
// to the CollectiveBuilder. This frees the user from having to compute additional parameters such as stage counts and copy atoms/layouts.
// These tags also provide additional metadata that can be queried at compile time.
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementScalar, RoundStyle>;
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementScalar, RoundStyle>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,

View File

@ -568,7 +568,7 @@ struct ExampleRunner
if (options.reference_check) {
if (!verify()) {
std::cout << "Failed validation" << std::endl;
#if 1
#if 0
debug_output(std::cout);
#endif
return false;

View File

@ -122,7 +122,7 @@ public:
static_assert(cute::size(GmemTiledCopyA{}) == cute::size(GmemTiledCopyB{}), "Number of threads in A/B tiled copies must be the same.");
static constexpr uint32_t NumLoadWarpGroups = cute::size(GmemTiledCopyA{}) / NumThreadsPerWarpGroup;
static constexpr uint32_t NumMmaWarpGroups = cute::size(TiledMma{}) / NumThreadsPerWarpGroup;
static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(cute::size(TiledMma{})) / NumThreadsPerWarpGroup;
static constexpr uint32_t NumWarpGroups = NumLoadWarpGroups + NumMmaWarpGroups;
static_assert(NumWarpGroups == 2 || NumWarpGroups == 3, "Number of warp groups must be 2 or 3 for good performance.");

View File

@ -109,6 +109,8 @@
namespace example
{
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
struct Options {
bool help;
@ -724,6 +726,7 @@ private:
return true;
}
};
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
} // namespace example
@ -749,7 +752,7 @@ int main(int argc, char const **argv)
if (notSupported) {
return EXIT_SUCCESS; // Do not fail CI checks on unsupported systems
}
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
example::Options options;
options.parse(argc, argv);
@ -970,6 +973,6 @@ int main(int argc, char const **argv)
result &= runner.run(options);
}
#endif
return result ? EXIT_SUCCESS : EXIT_FAILURE;
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
}

View File

@ -109,9 +109,11 @@ using ElementD = ElementC;
using LayoutD = LayoutC;
constexpr int AlignmentD = AlignmentC;
// Auxiliary matrix configuration
// Auxiliary matrix configuration and other fusion types
using ElementAux = ElementC;
using LayoutAux = LayoutC;
using ElementAmax = float;
using ElementBias = float;
// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
@ -124,7 +126,7 @@ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux<
LayoutAux, cutlass::epilogue::thread::ReLU, ElementD, ElementCompute, ElementAux>;
LayoutAux, cutlass::epilogue::thread::ReLU, ElementD, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementC>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,

View File

@ -38,19 +38,41 @@
using INT8 tensor cores.
The narrower type always passes through the register file. Therefore, in cases where the narrower type is operand B, the collective will implicitly swap
A and B in the main loop. Consequently, it is essential to consider this when constructing the epilogue, as illustrated in this example.
A and B in the main loop. However, implicit swaps do not support TMA epilogues. Consequently, it is essential to consider this when constructing the epilogue,
as illustrated in this example.
Note that in this example, we explicitly swap A and B in order to use TMA epilogues. We do this since TMA epilogues are more performant on problem sizes of interest.
It is expected that the scale's K dimension be scale_k = ceil_div(problem_k, group_size).
Scales are always expected to be MN major. This means the fastest changing dimension must be M if A is scaled or N if B is scaled.
If A is being scaled, the scales should have shape [M, scale_k], while if B is scaled, it must have shape [N, scale_k].
The implementation only supports "group-wise" scales. However, we can make it work for per-column scales by setting the groups size
equal to the gemm problem K.
Limitations:
1) Only supported combinations are 16-bit x {8-bit, 4-bit, 2-bit} and {8-bit} x {4-bit, 2-bit}.
2) The narrow type must always be in K-major format.
3) When dealing with 8-bit x {4-bit, 2-bit}, both inputs must be in K-major format.
4) Currently, TMA epilogues cannot be used when the narrow type is the B operand. This limitation arises because the implementation always swaps the
operands to ensure the narrow type passes through the register file, and TMA epilogues do not currently support swap + transpose operations.
We plan to address this limitation in the future.
3) The scales and zeros must be MN major. That means if A is scaled, it must be column major, but if B is scaled it must be row major.
4) The scales and the zeros must have the same layout and groupsize.
5) When dealing with 8-bit x {4-bit, 2-bit}, both inputs must be in K-major format.
6) Currently, TMA epilogues cannot be used when the narrow type is the B operand. This limitation arises because the implementation always swaps the
operands to ensure that the narrow type passes through the register file, and TMA epilogues do not currently support implicit swap + transpose operations.
We plan to address this limitation in the future. However, we address this in the example by explicitly swapping and transposing the operands.
Examples:
Runs the mixed input batched gemm (with batch size 2), converting B to the type of A (mode 0)
$ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm --m=2048 --n=2048 --k=2048 --l=2 --mode=0
$ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm --m=2048 --n=2048 --k=2048 --l=2
Runs the mixed input gemm, and applies a scaling factor to B before mma (mode 1). Applies a vector of scales to the entire
matrix (group size is the same as the gemm k dimension).
$ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm --m=4096 --n=5120 --k=8192 --g=8192 --mode=1
Runs the mixed input gemm, and applies a scaling factor and adds a zero-point to B before mma (mode 2). Uses a group size of 128.
$ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm --m=2048 --n=5120 --k=8192 --g=128 --mode=2
*/
#include <iostream>
@ -79,20 +101,28 @@
#include "cutlass/util/reference/host/gett.hpp"
#include "helper.h"
#include "unfused_weight_dequantize.h"
#include "unfused_weight_dequantize.hpp"
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
// This is just an example, so we use a regular enum so we can compare directly to the command-line int.
enum GemmMode {
ConvertOnly,
ScaleOnly,
ScaleWithZeroPoint
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
using MmaType = cutlass::half_t;
using QuantType = int8_t;
using MmaType = cutlass::float_e4m3_t;
using QuantType = cutlass::int4b_t;
constexpr int TileShapeK = 128 * 8 / sizeof_bits<MmaType>::value;
// A matrix configuration
using ElementA = MmaType; // Element type for A matrix operand
using ElementA = MmaType; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
@ -101,6 +131,14 @@ using ElementB = QuantType; // E
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// This example manually swaps and transposes, so keep transpose of input layouts
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
using ElementZero = cutlass::half_t;
using ElementScale = cutlass::half_t;
using LayoutScale = cutlass::layout::RowMajor;
// C/D matrix configuration
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
@ -109,51 +147,107 @@ constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // M
// D matrix configuration
using ElementD = ElementC;
using LayoutD = LayoutC;
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ElementCompute = float; // Element type for epilogue computation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_128,_256,_64>; // Threadblock-level tile size
using TileShape = Shape<_128,_256,cute::Int<TileShapeK>>; // Threadblock-level tile size
using ClusterShape = Shape<_2,_1,_1>; // Shape of the threadblocks in a cluster
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default setting in the Collective Builder
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput; // Kernel to launch based on the default setting in the Collective Builder
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
EpilogueTileType,
ElementAccumulator, ElementAccumulator,
// Lie here about layout of C and D since we do swap and transpose trick
// Transpose layout of D here since we use explicit swap + transpose
// the void type for C tells the builder to allocate 0 smem for the C matrix.
// We can enable this if beta == 0 by changing ElementC to void below.
ElementC, typename cutlass::layout::LayoutTranspose<LayoutC>::type, AlignmentC,
ElementC, typename cutlass::layout::LayoutTranspose<LayoutC>::type, AlignmentC,
cutlass::epilogue::NoSmemWarpSpecialized // This is the only epi supporting the required swap + transpose.
ElementD, typename cutlass::layout::LayoutTranspose<LayoutD>::type, AlignmentD,
EpilogueSchedule // This is the only epi supporting the required swap + transpose.
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
// ============================================================ MIXED INPUT NO SCALES ============================================================================
// The collective will infer that the narrow type should be upcasted to the wide type.
// We swap A and B operands to the builder here
using CollectiveMainloopConvertOnly = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementB, LayoutB_Transpose, AlignmentB,
ElementA, LayoutA_Transpose, AlignmentA,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
>,
KernelSchedule
>::CollectiveOp;
using GemmKernelConvertOnly = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveMainloopConvertOnly,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using GemmConvertOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelConvertOnly>;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
// =========================================================== MIXED INPUT WITH SCALES ===========================================================================
// The Scale information must get paired with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the scale information.
using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
cute::tuple<ElementB, ElementScale>, LayoutB_Transpose, AlignmentB,
ElementA, LayoutA_Transpose, AlignmentA,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
>,
KernelSchedule
>::CollectiveOp;
using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloopScaleOnly,
CollectiveEpilogue
>;
using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleOnly>;
// =========================================================== MIXED INPUT WITH SCALES AND ZEROS ==================================================================
// We specify scale + zero elements to indicate that we require both. Scales and biases have the same format.
using CollectiveMainloopScaleWithZeroPoint = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
cute::tuple<ElementB, ElementScale, ElementZero>, LayoutB_Transpose, AlignmentB,
ElementA, LayoutA_Transpose, AlignmentA,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
>,
KernelSchedule
>::CollectiveOp;
using GemmKernelScaleWithZeroPoint = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloopScaleWithZeroPoint,
CollectiveEpilogue
>;
using GemmScaleWithZeroPoint = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleWithZeroPoint>;
// =================================================================================================================================================================
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>;
using StrideC = typename GemmKernelScaleWithZeroPoint::StrideC;
using StrideD = typename GemmKernelScaleWithZeroPoint::StrideD;
using StrideC_ref = cutlass::detail::TagToStrideC_t<LayoutC>;
using StrideD_ref = cutlass::detail::TagToStrideC_t<LayoutD>;
//
// Data members
@ -163,20 +257,25 @@ using StrideD = typename Gemm::GemmKernel::StrideD;
StrideA stride_A;
StrideB stride_B;
StrideC stride_C;
StrideC_ref stride_C_ref;
StrideD stride_D;
StrideD_ref stride_D_ref;
uint64_t seed;
// Initialization functions don't handle sub-byte types so we use uint8 to initialize and a separate
// kernel to pack the data if it is necessary.
using InitializationType = cute::conditional_t<cute::sizeof_bits_v<QuantType> < 8, uint8_t, QuantType>;
// Scale and Zero share a stride since the layout and shapes must be the same.
using StrideS = typename CollectiveMainloopScaleWithZeroPoint::StrideScale;
using StrideS_ref = cutlass::detail::TagToStrideB_t<LayoutScale>;
StrideS stride_S;
StrideS_ref stride_S_ref;
cutlass::HostTensor<typename Gemm::ElementA, LayoutA> tensor_A;
cutlass::HostTensor<InitializationType, LayoutB> tensor_B_init;
cutlass::HostTensor<typename Gemm::ElementB, LayoutB> tensor_B;
cutlass::HostTensor<typename Gemm::ElementA, LayoutB> tensor_B_dq;
cutlass::HostTensor<typename Gemm::ElementC, LayoutC> tensor_C;
cutlass::HostTensor<typename Gemm::EpilogueOutputOp::ElementOutput, LayoutD> tensor_D;
cutlass::HostTensor<typename Gemm::EpilogueOutputOp::ElementOutput, LayoutD> tensor_ref_D;
cutlass::HostTensor<MmaType, LayoutA> tensor_A;
cutlass::HostTensor<QuantType, LayoutB> tensor_B;
cutlass::HostTensor<MmaType, LayoutB> tensor_B_dq;
cutlass::HostTensor<ElementScale, LayoutScale> tensor_scale;
cutlass::HostTensor<ElementZero, LayoutScale> tensor_zero;
cutlass::HostTensor<ElementC, LayoutC> tensor_C;
cutlass::HostTensor<typename GemmScaleWithZeroPoint::EpilogueOutputOp::ElementOutput, LayoutD> tensor_D;
cutlass::HostTensor<typename GemmScaleWithZeroPoint::EpilogueOutputOp::ElementOutput, LayoutD> tensor_ref_D;
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
@ -192,7 +291,9 @@ struct Options {
float alpha = 1.0f;
float beta = 0.0f;
int iterations = 1000;
int mode = 2;
int m = 5120, n = 4096, k = 4096;
int g = 128;
int l = 1;
// Parses the command line
@ -208,6 +309,8 @@ struct Options {
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("l", l);
cmd.get_cmd_line_argument("g", g);
cmd.get_cmd_line_argument("mode", mode);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
@ -224,13 +327,15 @@ struct Options {
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --l=<int> The number of independent gemm problems with mnk shape\n"
<< " --g=<int> The size of each group for the scales and zeros. To broadcast a vector of scales or zeros, set the group size to K.\n"
<< " --mode=<int> The mode to run the gemm. 0 does (A @ B), 1 means A @ (scale * B), 2 means A @ (scale * B + zero-point).\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << "55_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 --l=10 --alpha=2 --beta=0.707 \n\n";
<< "$ " << "55_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 -g 0 --l=10 --alpha=2 --mode=2 --beta=0.707 \n\n";
return out;
}
@ -289,114 +394,146 @@ bool initialize_tensor(
scope_min = -8;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
view, seed, scope_max, scope_min);
return true;
}
template <class QuantElement, typename Element, typename Layout>
template <typename Element, typename Layout>
bool initialize_quant_tensor(
cutlass::TensorView<Element, Layout> view,
uint64_t seed=2023) {
Element scope_max, scope_min;
constexpr int bits_input = cute::sizeof_bits_v<QuantElement>;
static_assert(bits_input <= 8, "Quantization type can be at most 8 bits");
if constexpr (bits_input == 8) {
// Directly init 1-byte types
static_assert(cute::is_same_v<QuantElement, Element>, "Init type should equal quant type for 1 byte types");
scope_max = std::numeric_limits<QuantElement>::max();
scope_min = std::numeric_limits<QuantElement>::min();
} else {
static_assert(cute::is_same_v<uint8_t, Element>, "Init type should be uint8_t for sub-byte types");
scope_max = (1 << bits_input);
scope_min = 0;
}
float scope_min = float(cutlass::platform::numeric_limits<Element>::lowest());
float scope_max = float(cutlass::platform::numeric_limits<Element>::max());
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
view, seed, scope_max, scope_min);
return true;
}
template <class Element, class Layout>
bool initialize_with_one(
cutlass::TensorView<Element, Layout> view) {
cutlass::reference::host::TensorFill(view, Element(1.0f));
bool initialize_scale(
cutlass::TensorView<Element, Layout> view,
const Options &options) {
if (options.mode == GemmMode::ConvertOnly) {
// No scales, so just initialize with 1 so we can use the same kernel to dequantize the data.
cutlass::reference::host::TensorFill(view, Element(1.0f));
}
else {
float elt_max_f = float(cutlass::platform::numeric_limits<QuantType>::max());
const float max_dequant_val = 4.f;
const float min_dequant_val = 0.5f;
float scope_max(max_dequant_val / elt_max_f);
float scope_min(min_dequant_val / elt_max_f);
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min);
}
return true;
}
template <class ElementDst, class ElementSrc, class Layout, class L>
void prepare_packed_data(cutlass::HostTensor<ElementDst, Layout> view_dst_data,
cutlass::HostTensor<ElementSrc, Layout> view_src_data,
const L& cute_layout) {
if constexpr (cute::is_same_v<ElementSrc, ElementDst>) {
view_dst_data.copy_in_device_to_device(view_src_data.device_data());
}
else {
pack_data(view_dst_data.device_data(), view_src_data.device_data(), cute_layout);
template <class Element, class Layout>
bool initialize_zero(
cutlass::TensorView<Element, Layout> view,
const Options &options) {
if (options.mode == GemmMode::ScaleWithZeroPoint) {
cutlass::reference::host::TensorFillRandomUniform(
view, seed, 2.0f, -2.0f);
} else {
// No bias, so just initialize with 1 so we can use the same kernel to dequantize the data.
cutlass::reference::host::TensorFill(view, Element(0.0f));
}
return true;
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const Options &options) {
auto shape_b = cute::make_shape(options.n, options.k, options.l);
const int scale_k = (options.k + options.g - 1) / options.g;
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_b);
// Reverse stride here due to swap and transpose
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.n, options.m, options.l));
stride_C_ref = cutlass::make_cute_packed_stride(StrideC_ref{}, cute::make_shape(options.m, options.n, options.l));
// Reverse stride here due to swap and transpose
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.n, options.m, options.l));
stride_D_ref = cutlass::make_cute_packed_stride(StrideD_ref{}, cute::make_shape(options.m, options.n, options.l));
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
tensor_A.resize(a_coord);
tensor_B_init.resize(b_coord);
tensor_B.resize(b_coord);
tensor_B_dq.resize(b_coord);
tensor_C.resize(c_coord);
tensor_D.resize(c_coord);
tensor_ref_D.resize(c_coord);
// We need scales since the "dequantize" kernels expects them. We just set them to 1 so the values get converted
// to the mma type.
cutlass::HostTensor<MmaType, cutlass::layout::RowMajor> tensor_scale;
tensor_scale.resize({1 * options.l, options.n});
tensor_scale.resize({scale_k * options.l, options.n});
tensor_zero.resize({scale_k * options.l, options.n});
initialize_tensor(tensor_A.host_view(), seed + 2022);
initialize_quant_tensor<QuantType>(tensor_B_init.host_view(), seed + 2021);
initialize_quant_tensor(tensor_B.host_view(), seed + 2021);
initialize_tensor(tensor_C.host_view(), seed + 2020);
initialize_with_one(tensor_scale.host_view());
initialize_scale(tensor_scale.host_view(), options);
initialize_zero(tensor_zero.host_view(), options);
tensor_A.sync_device();
tensor_B_init.sync_device();
tensor_B.sync_device();
tensor_C.sync_device();
tensor_scale.sync_device();
tensor_zero.sync_device();
auto layout_B = make_layout(shape_b, stride_B);
prepare_packed_data(tensor_B, tensor_B_init, layout_B);
auto shape_scale = cute::make_shape(options.n, 1, options.l);
auto layout_scale = make_layout(shape_scale);
dequantize_weight(tensor_B_dq.device_data(), tensor_B.device_data(), layout_B, tensor_scale.device_data(), layout_scale);
auto shape_scale_zero = cute::make_shape(options.n, scale_k, options.l);
stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(options.n, scale_k, options.l));
stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l));
auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref);
tensor_B.sync_host();
dequantize_weight(tensor_B_dq.device_data(), tensor_B.device_data(), layout_B, tensor_scale.device_data(), tensor_zero.device_data(), layout_scale_zero, options.g);
tensor_B_dq.sync_host();
}
/// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options)
template <typename Args>
Args args_from_options(const Options &options)
{
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{options.m, options.n, options.k, options.l},
{tensor_A.device_data(), stride_A, tensor_B.device_data(), stride_B},
{{options.alpha, options.beta}, tensor_C.device_data(), stride_C, tensor_D.device_data(), stride_D}
};
return arguments;
// Swap the A and B tensors, as well as problem shapes here.
if (options.mode == GemmMode::ConvertOnly) {
return Args {
cutlass::gemm::GemmUniversalMode::kGemm,
{options.n, options.m, options.k, options.l},
{tensor_B.device_data(), stride_B, tensor_A.device_data(), stride_A},
{{options.alpha, options.beta}, tensor_C.device_data(), stride_C, tensor_D.device_data(), stride_D}
};
}
else if (options.mode == GemmMode::ScaleOnly) {
return Args {
cutlass::gemm::GemmUniversalMode::kGemm,
{options.n, options.m, options.k, options.l},
{tensor_B.device_data(), stride_B, tensor_A.device_data(), stride_A, tensor_scale.device_data(), stride_S, options.g},
{{options.alpha, options.beta}, tensor_C.device_data(), stride_C, tensor_D.device_data(), stride_D}
};
}
else if (options.mode == GemmMode::ScaleWithZeroPoint) {
return Args {
cutlass::gemm::GemmUniversalMode::kGemm,
{options.n, options.m, options.k, options.l},
{tensor_B.device_data(), stride_B, tensor_A.device_data(), stride_A, tensor_scale.device_data(), stride_S, options.g, tensor_zero.device_data()},
{{options.alpha, options.beta}, tensor_C.device_data(), stride_C, tensor_D.device_data(), stride_D}
};
} else {
std::cerr << "Invalid mode " << options.mode << ". Must be 0, 1 or 2." << std::endl;
exit(-1);
}
}
bool verify(const Options &options) {
@ -404,44 +541,64 @@ bool verify(const Options &options) {
// Compute reference output
//
// Create instantiation for device reference gemm kernel
auto A = cute::make_tensor(tensor_A.host_data(),
cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A));
auto B = cute::make_tensor(tensor_B_dq.host_data(),
cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B));
auto C = cute::make_tensor(tensor_C.host_data(),
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_C));
auto D = cute::make_tensor(tensor_ref_D.host_data(),
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D));
// In this example, we use the GPU default kernels as a reference (unfused scale)
// This is to avoid numerical differences from different accumulation order.
using unused_t = decltype(D);
// Again, due to numerical differences, we must use fast acc here when the mma type is
// FP8 as the fused implementation only supports fast acc at the moment.
constexpr bool IsFP8Input = cute::is_same_v<MmaType, cutlass::float_e4m3_t> || cute::is_same_v<MmaType, cutlass::float_e5m2_t>;
using FP8Sched = cute::conditional_t<size<0>(TileShape{}) == 64, cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum>;
using ScheduleRef = cute::conditional_t<IsFP8Input, FP8Sched, cutlass::gemm::collective::KernelScheduleAuto>;
cutlass::reference::host::GettMainloopParams<ElementAccumulator, decltype(A), decltype(B)> mainloop_params{A, B};
cutlass::reference::host::GettEpilogueParams<
typename Gemm::EpilogueOutputOp::ElementScalar,
typename Gemm::EpilogueOutputOp::ElementScalar,
using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
MmaType, LayoutA, AlignmentA,
MmaType, LayoutB, AlignmentB,
ElementAccumulator,
ElementCompute,
decltype(C),
decltype(D),
unused_t, // bias
unused_t, // aux
unused_t, // valpha
unused_t // vbeta
> epilogue_params;
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAuto,
ScheduleRef
>::CollectiveOp;
epilogue_params.C = C;
epilogue_params.D = D;
epilogue_params.alpha = options.alpha;
epilogue_params.beta = options.beta;
using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, AlignmentC,
ElementD, LayoutD, AlignmentD,
cutlass::epilogue::NoSmemWarpSpecialized
>::CollectiveOp;
// get reference result
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloopRef,
CollectiveEpilogueRef
>;
using GemmRef = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelRef>;
typename GemmRef::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{options.m, options.n, options.k, options.l},
{tensor_A.device_data(), stride_A, tensor_B_dq.device_data(), stride_B},
{{options.alpha, options.beta}, tensor_C.device_data(), stride_C_ref, tensor_ref_D.device_data(), stride_D_ref}
};
// Run the gemm where the scaling is performed outside of the kernel.
GemmRef gemm_ref;
size_t workspace_size = GemmRef::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
CUTLASS_CHECK(gemm_ref.can_implement(arguments));
CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm_ref.run());
// compare_reference
tensor_D.sync_host();
bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view());
tensor_ref_D.sync_host();
const ElementD epsilon(1e-2f);
const ElementD non_zero_floor(1e-4f);
bool passed = cutlass::reference::host::TensorRelativelyEquals(tensor_ref_D.host_view(), tensor_D.host_view(), epsilon, non_zero_floor);
return passed;
}
@ -455,7 +612,7 @@ int run(Options &options)
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options);
auto arguments = args_from_options<typename Gemm::Arguments>(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
@ -549,7 +706,26 @@ int main(int argc, char const **args) {
//
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
run<Gemm>(options);
if (options.mode == GemmMode::ConvertOnly) {
std::cout << "Running in no scale mode." << std::endl;
run<GemmConvertOnly>(options);
}
else if (options.mode == GemmMode::ScaleOnly) {
if (options.g == options.k) {
std::cout << "Running in per-column scale mode." << std::endl;
} else {
std::cout << "Running in group scale mode." << std::endl;
}
run<GemmScaleOnly>(options);
}
else if (options.mode == GemmMode::ScaleWithZeroPoint) {
if (options.g == options.k) {
std::cout << "Running in per-column scale and zero mode." << std::endl;
} else {
std::cout << "Running in group scale and zero mode." << std::endl;
}
run<GemmScaleWithZeroPoint>(options);
}
#endif
return 0;

View File

@ -27,9 +27,33 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Note that we set --iterations=0 for all tests below to disable the performance benchmarking.
# Only the correctness check will be run by these commands.
set(TEST_DIRECT_BATCHED --m=2048 --n=2048 --k=2048 --l=2 --mode=0 --iterations=0) # Direct conversion
set(TEST_SCALE_PERCOL --m=4096 --n=5120 --k=8192 --g=8192 --mode=1 --iterations=0) # Per Column scaling
set(TEST_SCALE_ZERO_PERCOL --m=4096 --n=5120 --k=8192 --g=8192 --mode=2 --iterations=0) # Per Column scaling
set(TEST_SCALE_GROUP --m=2048 --n=5120 --k=8192 --g=512 --mode=1 --iterations=0) # Group-wise scaling
set(TEST_SCALE_ZERO_GROUPED --m=2048 --n=5120 --k=8192 --g=256 --mode=2 --iterations=0) # Group-wise scaling with zero-point
set(TEST_SCALE_RESIDUE --m=128 --n=128 --k=320 --g=128 --mode=1 --iterations=0) # Final group has residue
set(TEST_SCALE_ZERO_RESIDUE --m=128 --n=128 --k=192 --g=128 --mode=2 --iterations=0) # Final group has residue
set(TEST_ALPHA_BETA --alpha=0.5 --beta=0.7 --mode=2 --iterations=0) # Alpha and Beta with default shapes
cutlass_example_add_executable(
55_hopper_mixed_dtype_gemm
55_hopper_mixed_dtype_gemm.cu
TEST_COMMAND_OPTIONS
TEST_DIRECT_BATCHED
TEST_SCALE_PERCOL
TEST_SCALE_ZERO_PERCOL
TEST_SCALE_GROUP
TEST_SCALE_ZERO_GROUPED
TEST_SCALE_RESIDUE
TEST_SCALE_ZERO_RESIDUE
TEST_ALPHA_BETA
)

View File

@ -9,17 +9,14 @@ This first version only supports mixed type GEMMs using TMA.
## Performance
While the example offers a harness for straightforward benchmarking, this initial implementation isn't optimized for performance in the majority of scenarios. We expect this implementation to be performant for `{fp16, bf16} x int8` for problems that are compute bound.
While the example offers a harness for straightforward benchmarking, this initial implementation isn't optimized for performance in the majority of scenarios. We expect this implementation to be performant for `{fp16, bf16} x {int8, int4}` and `{fp8} x {int4}` for problems that are compute bound. Additionally, we expect good performance for `fp16, bf16` or `fp32` scales and zero-points. For best performance, it is ideal to have the scales and zero-points be the same type.
We are currently optimizing the following cases:
1. Memory bound cases for all types
1. Compute bound cases for `{16-bit, 8-bit} x {4-bit, 2-bit}`
As a result, we do not suggest using this example as a benchmarking reference until all of our optimizations are complete (this will be clearly stated in this README in a future release).
## Limitations
* The type that needs to be converted must go through the register file. This means that the collective will swap and transpose whenever the type with fewer bits is the B operand. The user must be aware of when these swaps happen to control the layout of the epilogue as shown in the example. Note that TMA epilogues currently do not support swap + transpose, so non-tma epilogues must be used in this case. We plan to relax this limitation in a future release.
* The type that needs to be converted must go through the register file. This means that the collective will swap and transpose whenever the type with fewer bits is the B operand. The user must be aware of when these swaps happen. Note that TMA epilogues currently do not support *implicit* swap + transpose, so non-tma epilogues must be used in this case. We plan to relax this limitation in a future release.
* The layout of the narrow type must be K-major. This means the following:
* Narrow type is the A operand: Must be Row-Major
@ -29,8 +26,12 @@ As a result, we do not suggest using this example as a benchmarking reference un
* TMA requires an alignment of 128 bits. As a result, for a type with `B` bits, `B x TILE_K` must be a multiple of 128 bits.
* The type of the scale and zero-point type must be two bytes or more.
* The group size must be equal to gemm-k size (indicating a broadcast), or it must be a multiple of the threadblock-k size.
## Upcoming features
* Support for applying scales after conversion, but before issuing tensor core math (input scale fusion) is planned for v3.4.
* Optimizations for memory bound cases.
* Many optimizations for SOL performance.
* Optimizations for scale and zero-point loading when the group size is not equal to the threadblock-k size.

View File

@ -9,13 +9,15 @@ template <class QuantizedElement,
class DequantizedElement,
class OperandLayout,
class ElementScale,
class ElementZero,
class ScaleBroadCastLayout,
class ThrLayout>
__global__ void dequantize_weight_kernel(DequantizedElement* dq_buffer,
const QuantizedElement* q_buffer,
const OperandLayout operand_layout,
const ElementScale* scale_buffer,
const ScaleBroadCastLayout broadcasted_scale_layout,
QuantizedElement const* q_buffer,
OperandLayout const operand_layout,
ElementScale const* scale_buffer,
ElementZero const* zero_buffer,
ScaleBroadCastLayout const broadcasted_scale_layout,
ThrLayout thr_layout) {
using namespace cute;
@ -33,6 +35,7 @@ __global__ void dequantize_weight_kernel(DequantizedElement* dq_buffer,
// While the scales are expected to have shape [MN, G, L] but with a stride to allow broadcasting
// It is expected that K % G == 0
Tensor gmem_scale_broadcasted = make_tensor(make_gmem_ptr(scale_buffer), broadcasted_scale_layout);
Tensor gmem_zero_broadcasted = make_tensor(make_gmem_ptr(zero_buffer), broadcasted_scale_layout);
// Assign 1 thread per element in the thread block
auto blk_shape = make_shape(size<0>(thr_layout), _1{}, _1{}); //
@ -41,16 +44,21 @@ __global__ void dequantize_weight_kernel(DequantizedElement* dq_buffer,
// Tile across the block
auto gOp_dq = local_tile(gmem_op_dq, blk_shape, blk_coord);
auto gScale = local_tile(gmem_scale_broadcasted, blk_shape, blk_coord);
auto gZero = local_tile(gmem_zero_broadcasted, blk_shape, blk_coord);
auto gOp_q = local_tile(gmem_op_q, blk_shape, blk_coord);
auto tOpDq_gOpDq = local_partition(gOp_dq, thr_layout, threadIdx.x);
auto tScale_gScale = local_partition(gScale, thr_layout, threadIdx.x);
auto tZero_gZero = local_partition(gZero, thr_layout, threadIdx.x);
auto tOpQ_gOpQ = local_partition(gOp_q, thr_layout, threadIdx.x);
// Make a fragment of registers to hold gmem loads
Tensor rmem_op_q = make_fragment_like(tOpQ_gOpQ(_, _, _, 0));
Tensor rmem_scale = make_fragment_like(tScale_gScale(_, _, _, 0));
Tensor rmem_zero = make_fragment_like(tZero_gZero(_, _, _, 0));
Tensor rmem_op_dq = make_fragment_like(tOpDq_gOpDq(_, _, _, 0));
Tensor rmem_op_scaled = make_fragment_like<ElementScale>(rmem_op_dq);
Tensor rmem_zero_buf = make_fragment_like<ElementScale>(rmem_zero);
Tensor pred_id = make_identity_tensor(shape(operand_layout));
auto pred_blk_tile = local_tile(pred_id, blk_shape, blk_coord);
@ -63,8 +71,12 @@ __global__ void dequantize_weight_kernel(DequantizedElement* dq_buffer,
if (thread_offset < size<0>(operand_layout)) {
copy(tOpQ_gOpQ(_, _, _, ii), rmem_op_q);
copy(tScale_gScale(_, _, _, ii), rmem_scale);
transform(rmem_op_q, rmem_op_dq, [] (const QuantizedElement& elt) { return DequantizedElement(elt); } );
transform(rmem_op_dq, rmem_scale, rmem_op_dq, multiplies{});
copy(tZero_gZero(_, _, _, ii), rmem_zero);
transform(rmem_op_q, rmem_op_scaled, [] (const QuantizedElement& elt) { return ElementScale(elt); } );
transform(rmem_zero, rmem_zero_buf, [] (const ElementZero& elt) { return ElementScale(elt); } );
transform(rmem_op_scaled, rmem_scale, rmem_op_scaled, multiplies{});
transform(rmem_op_scaled, rmem_zero_buf, rmem_op_scaled, plus{});
transform(rmem_op_scaled, rmem_op_dq, [] (const ElementScale& elt) { return DequantizedElement(elt); } );
copy(rmem_op_dq, tOpDq_gOpDq(_, _, _, ii));
}
}
@ -74,12 +86,15 @@ template <class QuantizedElement,
class DequantizedElement,
class OperandLayout,
class ElementScale,
class ElementZero,
class ScaleLayout>
void dequantize_weight(DequantizedElement* dq_buffer,
const QuantizedElement* q_buffer,
const OperandLayout operand_layout,
const ElementScale* scale_buffer,
const ScaleLayout scale_layout) {
QuantizedElement const* q_buffer,
OperandLayout const operand_layout,
ElementScale const* scale_buffer,
ElementZero const* zero_buffer,
ScaleLayout const scale_layout,
int const group_size) {
using namespace cute;
@ -87,9 +102,9 @@ void dequantize_weight(DequantizedElement* dq_buffer,
auto thr_layout = make_layout(make_shape(Int<tpb>{}));
const auto num_rows = get<0>(shape(operand_layout));
const auto num_cols = get<1>(shape(operand_layout)); // [MN, K, L]
const auto gemm_k = get<1>(shape(operand_layout)); // [MN, K, L]
const auto batches = get<2>(shape(operand_layout)); // [MN, K, L]
const auto num_cols_scale = get<1>(shape(scale_layout)); // [MN, G, L]
const auto scale_k = get<1>(shape(scale_layout)); // [MN, Scale_K, L]
if (num_rows != size<0>(scale_layout)) {
std::cerr << "Invalid first dimension for scales. Must match first dim for weights."
@ -98,81 +113,18 @@ void dequantize_weight(DequantizedElement* dq_buffer,
exit(-1);
}
if (num_cols % num_cols_scale != 0) {
std::cerr << "Invalid shape for weight / scales. Weight cols must be a multiple of scale cols."
<< " But got shapes " << shape(operand_layout) << " " << shape(scale_layout)
<< std::endl;
exit(-1);
}
const auto scale_stride0 = get<0>(stride(scale_layout));
const auto scale_stride1 = get<1>(stride(scale_layout));
const auto scale_stride2 = get<2>(stride(scale_layout));
auto scale_shape_bcast = make_shape(num_rows, make_shape(num_cols / num_cols_scale, num_cols_scale), batches);
auto scale_shape_bcast = make_shape(num_rows, make_shape(group_size, scale_k), batches);
auto scale_stride_bcast = make_stride(scale_stride0, make_stride(0, scale_stride1), scale_stride2);
auto scale_layout_bcast = make_layout(scale_shape_bcast, scale_stride_bcast);
const auto blocks_x = num_cols;
const auto blocks_x = gemm_k;
const auto blocks_y = batches;
dim3 blocks(blocks_x, blocks_y, 1);
dequantize_weight_kernel<<<blocks, tpb>>>(dq_buffer, q_buffer, operand_layout, scale_buffer, scale_layout_bcast, thr_layout);
CUDA_CHECK(cudaDeviceSynchronize());
}
template <int ELTS_PER_THREAD,
class SubbyteType>
__global__ void pack_data_kernel(SubbyteType* packed_data_ptr,
const uint8_t* unpacked_data_ptr,
const size_t max_elts) {
using namespace cute;
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
uint8_t data[ELTS_PER_THREAD];
if (tid < max_elts) {
const uint8_t* read_ptr = unpacked_data_ptr + tid * ELTS_PER_THREAD;
for (int ii = 0; ii < ELTS_PER_THREAD; ++ii) {
data[ii] = read_ptr[ii];
}
using WriteType = cute::array_subbyte<SubbyteType, ELTS_PER_THREAD>;
WriteType* write_ptr = reinterpret_cast<WriteType*>(packed_data_ptr);
WriteType packed_data;
for (int ii = 0; ii < ELTS_PER_THREAD; ++ii) {
SubbyteType elt(data[ii]);
packed_data[ii] = elt;
}
write_ptr[tid] = packed_data;
}
}
template <class SubbyteType,
class OperandLayout>
void pack_data(SubbyteType* packed_data, const uint8_t* unpacked_data, const OperandLayout operand_layout) {
static_assert(cute::sizeof_bits_v<SubbyteType> < 8, "First operand must be a sub-byte type");
constexpr int packed_elements = 8 / cute::sizeof_bits_v<SubbyteType>;
if (cute::stride<0>(operand_layout) == 1 && (cute::shape<0>(operand_layout) % packed_elements)) {
std::cerr << "Invalid shape / stride for dimension 0. Contiguous dimension must be a multiple of "
<< packed_elements << std::endl;
exit(-1);
}
if (cute::stride<1>(operand_layout) == 1 && (cute::shape<1>(operand_layout) % packed_elements)) {
std::cerr << "Invalid shape / stride for dimension 1. Contiguous dimension must be a multiple of "
<< packed_elements << std::endl;
exit(-1);
}
const int64_t total_threads = cute::size(operand_layout) / packed_elements;
const int threads_per_block = 256;
const int64_t num_blocks = (total_threads + threads_per_block - 1) / threads_per_block;
pack_data_kernel<packed_elements><<<int(num_blocks), threads_per_block>>>(packed_data, unpacked_data, total_threads);
dequantize_weight_kernel<<<blocks, tpb>>>(dq_buffer, q_buffer, operand_layout, scale_buffer, zero_buffer, scale_layout_bcast, thr_layout);
CUDA_CHECK(cudaDeviceSynchronize());
}

View File

@ -0,0 +1,520 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Hopper Ptr-Array Batched GEMM example using CUTLASS 3 APIs for NVIDIA Hopper architecture.
This example demonstrates an implementation of Ptr-Array Batched GEMM using a TMA + GMMA
warp-specialized cooperative kernel.
The new feature showcased in this example is on-the-fly modification of TMA descriptors
to move between batches (represented by l).
To run this example:
$ ./examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm --m=2048 --n=2048 --k=2048 --l=10
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "helper.h"
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::half_t; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::half_t; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_256,_128,_64>; // Threadblock-level tile size
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using KernelSchedule = cutlass::gemm::KernelArrayTmaWarpSpecializedCooperative; // Kernel to launch
using EpilogueSchedule = cutlass::epilogue::NoSmemWarpSpecializedArray; // Epilogue to launch
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, AlignmentC,
ElementC, LayoutC, AlignmentC,
EpilogueSchedule
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cutlass::gemm::ArrayProblemShape<Shape<int,int,int,int>>,
CollectiveMainloop,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// Reference device GEMM implementation type
using DeviceGemmReference = cutlass::reference::device::Gemm<
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementAccumulator,
ElementAccumulator>;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
StrideA stride_A;
StrideB stride_B;
StrideC stride_C;
StrideD stride_D;
uint64_t seed;
std::vector<int64_t> offset_A;
std::vector<int64_t> offset_B;
std::vector<int64_t> offset_C;
std::vector<int64_t> offset_D;
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A;
cutlass::DeviceAllocation<typename Gemm::ElementB> block_B;
cutlass::DeviceAllocation<typename Gemm::ElementC> block_C;
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_D;
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_ref_D;
cutlass::DeviceAllocation<const typename Gemm::ElementA *> ptr_A;
cutlass::DeviceAllocation<const typename Gemm::ElementB *> ptr_B;
cutlass::DeviceAllocation<const typename Gemm::ElementC *> ptr_C;
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput *> ptr_D;
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput *> ptr_ref_D;
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help = false;
float alpha = 1.0f;
float beta = 0.0f;
int iterations = 10;
int m = 1024, n = 512, k = 1024, l = 10;
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("l", l);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "56_hopper_ptr_array_batched_gemm\n\n"
<< " Hopper FP32 GEMM using a Warp Specialized kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --l=<int> Sets the batch count for Ptr-Array GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --iterations=<int> Number of profiling iterations to perform\n\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << "56_hopper_ptr_array_batched_gemm" << " --m=1024 --n=512 --k=1024 --l=10 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const
{
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * m * n * k * l;
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};
/// Result structure
struct Result
{
double avg_runtime_ms = 0.0;
double gflops = 0.0;
cutlass::Status status = cutlass::Status::kSuccess;
cudaError_t error = cudaSuccess;
bool passed = false;
};
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <class Element>
bool initialize_block(
cutlass::DeviceAllocation<Element>& block,
uint64_t seed=2023) {
Element scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
} else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
} else {
scope_max = 8;
scope_min = -8;
}
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, scope_max, scope_min, 0);
return true;
}
/// Allocates device-side data
void allocate(const Options &options) {
int64_t total_elements_A = 0;
int64_t total_elements_B = 0;
int64_t total_elements_C = 0;
int64_t total_elements_D = 0;
for (int32_t i = 0; i < options.l; ++i) {
offset_A.push_back(total_elements_A);
offset_B.push_back(total_elements_B);
offset_C.push_back(total_elements_C);
offset_D.push_back(total_elements_D);
int64_t elements_A = options.m * options.k;
int64_t elements_B = options.k * options.n;
int64_t elements_C = options.m * options.n;
int64_t elements_D = options.m * options.n;
total_elements_A += elements_A;
total_elements_B += elements_B;
total_elements_C += elements_C;
total_elements_D += elements_D;
}
block_A.reset(total_elements_A);
block_B.reset(total_elements_B);
block_C.reset(total_elements_C);
block_D.reset(total_elements_D);
block_ref_D.reset(total_elements_D);
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const Options &options) {
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l));
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l));
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l));
//
// Assign pointers
//
std::vector<ElementA *> ptr_A_host(options.l);
std::vector<ElementB *> ptr_B_host(options.l);
std::vector<ElementC *> ptr_C_host(options.l);
std::vector<ElementC *> ptr_D_host(options.l);
for (int32_t i = 0; i < options.l; ++i) {
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
}
ptr_A.reset(options.l);
ptr_A.copy_from_host(ptr_A_host.data());
ptr_B.reset(options.l);
ptr_B.copy_from_host(ptr_B_host.data());
ptr_C.reset(options.l);
ptr_C.copy_from_host(ptr_C_host.data());
ptr_D.reset(options.l);
ptr_D.copy_from_host(ptr_D_host.data());
initialize_block(block_A, seed + 2023);
initialize_block(block_B, seed + 2022);
initialize_block(block_C, seed + 2021);
}
/// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options)
{
cutlass::KernelHardwareInfo hw_info;
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
// to use a GPU other than that with device ID 0.
hw_info.device_id = 0;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kArray,
{{options.m, options.n, options.k, options.l}},
{ptr_A.get(), stride_A, ptr_B.get(), stride_B},
{{options.alpha, options.beta}, ptr_C.get(), stride_C, ptr_D.get(), stride_D},
hw_info
};
return arguments;
}
bool verify(const Options &options) {
bool passed = true;
for (int32_t i = 0; i < options.l; ++i) {
cutlass::TensorRef ref_A(block_A.get() + offset_A.at(i), Gemm::LayoutA::packed({options.m, options.k}));
cutlass::TensorRef ref_B(block_B.get() + offset_B.at(i), Gemm::LayoutB::packed({options.k, options.n}));
cutlass::TensorRef ref_C(block_C.get() + offset_C.at(i), Gemm::LayoutC::packed({options.m, options.n}));
cutlass::TensorRef ref_D(block_ref_D.get() + offset_D.at(i), Gemm::LayoutD::packed({options.m, options.n}));
//
// Compute reference output
//
// Create instantiation for device reference gemm kernel
DeviceGemmReference gemm_reference;
// Launch device reference gemm kernel
gemm_reference(
{options.m, options.n, options.k},
ElementAccumulator(options.alpha),
ref_A,
ref_B,
ElementAccumulator(options.beta),
ref_C,
ref_D);
// Wait for kernel to finish
CUDA_CHECK(cudaDeviceSynchronize());
// Check if output from CUTLASS kernel and reference kernel are equal or not
passed &= cutlass::reference::device::BlockCompareEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), options.m * options.n);
}
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options)
{
allocate(options);
initialize(options);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
if (!result.passed) {
exit(-1);
}
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm.run());
}
timer.stop();
// Compute average setup and runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
std::cout << " Batches : " << options.l << std::endl;
std::cout << " Alpha, Beta : " << options.alpha << ',' << options.beta << std::endl;
std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS : " << result.gflops << std::endl;
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.3 Toolkit to run this example
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 3)) {
std::cerr << "This example requires CUDA 12.3 or newer.\n";
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major < 9) {
std::cerr
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
<< "later (compute capability 90 or greater).\n";
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate CUTLASS kernels
//
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
run<Gemm>(options);
#endif
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,52 @@
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Note that we set --iterations=0 for all tests below to disable the performance benchmarking.
# Only the correctness check will be run by these commands.
set(TEST_SQUARE --m=2048 --n=2048 --k=2048 -l=10 --iterations=0) # Square problem sizes
set(TEST_SQUARE_LARGE_BATCH --m=2048 --n=2048 --k=2048 -l=500 --iterations=0) # Square problem sizes
set(TEST_EPILOGUE --alpha=0.5 --beta=0.7 --iterations=0) # Default problem sizes
set(TEST_EPILOGUE_LARGE_BATCH --alpha=1.5 --beta=2.0 -l=500 --iterations=0) # Default problem sizes
set(TEST_SMALLK --m=2048 --n=5120 --k=128 --l=5 --iterations=0) # Small-k problem sizes
set(TEST_SMALLK_LARGE_BATCH --m=1024 --n=512 --k=64 --l=500 --iterations=0) # Small-k problem sizes
cutlass_example_add_executable(
56_hopper_ptr_array_batched_gemm
56_hopper_ptr_array_batched_gemm.cu
TEST_COMMAND_OPTIONS
TEST_SQUARE
TEST_SQUARE_LARGE_BATCH
TEST_EPILOGUE
TEST_EPILOGUE_LARGE_BATCH
TEST_SMALLK
TEST_SMALLK_LARGE_BATCH
)

View File

@ -0,0 +1,677 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Hopper Grouped GEMM example using CUTLASS 3 APIs for NVIDIA Hopper architecture.
This example demonstrates an implementation of Grouped GEMM using a TMA + GMMA
warp-specialized cooperative kernel.
For this example all scheduling work is performed on the device.
The new feature showcased in this example is on-the-fly modification of TMA descriptors
to move between groups/problem_count (represented by groups).
To run this example:
$ ./examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm --m=2048 --n=2048 --k=2048 --groups=10
The above example command makes all 10 groups to be sized at the given m, n, k sizes.
Skipping any of the problem dimensions randomizes it across the different groups.
To run this example for a set of problems using the benchmark option:
$ ./examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm --benchmark=./test_benchmark.txt
Where the test_benchmark.txt may look as such:
0 256x512x128
1 256x512x512
2 512x256x128
3 256x256x128
4 256x512x1024
5 1024x512x128 and so on
*/
#include <iostream>
#include <fstream>
#include <sstream>
#include <vector>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "helper.h"
using namespace cute;
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
using ElementB = cutlass::float_e5m2_t; // Element type for B matrix operand
using ElementC = float; // Element type for C and D matrix operands
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_256,_128,_64>; // Threadblock-level tile size
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using KernelSchedule = cutlass::gemm::KernelGroupTmaWarpSpecializedCooperativeFP8FastAccum; // Kernel to launch
using EpilogueSchedule = cutlass::epilogue::NoSmemWarpSpecializedGroup; // Epilogue to launch
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, AlignmentC,
ElementC, LayoutC, AlignmentC,
EpilogueSchedule
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
CollectiveMainloop,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// Reference device GEMM implementation type
using DeviceGemmReference = cutlass::reference::device::Gemm<
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementAccumulator,
ElementAccumulator>;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
// Host-side allocations
std::vector<int64_t> offset_A;
std::vector<int64_t> offset_B;
std::vector<int64_t> offset_C;
std::vector<int64_t> offset_D;
std::vector<StrideA> stride_A_host;
std::vector<StrideB> stride_B_host;
std::vector<StrideC> stride_C_host;
std::vector<StrideD> stride_D_host;
// Device-side allocations
cutlass::DeviceAllocation<typename ProblemShape::UnderlyingProblemShape> problem_sizes;
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A;
cutlass::DeviceAllocation<typename Gemm::ElementB> block_B;
cutlass::DeviceAllocation<typename Gemm::ElementC> block_C;
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_D;
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_ref_D;
cutlass::DeviceAllocation<const typename Gemm::ElementA *> ptr_A;
cutlass::DeviceAllocation<const typename Gemm::ElementB *> ptr_B;
cutlass::DeviceAllocation<const typename Gemm::ElementC *> ptr_C;
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput *> ptr_D;
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput *> ptr_ref_D;
cutlass::DeviceAllocation<StrideA> stride_A;
cutlass::DeviceAllocation<StrideB> stride_B;
cutlass::DeviceAllocation<StrideC> stride_C;
cutlass::DeviceAllocation<StrideD> stride_D;
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help = false;
float alpha = 1.0f;
float beta = 0.0f;
int iterations = 10;
int m = 1024, n = 2048, k = 512, groups = 10;
std::string benchmark_path;
std::vector<typename ProblemShape::UnderlyingProblemShape> problem_sizes_host;
int const tma_alignment_bits = 128;
int const alignment = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("groups", groups);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("benchmark", benchmark_path);
// Decide how to initialize the problems
if (!benchmark_path.empty()) {
if (!benchmark_problems()) {
problem_sizes_host.clear();
return;
}
}
else {
randomize_problems(cmd);
}
}
void randomize_problems(cutlass::CommandLine &cmd) {
int cmd_line_m = -1;
int cmd_line_n = -1;
int cmd_line_k = -1;
cmd.get_cmd_line_argument("m", cmd_line_m);
cmd.get_cmd_line_argument("n", cmd_line_n);
cmd.get_cmd_line_argument("k", cmd_line_k);
problem_sizes_host.reserve(groups);
for (int i = groups; i > 0; i--) {
int m = cmd_line_m;
int n = cmd_line_n;
int k = cmd_line_k;
if (m < 1) {
m = ((rand() % 512) + 1);
}
if (n < 1) {
n = ((rand() % 512) + 1);
}
if (k < 1) {
k = alignment * ((rand() % 64) + 1);
}
problem_sizes_host.push_back({m, n, k});
}
}
/// Load a benchmark
bool benchmark_problems() {
std::ifstream file(benchmark_path);
if (!file.good()) {
return false;
}
while (file.good()) {
int idx = -1;
std::string extent_str;
file >> idx >> extent_str;
if (idx < 0 || extent_str.empty()) {
break;
}
cutlass::gemm::GemmCoord extent;
std::vector<std::string> tokens;
cutlass::CommandLine::tokenize(tokens, extent_str, 'x');
for (int i = 0; i < int(tokens.size()); ++i) {
int x = std::atoi(tokens.at(i).c_str());
// round up
if (x % alignment) {
x += (alignment - (x % alignment));
}
extent.at(i) = x;
}
if (extent.product()) {
problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()});
}
}
return true;
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "57_hopper_grouped_gemm\n\n"
<< " Hopper FP8 Grouped GEMM using a Warp Specialized kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM for all groups\n"
<< " --n=<int> Sets the N extent of the GEMM for all groups\n"
<< " --k=<int> Sets the K extent of the GEMM for all groups\n"
<< " --groups=<int> Sets the number of individual GEMM problems for Grouped GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --iterations=<int> Number of profiling iterations to perform\n\n"
<< " --benchmark=<str> Executes a benchmark problem size.\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << "57_hopper_grouped_gemm" << " --m=1024 --n=512 --k=1024 --groups=10 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s, std::vector<typename ProblemShape::UnderlyingProblemShape> problem_sizes_host) const
{
// Number of real-valued multiply-adds
uint64_t fmas = uint64_t();
for (auto const & problem : problem_sizes_host) {
fmas += cute::size(problem);
}
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * uint64_t(fmas);
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};
/// Result structure
struct Result
{
double avg_runtime_ms = 0.0;
double gflops = 0.0;
cutlass::Status status = cutlass::Status::kSuccess;
cudaError_t error = cudaSuccess;
bool passed = false;
};
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <class Element>
bool initialize_block(
cutlass::DeviceAllocation<Element>& block,
uint64_t seed=2023) {
Element scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
scope_max = static_cast<Element>(2);
scope_min = static_cast<Element>(0);
} else if (bits_input <= 8) {
scope_max = static_cast<Element>(2);
scope_min = static_cast<Element>(-2);
} else {
scope_max = static_cast<Element>(8);
scope_min = static_cast<Element>(-8);
}
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, scope_max, scope_min, 0);
return true;
}
/// Allocates device-side data
void allocate(const Options &options) {
int64_t total_elements_A = 0;
int64_t total_elements_B = 0;
int64_t total_elements_C = 0;
int64_t total_elements_D = 0;
for (int32_t i = 0; i < options.groups; ++i) {
auto problem = options.problem_sizes_host.at(i);
auto M = get<0>(problem);
auto N = get<1>(problem);
auto K = get<2>(problem);
offset_A.push_back(total_elements_A);
offset_B.push_back(total_elements_B);
offset_C.push_back(total_elements_C);
offset_D.push_back(total_elements_D);
int64_t elements_A = M * K;
int64_t elements_B = K * N;
int64_t elements_C = M * N;
int64_t elements_D = M * N;
total_elements_A += elements_A;
total_elements_B += elements_B;
total_elements_C += elements_C;
total_elements_D += elements_D;
stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, Int<1>{})));
stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, Int<1>{})));
stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, Int<1>{})));
stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, Int<1>{})));
}
block_A.reset(total_elements_A);
block_B.reset(total_elements_B);
block_C.reset(total_elements_C);
block_D.reset(total_elements_D);
block_ref_D.reset(total_elements_D);
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const Options &options) {
uint64_t seed = 2020;
problem_sizes.reset(options.groups);
problem_sizes.copy_from_host(options.problem_sizes_host.data());
//
// Assign pointers
//
std::vector<ElementA *> ptr_A_host(options.groups);
std::vector<ElementB *> ptr_B_host(options.groups);
std::vector<ElementC *> ptr_C_host(options.groups);
std::vector<ElementC *> ptr_D_host(options.groups);
for (int32_t i = 0; i < options.groups; ++i) {
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
}
ptr_A.reset(options.groups);
ptr_A.copy_from_host(ptr_A_host.data());
ptr_B.reset(options.groups);
ptr_B.copy_from_host(ptr_B_host.data());
ptr_C.reset(options.groups);
ptr_C.copy_from_host(ptr_C_host.data());
ptr_D.reset(options.groups);
ptr_D.copy_from_host(ptr_D_host.data());
stride_A.reset(options.groups);
stride_A.copy_from_host(stride_A_host.data());
stride_B.reset(options.groups);
stride_B.copy_from_host(stride_B_host.data());
stride_C.reset(options.groups);
stride_C.copy_from_host(stride_C_host.data());
stride_D.reset(options.groups);
stride_D.copy_from_host(stride_D_host.data());
initialize_block(block_A, seed + 2023);
initialize_block(block_B, seed + 2022);
initialize_block(block_C, seed + 2021);
}
/// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options)
{
cutlass::KernelHardwareInfo hw_info;
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
// to use a GPU other than that with device ID 0.
hw_info.device_id = 0;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGrouped,
{options.groups, problem_sizes.get(), options.problem_sizes_host.data()},
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()},
{{options.alpha, options.beta}, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
hw_info
};
return arguments;
}
bool verify(const Options &options) {
bool passed = true;
for (int32_t i = 0; i < options.groups; ++i) {
auto problem = options.problem_sizes_host.at(i);
auto M = get<0>(problem);
auto N = get<1>(problem);
auto K = get<2>(problem);
cutlass::TensorRef ref_A(block_A.get() + offset_A.at(i), Gemm::LayoutA::packed({M, K}));
cutlass::TensorRef ref_B(block_B.get() + offset_B.at(i), Gemm::LayoutB::packed({K, N}));
cutlass::TensorRef ref_C(block_C.get() + offset_C.at(i), Gemm::LayoutC::packed({M, N}));
cutlass::TensorRef ref_D(block_ref_D.get() + offset_D.at(i), Gemm::LayoutD::packed({M, N}));
//
// Compute reference output
//
// Create instantiation for device reference gemm kernel
DeviceGemmReference gemm_reference;
// Launch device reference gemm kernel
gemm_reference(
{M, N, K},
ElementAccumulator(options.alpha),
ref_A,
ref_B,
ElementAccumulator(options.beta),
ref_C,
ref_D);
// Wait for kernel to finish
CUDA_CHECK(cudaDeviceSynchronize());
// Check if output from CUTLASS kernel and reference kernel are equal or not
passed &= cutlass::reference::device::BlockCompareEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N);
#if 0
std::cout << "Group: " << i << " Status: " << passed << std::endl;
#endif
}
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options)
{
allocate(options);
initialize(options);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
if (!result.passed) {
exit(-1);
}
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm.run());
}
timer.stop();
// Compute average setup and runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0, options.problem_sizes_host);
std::cout << " Problem Sizes: " << std::endl;
for (auto const & problem : options.problem_sizes_host) {
std::cout << " " << problem << std::endl;
}
std::cout << " Groups : " << options.groups << std::endl;
std::cout << " Alpha, Beta : " << options.alpha << ',' << options.beta << std::endl;
std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS : " << result.gflops << std::endl;
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.3 Toolkit to run this example
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 3)) {
std::cerr << "This example requires CUDA 12.3 or newer.\n";
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major < 9) {
std::cerr
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
<< "later (compute capability 90 or greater).\n";
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate CUTLASS kernels
//
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
run<Gemm>(options);
#endif
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,56 @@
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Note that we set --iterations=0 for all tests below to disable the performance benchmarking.
# Only the correctness check will be run by these commands.
set(TEST_RANDOM --iterations=0) # Random problem sizes
set(TEST_RANDOM_LARGE_GROUP --groups=500 --iterations=0) # Random problem sizes
set(TEST_EPILOGUE --alpha=0.5 --beta=0.7 --iterations=0) # Random problem sizes
set(TEST_EPILOGUE_LARGE_GROUP --alpha=1.5 --beta=2.0 --groups=500 --iterations=0) # Random problem sizes
set(TEST_FIXED --m=2048 --n=5120 --k=8192 --groups=50 --iterations=0) # Fixed problem sizes
set(TEST_FIXED_LARGE_GROUP --m=2048 --n=512 --k=512 --groups=512 --iterations=0) # Fixed problem sizes
set(TEST_RANDOM_PERF --iterations=10) # Random problem sizes
set(TEST_RANDOM_PERF_LARGE_GROUP --groups=500 --iterations=10) # Random problem sizes
cutlass_example_add_executable(
57_hopper_grouped_gemm
57_hopper_grouped_gemm.cu
TEST_COMMAND_OPTIONS
TEST_RANDOM
TEST_RANDOM_LARGE_GROUP
TEST_EPILOGUE
TEST_EPILOGUE_LARGE_GROUP
TEST_FIXED
TEST_FIXED_LARGE_GROUP
TEST_RANDOM_PERF
TEST_RANDOM_PERF_LARGE_GROUP
)

View File

@ -37,9 +37,10 @@
cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
project(cutlass_import_example VERSION 0.1 LANGUAGES CXX CUDA)
project(cutlass_import_example VERSION 0.2 LANGUAGES CXX CUDA)
if (DEFINED CUTLASS_DIR)
if (CUTLASS_DIR)
message(STATUS "Using CUTLASS specified at ${CUTLASS_DIR}.")
list(APPEND CMAKE_PREFIX_PATH ${CUTLASS_DIR})
endif()

View File

@ -44,7 +44,7 @@ function(cutlass_example_add_executable NAME)
set(__DISABLE_TESTS OFF)
endif()
cutlass_add_executable(${NAME} ${__UNPARSED_ARGUMENTS})
cutlass_add_executable(${NAME} ${__UNPARSED_ARGUMENTS} BATCH_SOURCES OFF)
add_dependencies(cutlass_examples ${NAME})
@ -136,6 +136,8 @@ foreach(EXAMPLE
53_hopper_gemm_permute
54_hopper_fp8_warp_specialized_gemm
55_hopper_mixed_dtype_gemm
56_hopper_ptr_array_batched_gemm
57_hopper_grouped_gemm
)
add_subdirectory(${EXAMPLE})

View File

@ -27,7 +27,14 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cutlass_example_add_executable(
sgemm_nt_1
sgemm_nt_1.cu
)
cutlass_example_add_executable(
tiled_copy
tiled_copy.cu
)

View File

@ -0,0 +1,260 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <cute/tensor.hpp>
#include "cutlass/util/print_error.hpp"
#include "cutlass/util/GPU_Clock.hpp"
#include "cutlass/util/helper_cuda.hpp"
// This is a simple tutorial showing several ways to partition a tensor into tiles then
// perform efficient, coalesced copies. This example also shows how to vectorize accesses
// which may be a useful optimization or required for certain workloads.
//
// `copy_kernel()` and `copy_kernel_vectorized()` each assume a pair of tensors with
// dimensions (m, n) have been partitioned via `tiled_divide()`.
//
// The result are a part of compatible tensors with dimensions ((M, N), m', n'), where
// (M, N) denotes a statically sized tile, and m' and n' denote the number of such tiles
// within the tensor.
//
// Each statically sized tile is mapped to a CUDA threadblock which performs efficient
// loads and stores to Global Memory.
//
// `copy_kernel()` uses `cute::local_partition()` to partition the tensor and map
// the result to threads using a striped indexing scheme. Threads themselve are arranged
// in a (ThreadShape_M, ThreadShape_N) arrangement which is replicated over the tile.
//
// `copy_kernel_vectorized()` uses `cute::make_tiled_copy()` to perform a similar
// partitioning using `cute::Copy_Atom` to perform vectorization. The actual vector
// size is defined by `ThreadShape`.
//
// This example assumes the overall tensor shape is divisible by the tile size and
// does not perform predication.
/// Simple copy kernel.
//
// Uses local_partition() to partition a tile among threads arranged as (THR_M, THR_N).
template <class TensorS, class TensorD, class ThreadLayout>
__global__ void copy_kernel(TensorS S, TensorD D, ThreadLayout)
{
using namespace cute;
// Slice the tiled tensors
Tensor tile_S = S(make_coord(_,_), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N)
Tensor tile_D = D(make_coord(_,_), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N)
// Construct a partitioning of the tile among threads with the given thread arrangement.
// Concept: Tensor Layout Index
Tensor thr_tile_S = local_partition(tile_S, ThreadLayout{}, threadIdx.x);
Tensor thr_tile_D = local_partition(tile_D, ThreadLayout{}, threadIdx.x);
// Construct a register-backed Tensor with the same shape as each thread's partition
auto fragment = make_fragment_like(thr_tile_S);
// Copy from GMEM to RMEM and from RMEM to GMEM
copy(thr_tile_S, fragment);
copy(fragment, thr_tile_D);
}
/// Vectorized copy kernel.
///
/// Uses `make_tiled_copy()` to perform a copy using vector instructions. This operation
/// has the precondition that pointers are aligned to the vector size.
///
template <class TensorS, class TensorD, class ThreadLayout, class VecLayout>
__global__ void copy_kernel_vectorized(TensorS S, TensorD D, ThreadLayout, VecLayout)
{
using namespace cute;
using Element = typename TensorS::value_type;
// Slice the tensors to obtain a view into each tile.
Tensor tile_S = S(make_coord(_, _), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N)
Tensor tile_D = D(make_coord(_, _), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N)
// Define `AccessType` which controls the size of the actual memory access.
using AccessType = cutlass::AlignedArray<Element, size(shape(VecLayout{}))>;
// A copy atom corresponds to one hardware memory access.
using Atom = Copy_Atom<UniversalCopy<AccessType>, Element>;
// Construct tiled copy, a tiling of copy atoms.
//
// Note, this assumes the vector and thread layouts are aligned with contigous data
// in GMEM. Alternative thread layouts are possible but may result in uncoalesced
// reads. Alternative vector layouts are also possible, though incompatible layouts
// will result in compile time errors.
auto tiled_copy =
make_tiled_copy(
Atom{}, // access size
ThreadLayout{}, // thread layout
VecLayout{}); // vector layout (e.g. 4x1)
// Construct a Tensor corresponding to each thread's slice.
auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x);
Tensor thr_tile_S = thr_copy.partition_S(tile_S);
Tensor thr_tile_D = thr_copy.partition_D(tile_D);
// Construct a register-backed Tensor with the same shape as each thread's partition
auto fragment = make_fragment_like(thr_tile_D);
// Copy from GMEM to RMEM and from RMEM to GMEM
copy(tiled_copy, thr_tile_S, fragment);
copy(tiled_copy, fragment, thr_tile_D);
}
/// Helper to convert a shape to a dim3
template <class Shape>
dim3 shape_to_dim3(Shape shape)
{
using namespace cute;
CUTE_STATIC_ASSERT_V(rank(shape) <= Int<3>{});
auto result = append<3>(product_each(shape), 1u);
return dim3(get<0>(result), get<1>(result), get<2>(result));
}
/// Main function
int main(int argc, char** argv)
{
//
// Given a 2D shape, perform an efficient copy
//
using namespace cute;
using Element = float;
// Define a tensor shape with dynamic extents (m, n)
auto tensor_shape = make_shape(256, 512);
thrust::host_vector<Element> h_S(size(tensor_shape));
thrust::host_vector<Element> h_D(size(tensor_shape));
//
// Initialize
//
for (size_t i = 0; i < h_S.size(); ++i) {
h_S[i] = static_cast<Element>(i);
h_D[i] = Element{};
}
thrust::device_vector<Element> d_S = h_S;
thrust::device_vector<Element> d_D = h_D;
//
// Make tensors
//
Tensor tensor_S = make_tensor(make_gmem_ptr(d_S.data().get()), make_layout(tensor_shape));
Tensor tensor_D = make_tensor(make_gmem_ptr(d_D.data().get()), make_layout(tensor_shape));
//
// Partition
//
// Define a statically sized block (M, N).
//
// Note, by convention, capital letters are used to represent static modes.
auto block_shape = make_shape(Int<128>{}, Int<64>{});
if ((get<0>(tensor_shape) % get<0>(block_shape)) || (get<1>(tensor_shape) % get<1>(block_shape))) {
std::cerr << "The tensor shape must be divisible by the block shape." << std::endl;
return -1;
}
// Tile the tensor (m, m) ==> ((M, N), m', n') where (M, N) is the static tile
// shape, and modes (m', n') correspond to the number of tiles.
//
// These will be used to determine the CUDA kernel grid dimensinos.
Tensor tiled_tensor_S = tiled_divide(tensor_S, block_shape);
Tensor tiled_tensor_D = tiled_divide(tensor_D, block_shape);
// Thread arrangement
Layout thr_layout = make_layout(make_shape(Int<32>{}, Int< 8>{}));
// Vector dimensions
Layout vec_layout = make_layout(make_shape(Int<4>{}, Int<1>{}));
//
// Determine grid and block dimensions
//
dim3 gridDim = shape_to_dim3(select<1,2>(shape(tiled_tensor_D))); // Grid shape corresponds to modes m' and n'
dim3 blockDim(size(shape(thr_layout)));
//
// Launch the kernel
//
copy_kernel_vectorized<<< gridDim, blockDim >>>(
tiled_tensor_S,
tiled_tensor_D,
thr_layout,
vec_layout);
cudaError result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "CUDA Runtime error: " << cudaGetErrorString(result) << std::endl;
return -1;
}
//
// Verify
//
h_D = d_D;
int32_t errors = 0;
int32_t const kErrorLimit = 10;
for (size_t i = 0; i < h_D.size(); ++i) {
if (h_S[i] != h_D[i]) {
std::cerr << "Error. S[" << i << "]: " << h_S[i] << ", D[" << i << "]: " << h_D[i] << std::endl;
if (++errors >= kErrorLimit) {
std::cerr << "Aborting on " << kErrorLimit << "nth error." << std::endl;
return -1;
}
}
}
std::cout << "Success." << std::endl;
return 0;
}

View File

@ -357,7 +357,7 @@
"## Handling errors\n",
"The CUTLASS Python interface attempts to catch runtime and compilation errors in Python so as to provide more understandable error messages.\n",
"\n",
"Here's an example in which we try to use too many stages for a given GEMM kernel. Normally, this would result in a runtime error due to the GPU having insufficient shared memory to launch the kernel with 8 stages. The CUTLASS Python interface is able to detect this issue before compiling the kernel, and reports it back to the user."
"Here's an example in which we try to use too many stages for a given GEMM kernel. Normally, this would result in a runtime error due to the GPU having insufficient shared memory to launch the kernel with 8 stages. The CUTLASS Python interface is able to detect this issue before compiling the kernel, and reports it back to the user. Uncomment and run the code below to see this error."
]
},
{
@ -371,6 +371,75 @@
"# td.stages = 8\n",
"# plan.compile(td)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Specializations for other data types\n",
"\n",
"Various CUTLASS kernels specialized for specific data types can also be run via the Python interface.\n",
"\n",
"For example, the code below shows how to declare and run a GEMM using the 3xTF32 feature (see corresponding C++ example [here](https://github.com/NVIDIA/cutlass/blob/main/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu))."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from cutlass.backend.utils.device import device_cc\n",
"\n",
"# 3xTF32 requires SM80 or higher\n",
"if device_cc() >= 80:\n",
" plan = cutlass.op.Gemm(element=np.float32, layout=cutlass.LayoutType.RowMajor)\n",
" plan.math_operation = cutlass.MathOperation.multiply_add_fast_f32\n",
"\n",
" # Create input/output tensors in FP32\n",
" A, B = [np.ones((128, 128)).astype(np.float32) for _ in range(2)]\n",
" C, D = [np.zeros((128, 128)).astype(np.float32) for _ in range(2)]\n",
"\n",
" # Run the GEMM\n",
" plan.run(A, B, C, D, print_module=print_module)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Additionally, one can run CUTLASS's FP8 GEMMs if using a frontend library capable of allocating and initializing FP8 tensors (e.g., PyTorch)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"try:\n",
" import torch\n",
"except ImportError:\n",
" print(\"PyTorch is not available. Skipping FP8 example\")\n",
" import sys; sys.exit(0)\n",
"\n",
"if not hasattr(torch, \"float8_e4m3fn\"):\n",
" print(\"Version of PyTorch does not have the float8_e4m3fn data type. Skipping FP8 example\")\n",
" import sys; sys.exit(0)\n",
"\n",
"# FP8 is supported through the CUTLASS Python interface on SM90 and higher\n",
"if device_cc() >= 90:\n",
" plan = cutlass.op.Gemm(element=torch.float8_e4m3fn, element_C=torch.float32, element_accumulator=torch.float32,\n",
" layout_A=cutlass.LayoutType.RowMajor, layout_B=cutlass.LayoutType.ColumnMajor,\n",
" layout_C=cutlass.LayoutType.ColumnMajor)\n",
"\n",
" # Create input/output tensors in FP8\n",
" A, B = [torch.ones((128, 128)).to(torch.float8_e4m3fn).to(\"cuda\") for _ in range(2)]\n",
" C, D = [torch.zeros((128, 128)).to(torch.float8_e4m3fn).to(\"cuda\") for _ in range(2)]\n",
"\n",
" # Run the GEMM\n",
" plan.run(A, B, C, D, print_module=print_module)"
]
}
],
"metadata": {

View File

@ -134,7 +134,7 @@
"id": "590a3bc5",
"metadata": {},
"source": [
"We'll next run a group of 50 GEMMs via the CUTLASS Python interface and via PyTorch."
"We'll next run a group of 20 GEMMs via the CUTLASS Python interface and via PyTorch."
]
},
{
@ -144,7 +144,7 @@
"metadata": {},
"outputs": [],
"source": [
"As, Bs, Cs, Ds, = generate_problems(50)\n",
"As, Bs, Cs, Ds, = generate_problems(20)\n",
"\n",
"plan.run(As, Bs, Cs, Ds, print_module=True)\n",
"Ds_torch = [a @ b for a, b in zip(As, Bs)]\n",

View File

@ -95,6 +95,7 @@ CUTE_DEVICE dim3 cluster_grid_dims()
return gridDim;
#elif defined(_MSC_VER)
CUTE_RUNTIME_ASSERT("cluster_grid_dims() can only be called on device");
return {0, 0, 0};
#else
return {0, 0, 0};
#endif
@ -114,6 +115,7 @@ CUTE_DEVICE dim3 cluster_id_in_grid()
return blockIdx;
#elif defined(_MSC_VER)
CUTE_RUNTIME_ASSERT("cluster_id_in_grid() can only be called on device");
return {0, 0, 0};
#else
return {0, 0, 0};
#endif

View File

@ -40,6 +40,11 @@
# define CUTE_ARCH_TMA_SM90_ENABLED
#endif
#if defined(CUTE_ARCH_TMA_SM90_ENABLED) && \
((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 3)))
# define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED
#endif
namespace cute
{

View File

@ -42,6 +42,7 @@
#include <cute/container/alignment.hpp>
#include <cute/container/bit_field.hpp>
#include <cute/container/array.hpp>
#include <cute/numeric/int.hpp> // to_Format<[u]intX>
#include <cute/numeric/half.hpp> // to_Format<half_t>
@ -200,6 +201,141 @@ prefetch_tma_descriptor(TmaDescriptor const* desc_ptr)
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Perform a TensorMap modification (by each field)
////////////////////////////////////////////////////////////////////////////////////////////////////
// Replace tensor pointer directly in GMEM
CUTE_HOST_DEVICE
void
tma_descriptor_replace_addr_in_global_mem(TmaDescriptor const* desc_ptr,
void const* const new_tensor_ptr)
{
#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
uint64_t const new_desc_addr = reinterpret_cast<uint64_t>(new_tensor_ptr);
asm volatile (
"tensormap.replace.tile.global_address.global.b1024.b64 [%0], %1;"
:: "l"(gmem_int_desc), "l"(new_desc_addr));
#else
CUTE_RUNTIME_ASSERT("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3");
#endif
}
// Replace tensor pointer by bringing the tensormap from GMEM into the shared memory
CUTE_HOST_DEVICE
void
tma_descriptor_replace_addr_in_shared_mem(TmaDescriptor& smem_desc,
void const* const new_tensor_ptr)
{
#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED)
uint32_t smem_int_desc = cast_smem_ptr_to_uint(&smem_desc);
uint64_t const new_desc_addr = reinterpret_cast<uint64_t>(new_tensor_ptr);
uint64_t const smem_int64_desc = 0;
asm volatile (
"cvt.u64.u32 %0, %1;"
:: "l"(smem_int64_desc), "r"(smem_int_desc));
asm volatile (
"tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;"
:: "l"(smem_int64_desc), "l"(new_desc_addr));
#else
CUTE_RUNTIME_ASSERT("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3");
#endif
}
// Replace tensor dims and strides for GEMMs by bringing the tensormap from GMEM into the shared memory
CUTE_HOST_DEVICE
void
tma_descriptor_replace_dims_strides_in_shared_mem(TmaDescriptor & smem_desc,
cute::array<uint32_t, 3> const& prob_shape,
cute::array<uint64_t, 3> const& prob_stride)
{
#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED)
uint32_t smem_int_desc = cast_smem_ptr_to_uint(&smem_desc);
uint64_t const smem_int64_desc = 0;
asm volatile (
"cvt.u64.u32 %0, %1;"
:: "l"(smem_int64_desc), "r"(smem_int_desc));
asm volatile (
"tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 0, %1;"
:: "l"(smem_int64_desc), "r"(prob_shape[0]));
asm volatile (
"tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 1, %1;"
:: "l"(smem_int64_desc), "r"(prob_shape[1]));
asm volatile (
"tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 2, %1;"
:: "l"(smem_int64_desc), "r"(prob_shape[2]));
// Strides must be a multiple of 16. Also, stride for the intermost dimension is implicitly 1
asm volatile (
"tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;"
:: "l"(smem_int64_desc), "l"(prob_stride[1] >> 4));
asm volatile (
"tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 1, %1;"
:: "l"(smem_int64_desc), "l"(prob_stride[2] >> 4));
#else
CUTE_RUNTIME_ASSERT("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3");
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Perform a fused copy and fence operation (needed when modifying tensormap in shared memory)
////////////////////////////////////////////////////////////////////////////////////////////////////
CUTE_HOST_DEVICE
void
tma_descriptor_cp_fence_release(TmaDescriptor const* gmem_desc_ptr, TmaDescriptor& smem_desc)
{
#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(gmem_desc_ptr);
uint32_t smem_int_desc = cast_smem_ptr_to_uint(&smem_desc);
asm volatile (
"tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned [%0], [%1], 128;"
:: "l"(gmem_int_desc), "r"(smem_int_desc));
#else
CUTE_RUNTIME_ASSERT("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3");
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Perform a release fence operation (needed when modifying tensormap directly in GMEM)
////////////////////////////////////////////////////////////////////////////////////////////////////
CUTE_HOST_DEVICE
void
tma_descriptor_fence_release()
{
#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED)
asm volatile ("fence.proxy.tensormap::generic.release.gpu;");
#else
CUTE_RUNTIME_ASSERT("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3");
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Perform a acquire fence operation
////////////////////////////////////////////////////////////////////////////////////////////////////
CUTE_HOST_DEVICE
void
tma_descriptor_fence_acquire(TmaDescriptor const* desc_ptr)
{
#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
asm volatile (
"fence.proxy.tensormap::generic.acquire.gpu [%0], 128;"
:
: "l"(gmem_int_desc)
: "memory");
asm volatile (
"cvta.global.u64 %0, %0;"
:
: "l"(gmem_int_desc), "l"(gmem_int_desc)
: "memory");
#else
CUTE_RUNTIME_ASSERT("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3");
#endif
}
///////////////////////////////////////////////////////////////////////////////
} // end namespace cute

View File

@ -775,6 +775,104 @@ struct SM90_TMA_STORE
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// TMA_STORE im2col: Initiates a TMA copy, in im2col mode, from shared memory to global memory
////////////////////////////////////////////////////////////////////////////////////////////////////
struct SM90_TMA_STORE_IM2COL_3D
{
CUTE_HOST_DEVICE static void
copy(void const* const desc_ptr,
void const* const smem_ptr,
int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n)
{
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.3d.global.shared::cta.im2col_no_offs.bulk_group"
" [%0, {%2, %3, %4}], [%1];"
:
: "l"(gmem_int_desc), "r"(smem_int_ptr),
"r"(coord_c), "r"(coord_w), "r"(coord_n)
: "memory");
#else
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
#endif
}
};
struct SM90_TMA_STORE_IM2COL_4D
{
CUTE_HOST_DEVICE static void
copy(void const* const desc_ptr,
void const* const smem_ptr,
int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n)
{
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.4d.global.shared::cta.im2col_no_offs.bulk_group"
" [%0, {%2, %3, %4, %5}], [%1];"
:
: "l"(gmem_int_desc), "r"(smem_int_ptr),
"r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n)
: "memory");
#else
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
#endif
}
};
struct SM90_TMA_STORE_IM2COL_5D
{
CUTE_HOST_DEVICE static void
copy(void const* const desc_ptr,
void const* const smem_ptr,
int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n)
{
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.5d.global.shared::cta.im2col_no_offs.bulk_group"
" [%0, {%2, %3, %4, %5, %6}], [%1];"
:
: "l"(gmem_int_desc), "r"(smem_int_ptr),
"r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_d), "r"(coord_n)
: "memory");
#else
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
#endif
}
};
struct SM90_TMA_STORE_IM2COL
{
CUTE_HOST_DEVICE static void
copy(void const* const desc_ptr,
void const* const smem_ptr,
int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n)
{
return SM90_TMA_STORE_IM2COL_3D::copy(desc_ptr, smem_ptr, coord_c, coord_w, coord_n);
}
CUTE_HOST_DEVICE static void
copy(void const* const desc_ptr,
void const* const smem_ptr,
int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n)
{
return SM90_TMA_STORE_IM2COL_4D::copy(desc_ptr, smem_ptr, coord_c, coord_w, coord_h, coord_n);
}
CUTE_HOST_DEVICE static void
copy(void const* const desc_ptr,
void const* const smem_ptr,
int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n)
{
return SM90_TMA_STORE_IM2COL_5D::copy(desc_ptr, smem_ptr, coord_c, coord_w, coord_h, coord_d, coord_n);
}
};
// Indicate arrival of warp issuing TMA_STORE
CUTE_HOST_DEVICE static void
tma_store_arrive() {

View File

@ -1,4 +1,4 @@
/**************************************************************************************************
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*

View File

@ -31,26 +31,28 @@
#pragma once
#include <cute/config.hpp>
#include <cute/arch/copy.hpp>
#include <cute/tensor.hpp>
#include <cute/atom/copy_traits.hpp>
#include <cute/atom/mma_atom.hpp>
#include <cute/util/type_traits.hpp>
#include <cute/tensor.hpp>
namespace cute
{
template <class... Args>
struct Copy_Atom;
template <class CopyOperation, class T>
struct Copy_Atom<CopyOperation, T> : Copy_Atom<Copy_Traits<CopyOperation>, T>
template <class CopyOperation, class CopyInternalType>
struct Copy_Atom<CopyOperation, CopyInternalType> : Copy_Atom<Copy_Traits<CopyOperation>, CopyInternalType>
{};
template <class... Args, class T>
struct Copy_Atom<Copy_Traits<Args...>, T>
template <class... Args, class CopyInternalType>
struct Copy_Atom<Copy_Traits<Args...>, CopyInternalType>
: Copy_Traits<Args...>
{
using Traits = Copy_Traits<Args...>;
@ -61,7 +63,7 @@ struct Copy_Atom<Copy_Traits<Args...>, T>
using BitLayoutDst = typename Traits::DstLayout;
using BitLayoutRef = typename Traits::RefLayout;
using ValType = T;
using ValType = CopyInternalType;
using ValLayoutSrc = decltype(upcast<sizeof_bits<ValType>::value>(BitLayoutSrc{}));
using ValLayoutDst = decltype(upcast<sizeof_bits<ValType>::value>(BitLayoutDst{}));
@ -80,7 +82,7 @@ struct Copy_Atom<Copy_Traits<Args...>, T>
auto
with(TraitsArgs&&... args) const {
auto traits = Traits::with(std::forward<TraitsArgs>(args)...);
return Copy_Atom<decltype(traits), T>{traits};
return Copy_Atom<decltype(traits), CopyInternalType>{traits};
}
//
@ -88,19 +90,19 @@ struct Copy_Atom<Copy_Traits<Args...>, T>
//
// Check and call instruction, or recurse
template <class TS, class SLayout,
class TD, class DLayout>
template <class SEngine, class SLayout,
class DEngine, class DLayout>
CUTE_HOST_DEVICE
void
call(Tensor<TS,SLayout> const& src,
Tensor<TD,DLayout> & dst) const
call(Tensor<SEngine,SLayout> const& src,
Tensor<DEngine,DLayout> & dst) const
{
static_assert(SLayout::rank == 1, "Expected rank-1 src tensor");
static_assert(DLayout::rank == 1, "Expected rank-1 dst tensor");
if constexpr (is_constant<NumValSrc, decltype(size(src))>::value ||
is_constant<NumValDst, decltype(size(dst))>::value) {
// Dispatch to unpack for instruction
// Dispatch to unpack to execute instruction
return copy_unpack(*this, src, dst);
} else
if constexpr (is_tuple<decltype(shape(src))>::value &&
@ -110,7 +112,7 @@ struct Copy_Atom<Copy_Traits<Args...>, T>
// ((A,B,C,...)) -> (A,B,C,...)
return copy(*this, tensor<0>(src), tensor<0>(dst));
} else {
static_assert(sizeof(TS) < 0, "No instruction match and no recursion possible.");
static_assert(dependent_false<SEngine>, "No instruction match and no recursion possible.");
}
}
@ -135,7 +137,7 @@ struct ThrCopy;
template <class Copy_Atom,
class LayoutCopy_TV, // (tid,vid) -> coord [Need not be 2D...]
class ShapeTile_MN> // coord space
class ShapeTiler_MN> // coord space
struct TiledCopy : Copy_Atom
{
// Layout information from the CopyAtom
@ -148,8 +150,7 @@ struct TiledCopy : Copy_Atom
using AtomNumVal = decltype(size<1>(AtomLayoutRef{}));
// Layout information for the TiledCopy
using Tiler_MN = ShapeTile_MN;
using TiledShape_MN = decltype(shape(ShapeTile_MN{}));
using Tiler_MN = ShapeTiler_MN;
using TiledLayout_TV = LayoutCopy_TV;
using TiledNumThr = decltype(size<0>(TiledLayout_TV{}));
using TiledNumVal = decltype(size<1>(TiledLayout_TV{}));
@ -172,12 +173,9 @@ struct TiledCopy : Copy_Atom
auto
tidfrg_S(STensor&& stensor)
{
constexpr int R = remove_cvref_t<STensor>::rank;
static_assert(R >= rank_v<TiledShape_MN>, "Rank of tensor to be partitioned too small.");
// Generalize the dimension checks for arbitrary rank
//CUTE_STATIC_ASSERT_V(size<0>(stensor) % size<0>(TiledShape_MNK{}) == Int<0>{});
//CUTE_STATIC_ASSERT_V(size<1>(stensor) % size<1>(TiledShape_MNK{}) == Int<0>{});
CUTE_STATIC_ASSERT_V(rank(stensor) >= rank(Tiler_MN{}), "Rank of tensor to be partitioned too small.");
// Tile the stensor and compute the (src-thr, src-val) -> (ref-thr, ref-val) layout
return tile2thrfrg(zipped_divide(stensor,Tiler_MN{}), right_inverse(AtomLayoutRef{}).compose(AtomLayoutSrc{}));
}
@ -196,17 +194,14 @@ struct TiledCopy : Copy_Atom
auto
tidfrg_D(DTensor&& dtensor)
{
constexpr int R = remove_cvref_t<DTensor>::rank;
static_assert(R >= rank_v<TiledShape_MN>, "Rank of tensor to be partitioned too small.");
// Generalize the dimension checks for arbitrary rank
//CUTE_STATIC_ASSERT_V(size<0>(stensor) % size<0>(TiledShape_MNK{}) == Int<0>{});
//CUTE_STATIC_ASSERT_V(size<1>(stensor) % size<1>(TiledShape_MNK{}) == Int<0>{});
CUTE_STATIC_ASSERT_V(rank(dtensor) >= rank(Tiler_MN{}), "Rank of tensor to be partitioned too small.");
// Tile the dtensor and compute the (dst-thr, dst-val) -> (ref-thr, ref-val) layout
return tile2thrfrg(zipped_divide(dtensor,Tiler_MN{}), right_inverse(AtomLayoutRef{}).compose(AtomLayoutDst{}));
}
// Tile a tensor or a layout from shape
// (Tile,(RestM,RestN,...))
// ((TileM,TileN,...), (RestM,RestN,...))
// to shape
// ((ThrV,ThrX),FrgV,(RestM,RestN,...))
template <class Tensor, class Ref2TrgLayout>
@ -232,7 +227,7 @@ struct TiledCopy : Copy_Atom
// Transform the tile mode
auto tv_tensor = tensor.compose(thrval2mn, _);
// ((thrid,val),(RM,RN,...))
// ((thrid,val),(RestM,RestN,...))
// Unfold and return
return tv_tensor(make_coord(_,_), _);
@ -253,7 +248,7 @@ struct TiledCopy : Copy_Atom
auto V = size<0>(tensor);
auto frg_layout_mn = upcast<TiledNumThr{} * V>(right_inverse(TiledLayout_TV{}).with_shape(TiledShape_MN{}));
auto frg_layout_mn = upcast<TiledNumThr{} * V>(right_inverse(TiledLayout_TV{}).with_shape(shape(Tiler_MN{})));
// (m,n) -> v_idx -- The shape and order of the V inside of TiledLayout_TV
auto frg_layout_v = zipped_divide(logical_product(make_layout(V), right_inverse(frg_layout_mn)), make_layout(AtomNumVal{}));
@ -278,7 +273,7 @@ struct TiledCopy : Copy_Atom
get_layoutS_TV()
{
// (M,N) -> (M,N)
auto ref_S = make_layout(make_shape(TiledShape_MN{}, Int<1>{}));
auto ref_S = make_layout(make_shape(shape(Tiler_MN{}), Int<1>{}));
// (thr_idx,val_idx) -> (M,N)
return tile2thrfrg(ref_S, right_inverse(AtomLayoutRef{}).compose(AtomLayoutSrc{}))(_,_,Int<0>{});
}
@ -290,7 +285,7 @@ struct TiledCopy : Copy_Atom
// (thr_idx,val_idx) -> (M,N)
auto layoutS_TV = get_layoutS_TV();
// (M,K) -> (thr_idx,val_idx)
auto layoutS_MK = right_inverse(layoutS_TV).with_shape(TiledShape_MN{});
auto layoutS_MK = right_inverse(layoutS_TV).with_shape(shape(Tiler_MN{}));
// athrid = (v,m,k) -> thr_idx
auto thrID_S = make_layout(size<0>(TiledLayout_TV{}));
@ -303,7 +298,7 @@ struct TiledCopy : Copy_Atom
get_layoutD_TV()
{
// (M,N) -> (M,N)
auto ref_D = make_layout(make_shape(TiledShape_MN{}, Int<1>{}));
auto ref_D = make_layout(make_shape(shape(Tiler_MN{}), Int<1>{}));
// (thr_idx,val_idx) -> (M,N)
return tile2thrfrg(ref_D, right_inverse(AtomLayoutRef{}).compose(AtomLayoutDst{}))(_,_,Int<0>{});
}
@ -315,7 +310,7 @@ struct TiledCopy : Copy_Atom
// (thr_idx,val_idx) -> (M,N)
auto layoutD_TV = get_layoutD_TV();
// (M,K) -> (thr_idx,val_idx)
auto layoutD_MK = right_inverse(layoutD_TV).with_shape(TiledShape_MN{});
auto layoutD_MK = right_inverse(layoutD_TV).with_shape(shape(Tiler_MN{}));
// athrid = (v,m,k) -> thr_idx
auto thrID_D = make_layout(size<0>(TiledLayout_TV{}));
@ -406,51 +401,44 @@ make_tiled_copy_impl(Copy_Atom<Args...> const& atom,
// These tile the Copy_Atom as a whole
//
template <class... Args,
class TiledMMA>
template <class... CArgs, class... MArgs>
CUTE_HOST_DEVICE
auto
make_tiled_copy_A(Copy_Atom<Args...> const& copy_atom,
TiledMMA const& tiled_mma)
make_tiled_copy_A(Copy_Atom<CArgs...> const& copy_atom,
TiledMMA<MArgs...> const& mma)
{
using MNK = typename TiledMMA::TiledShape_MNK;
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), make_shape(size<0>(MNK{}),size<2>(MNK{})));
return make_tiled_copy_impl(copy_atom, mma.get_layoutA_TV(), make_shape(tile_size<0>(mma),tile_size<2>(mma)));
}
template <class... Args,
class TiledMMA>
template <class... CArgs, class... MArgs>
CUTE_HOST_DEVICE
auto
make_tiled_copy_B(Copy_Atom<Args...> const& copy_atom,
TiledMMA const& tiled_mma)
make_tiled_copy_B(Copy_Atom<CArgs...> const& copy_atom,
TiledMMA<MArgs...> const& mma)
{
using MNK = typename TiledMMA::TiledShape_MNK;
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutB_TV(), make_shape(size<1>(MNK{}),size<2>(MNK{})));
return make_tiled_copy_impl(copy_atom, mma.get_layoutB_TV(), make_shape(tile_size<1>(mma),tile_size<2>(mma)));
}
template <class... Args,
class TiledMMA>
template <class... CArgs, class... MArgs>
CUTE_HOST_DEVICE
auto
make_tiled_copy_C(Copy_Atom<Args...> const& copy_atom,
TiledMMA const& tiled_mma)
make_tiled_copy_C(Copy_Atom<CArgs...> const& copy_atom,
TiledMMA<MArgs...> const& mma)
{
using MNK = typename TiledMMA::TiledShape_MNK;
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), make_shape(size<0>(MNK{}),size<1>(MNK{})));
return make_tiled_copy_impl(copy_atom, mma.get_layoutC_TV(), make_shape(tile_size<0>(mma),tile_size<1>(mma)));
}
// returns the smallest tiled copy that can retile LayoutC_TV
// for use with pipelined epilogues with subtiled stores
template <class... Args,
class TiledMMA>
template <class... CArgs, class... MArgs>
CUTE_HOST_DEVICE
auto
make_tiled_copy_C_atom(Copy_Atom<Args...> const& copy_atom,
TiledMMA const& tiled_mma)
make_tiled_copy_C_atom(Copy_Atom<CArgs...> const& copy_atom,
TiledMMA<MArgs...> const& mma)
{
// Truncate the V-layout to just the Copy_Atom, keep the V-order
auto layoutC_TV = tiled_mma.get_layoutC_TV();
auto copy_V = Int<Copy_Atom<Args...>::NumValSrc>{};
auto layoutC_TV = mma.get_layoutC_TV();
auto copy_V = Int<Copy_Atom<CArgs...>::NumValSrc>{};
CUTE_STATIC_ASSERT_V(copy_V <= size<1>(layoutC_TV));
auto layout_TV = composition(layoutC_TV, make_layout(make_shape(size<0>(layoutC_TV), copy_V)));
@ -458,8 +446,7 @@ make_tiled_copy_C_atom(Copy_Atom<Args...> const& copy_atom,
// Tiler -- Find the active elements in the MMA tensor and generate a tiler to extract them
// Convert to the awkward by-mode tiler to preserve the modes of the tiled MMA
using MNK = typename TiledMMA::TiledShape_MNK;
auto mma_tiler = make_shape(size<0>(MNK{}),size<1>(MNK{}));
auto mma_tiler = make_shape(tile_size<0>(mma),tile_size<1>(mma));
auto mma_zeros = repeat_like(mma_tiler, Int<0>{});
auto tiler = transform(make_seq<rank(mma_tiler)>{}, [&](auto i) {
@ -474,8 +461,6 @@ make_tiled_copy_C_atom(Copy_Atom<Args...> const& copy_atom,
// (tid,vid) -> tile_coord
auto layout_tv = composition(left_inverse(tile2mma), layout_TV);
using MNK = typename TiledMMA::TiledShape_MNK;
return make_tiled_copy_impl(copy_atom, layout_tv, tiler);
}
@ -655,8 +640,10 @@ print(TiledCopy<Atom, Args...> const& copy, char const* pad = "")
template <class TiledCopy, class ThrIdx>
CUTE_HOST_DEVICE
void
print(ThrCopy<TiledCopy, ThrIdx> const&)
print(ThrCopy<TiledCopy, ThrIdx> const& thr_copy)
{
print("ThrCopy\n");
print(" ThrIdx: "); print(thr_copy.thr_idx_); print("\n");
print(TiledCopy{});
}

View File

@ -43,10 +43,14 @@
namespace cute
{
template <class GmemStrides_, class TmaGBasis_, class TmaSwizzle_>
template <class GmemTmaBasisStrides_, class TmaGmemBasis_, class TmaSwizzle_>
struct AuxTmaParams {
using GmemStrides = GmemStrides_;
using GmemStrides = GmemTmaBasisStrides_; // Strides for Gmem mode -> Tma coord mode, may be dynamic
GmemStrides g_stride_;
using TmaGmemBasis = TmaGmemBasis_; // Layout for Tma box shape -> Gmem mode(s), always static
static_assert(is_static<TmaGmemBasis>::value);
using TmaSwizzle = TmaSwizzle_; // Tma swizzle, always Swizzle<B,M,S>
static_assert(is_static<TmaSwizzle>::value);
};
//////////////////////////////////////////////////////////////////////////////
@ -138,13 +142,19 @@ struct Copy_Traits<SM90_TMA_LOAD, NumBitsPerTMA, AuxParams_>
// Construct an executable SM90_TMA_LOAD with tma_mbar
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM90_TMA_LOAD_OP, NumBitsPerTMA>
with(uint64_t& tma_mbar, uint16_t const& multicast_mask = 0) const {
with(uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask = 0) const {
// We accept multicast_mask here to keep the API for both atoms consistent
// assert(multicast_mask == 0);
(void) multicast_mask;
return {tma_desc_, tma_mbar};
}
// Construct an executable SM90_TMA_LOAD with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm)
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM90_TMA_LOAD_OP, NumBitsPerTMA>
with(TmaDescriptor const* new_tma_desc, uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask = 0) const {
// We accept multicast_mask here to keep the API for both atoms consistent
return {*new_tma_desc, tma_mbar};
}
// Generate the TMA coord tensor
template <class GShape>
CUTE_HOST_DEVICE constexpr
@ -251,6 +261,13 @@ struct Copy_Traits<SM90_TMA_LOAD_MULTICAST, NumBitsPerTMA, AuxParams_>
return {tma_desc_, tma_load_mbar, multicast_mask};
}
// Construct an executable SM90_TMA_LOAD_MULTICAST_OP with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm)
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM90_TMA_LOAD_MULTICAST_OP, NumBitsPerTMA>
with(TmaDescriptor const* new_tma_desc, uint64_t& tma_load_mbar, uint16_t const& multicast_mask) const {
return {*new_tma_desc, tma_load_mbar, multicast_mask};
}
// Generate the TMA coord tensor
template <class GShape>
CUTE_HOST_DEVICE constexpr
@ -508,51 +525,91 @@ coalesce_256(Layout<Shape,Stride> const& layout)
return coalesce_256_impl<1>(flat_shape, flat_stride, get<0>(flat_shape), get<0>(flat_stride));
}
template <class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
coalesce_256(Tensor<Engine,Layout> const& tensor)
{
return make_tensor(tensor.data(), coalesce_256(tensor.layout()));
}
// Use a smem_inv_h to read through the GMEM tensor
// and construct a TMA Descriptor for the resulting instruction
// At the same time, construct the Tma Tensor's Stride to generate
// the TMA coordinates that the instruction consumes.
//
template <class TmaInternalType,
class GEngine, class GLayout,
class SShape, class SStride,
int B, int M, int S>
CUTE_HOST_RTC
class VShape, class VStride>
CUTE_HOST_DEVICE constexpr
auto
make_tma_copy_desc(Tensor<GEngine,GLayout> const& gtensor, // The original GMEM Tensor
Layout<SShape,SStride> const& smem_inv_h, // smem_idx to hier gmode
Swizzle<B,M,S> const& swizzle) // Swizzle fn on smem_idx
construct_tma_gbasis(Tensor<GEngine,GLayout> const& gtensor, // The original GMEM Tensor
Layout<SShape,SStride> const& slayout, // The layout of SMEM
Layout<VShape,VStride> const& cta_v_map) // smem_idx to hier gmode
{
//
// TMA parameter checking
//
CUTE_STATIC_ASSERT_V(product_each(shape(slayout)) == product_each(shape(cta_v_map)),
"TMA requires CTA_Tile and SLayout top-level shape equivalence.");
#if 0
print("gtensor : "); print(gtensor); print("\n");
print("slayout : "); print(slayout); print("\n");
print("cta_v_map : "); print(cta_v_map); print("\n");
#endif
//
// TMA slayout manipulation
//
// Invert the smem to get the largest contiguous vector in the smem layout
// smem idx -> smem coord
auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout));
// Compose with the V-Map to convert smem coord (CTA val idx) to gmem mode
// smem idx -> gmem mode
auto sidx2gmode_full = coalesce(composition(cta_v_map, inv_smem_layout));
#if 0
print("inv_smem_layout : "); print(inv_smem_layout); print("\n");
print("sidx2gmode_full : "); print(sidx2gmode_full); print("\n");
#endif
//
// TMA gtensor truncation
//
// Truncate any incompatibilities -- no starting in the middle of gmodes
auto smem_rank = find_if(stride(sidx2gmode_full), [](auto e) {
[[maybe_unused]] auto v = basis_value(e);
return not is_constant<1,decltype(v)>{};
});
static_assert(smem_rank > 0, "Could not find a common tile-gmem vectorization. Does the Tile select out major GMEM modes?");
// Keep only the static-1 basis modes into gmem
auto sidx2gmode = take<0,smem_rank>(sidx2gmode_full);
#if 0
print("smem_rank : "); print(smem_rank); print("\n");
print("sidx2gmode : "); print(sidx2gmode); print("\n");
#endif
//
// TMA gtensor manipulation
//
// The smem vector is the same units as gtensor, so compose first and then recast
// tma_val_idx:gmem_strides
Tensor tile_gstride = recast<TmaInternalType>(gtensor.compose(smem_inv_h));
auto tile_gstride = recast<TmaInternalType>(gtensor.compose(sidx2gmode)).layout();
// Coalesce modes up to size-256 (the maximum TMA box extent in units of TmaInternalType)
// tma_box_shape:gmem_strides
Tensor tma_gstride = coalesce_256(tile_gstride);
auto tma_gstride = coalesce_256(tile_gstride);
// Perform the tiling to the gmem vector again, but with indirections to the gtensor modes
// Perform the tiling, recast, and coalesce to the gmem vector again, but with indirections to the gtensor modes
auto gbasis = make_identity_layout(shape(gtensor));
auto tile_gbasis_tmp = gbasis.compose(smem_inv_h);
auto tile_gbasis_tmp = gbasis.compose(sidx2gmode);
// Instead of the recast (gbasis doesn't have type info), replace the shape with the already-recasted shape
// tma_box_shape:gmem_mode
auto tile_gbasis = make_layout(shape(tile_gstride), stride(tile_gbasis_tmp));
// Recast the original tensor for shape inspections
auto gtensor_T = recast<TmaInternalType>(gtensor);
// "Coalesce" the tile basis into a compatible shape with the tma_gstride
auto tma_gbasis_tile = tile_gbasis.compose(make_layout(wrap(shape(tma_gstride))));
// Recast the original tensor for shape/stride inspections
Tensor gtensor_T = recast<TmaInternalType>(gtensor);
// Find missing bases that don't appear in tile_gbasis
// NOTE This is essentially ArithmeticTuple complement...
// NOTE in pursuit of implementing an ArithmeticTuple logical_divide for smem_inv_h
auto tile_gbasis_remaining_stride = filter_tuple(flatten(shape (gtensor_T)), flatten(stride(gtensor_T)),
flatten(stride(gbasis)),
[&](auto s, auto d, auto e)
@ -561,7 +618,7 @@ make_tma_copy_desc(Tensor<GEngine,GLayout> const& gtensor, // The original GM
return cute::tuple<>{}; // If size-1 or stride-0, then don't append
} else {
using E = decltype(e);
auto has_e = any_of(stride(tile_gbasis), [] (auto tb) { return tb == E{}; });
auto has_e = any_of(flatten(stride(tma_gbasis_tile)), [] (auto tb) { return tb == E{}; });
if constexpr (decltype(has_e)::value) {
return cute::tuple<>{}; // If d was found, then don't append
} else {
@ -569,13 +626,10 @@ make_tma_copy_desc(Tensor<GEngine,GLayout> const& gtensor, // The original GM
}
}
});
auto tile_gbasis_remaining_rank = rank(tile_gbasis_remaining_stride);
// "Coalesce" the tile basis into a compatible shape with the tma
auto tma_gbasis_tile = tile_gbasis.compose(make_layout(wrap(shape(tma_gstride))));
// Append the remaining basis modes that contribute to the TMA with size-1
auto tma_gbasis_full = make_layout(tuple_cat(wrap( shape(tma_gbasis_tile)), wrap(repeat<tile_gbasis_remaining_rank>(Int<1>{}))),
auto tile_gbasis_remaining_shape = repeat<rank(tile_gbasis_remaining_stride)>(Int<1>{});
auto tma_gbasis_full = make_layout(tuple_cat(wrap( shape(tma_gbasis_tile)), wrap(tile_gbasis_remaining_shape )),
tuple_cat(wrap(stride(tma_gbasis_tile)), wrap(tile_gbasis_remaining_stride)));
// Group the trailing modes to make this max rank-5 -- TMA rank limitation
@ -583,15 +637,98 @@ make_tma_copy_desc(Tensor<GEngine,GLayout> const& gtensor, // The original GM
auto tma_gbasis = group<cute::min(rank(tma_gbasis_full),4),-1>(tma_gbasis_full);
#if 0
print("smem_inv_h : "); print(smem_inv_h); print("\n");
print("gtensor : "); print(gtensor); print("\n");
print("tile_gstride : "); print(tile_gstride); print("\n");
print("tma_gstride : "); print(tma_gstride); print("\n");
print("gbasis : "); print(gbasis); print("\n");
print("tile_gbasis : "); print(tile_gbasis); print("\n");
print("tile_gbasis : "); print(tma_gbasis_tile); print("\n");
print("tma_gbasis : "); print(tma_gbasis); print("\n");
#endif
return tma_gbasis;
}
template <class GEngine, class GLayout,
class TmaGmemBasisStride,
class ShapeT, size_t TmaRank>
CUTE_HOST_DEVICE constexpr
void
fill_tma_gmem_shape_stride(Tensor<GEngine,GLayout> const& gtensor, // Gmem Shapes and Strides, in units of TmaInternalType
TmaGmemBasisStride const& tma_gbasis_stride, // Map Tma mode idx -> Gmem mode(s)
cute::array<ShapeT, TmaRank> & gmem_prob_shape, // Tma Shapes, uint32_t or uin64_t
cute::array<uint64_t, TmaRank> & gmem_prob_stride) // Tma Strides
{
static_assert(is_tuple<TmaGmemBasisStride>::value);
static_assert(is_same<uint32_t, ShapeT>::value || is_same<uint64_t, ShapeT>::value);
using TmaInternalType = typename GEngine::value_type;
constexpr int tma_rank = decltype(rank(tma_gbasis_stride))::value;
static_assert(TmaRank >= tma_rank);
auto gmem_shape = shape(gtensor);
auto gmem_stride = stride(gtensor);
// Use the indirections in tma_gbasis_stride into gtensor to construct the tma gmem shapes/strides
for_each(make_seq<tma_rank>{}, [&](auto i) {
constexpr int tma_i_rank = decltype(rank<i>(tma_gbasis_stride))::value;
if constexpr (tma_i_rank == 1) {
// Trivial contribution of this gmem mode to this tma mode
auto ej = unwrap(get<i>(tma_gbasis_stride));
gmem_prob_shape[i] = basis_get(ej, gmem_shape);
gmem_prob_stride[i] = basis_get(ej, gmem_stride) * sizeof_bits_v<TmaInternalType> / 8;
} else {
// Apply a recurrence to each gmem mode that contributes to this tma mode
for_each(get<i>(tma_gbasis_stride), [&](auto ej) {
// Problem shape
uint64_t shape_j = basis_get(ej, gmem_shape);
// Problem stride (in bytes)
uint64_t stride_j = basis_get(ej, gmem_stride) * sizeof_bits_v<TmaInternalType> / 8;
uint64_t old_stride = gmem_prob_stride[i];
gmem_prob_stride[i] = gcd(gmem_prob_stride[i], stride_j);
if (gmem_prob_stride[i] != 0) {
// Recurrence: g_shape = (s_i - 1) * (d_i / gcd_j d_j) + 1
gmem_prob_shape[i] = (gmem_prob_shape[i]-1) * (old_stride / gmem_prob_stride[i])
+ (shape_j-1) * (stride_j / gmem_prob_stride[i])
+ 1;
} else {
gmem_prob_shape[i] = shape_j;
}
});
}
});
}
// Overload for an existing Copy_Traits
template <class GEngine, class GLayout,
class Op, class Bits, class Aux,
class ShapeT, size_t TmaRank>
CUTE_HOST_DEVICE constexpr
void
fill_tma_gmem_shape_stride(Copy_Traits<Op,Bits,Aux> const& tma_traits,
Tensor<GEngine,GLayout> const& gtensor, // Gmem Shapes and Strides, value_type = TmaInternalType
cute::array<ShapeT, TmaRank> & gmem_prob_shape, // Tma Shapes, uint32_t or uin64_t
cute::array<uint64_t, TmaRank> & gmem_prob_stride) // Tma Strides
{
return fill_tma_gmem_shape_stride(gtensor, stride(typename Aux::TmaGmemBasis{}),
gmem_prob_shape, gmem_prob_stride);
}
// Use a sidx2gmode to read through the GMEM tensor
// and construct a TMA Descriptor for the resulting instruction
// At the same time, construct the Tma Tensor's Stride to generate
// the TMA coordinates that the instruction consumes.
//
template <class TmaInternalType,
class GEngine, class GLayout,
class TShape, class TStride,
int B, int M, int S>
CUTE_HOST_RTC
auto
make_tma_copy_desc(Tensor<GEngine,GLayout> const& gtensor, // The original GMEM Tensor
Layout<TShape,TStride> const& tma_gbasis, // TMA mode -> GMEM mode mapping
Swizzle<B,M,S> const& swizzle, // Swizzle fn on smem_idx
uint32_t num_multicast) // The number of CTAs in multicasting
{
//
// TMA desc creation
//
@ -602,31 +739,16 @@ make_tma_copy_desc(Tensor<GEngine,GLayout> const& gtensor, // The original GM
// TMA gmem desc info
//
// Recast the original tensor for shape/stride inspections
Tensor gtensor_T = recast<TmaInternalType>(gtensor);
void* gmem_address = (void*) raw_pointer_cast(gtensor_T.data());
auto gmem_layout = gtensor_T.layout();
cute::array<uint64_t, 5> gmem_prob_shape = {1,1,1,1,1};
cute::array<uint64_t, 5> gmem_prob_stride = {0,0,0,0,0};
// Use the indirections in tma_gbasis in the values of flat_glayout to construct the gmem shapes/strides
for_each(make_seq<tma_dim>{}, [&](auto i) {
for_each(stride<i>(tma_gbasis), [&](auto ej) {
// Problem stride
uint64_t stride_j = ceil_div(basis_get(ej, stride(gmem_layout)) * sizeof_bits_v<TmaInternalType>, 8);
uint64_t old_stride = gmem_prob_stride[i];
gmem_prob_stride[i] = gcd(gmem_prob_stride[i], stride_j);
// Problem shape
uint64_t shape_j = basis_get(ej, shape(gmem_layout));
if (gmem_prob_stride[i] != 0) {
// Recurrence: g_shape = (s_i - 1) * (d_i / gcd_j d_j) + 1
gmem_prob_shape[i] = (gmem_prob_shape[i]-1) * (old_stride / gmem_prob_stride[i])
+ (shape_j-1) * (stride_j / gmem_prob_stride[i])
+ 1;
} else {
gmem_prob_shape[i] = shape_j;
}
});
});
fill_tma_gmem_shape_stride(gtensor_T, stride(tma_gbasis), gmem_prob_shape, gmem_prob_stride);
assert((reinterpret_cast<uint64_t>(gmem_address) & 0b1111) == 0); // Address must be 16B-aligned
@ -663,6 +785,13 @@ make_tma_copy_desc(Tensor<GEngine,GLayout> const& gtensor, // The original GM
for_each(make_seq<tma_dim>{}, [&](auto i) {
smem_box_shape[i] *= size<i>(tma_gbasis);
});
// Finally, truncate the tma box by the num_multicast
for (uint32_t i = tma_dim-1, multicast = num_multicast; multicast > 1; --i) {
assert(smem_box_shape[i] % multicast == 0 || multicast % smem_box_shape[i] == 0);
uint32_t new_mult = ceil_div(multicast, smem_box_shape[i]);
smem_box_shape[i] = ceil_div(smem_box_shape[i], multicast);
multicast = new_mult;
}
assert(smem_box_shape[0] >= (uint32_t(1))); // Size must be min 1
assert(smem_box_shape[0] <= (uint32_t(1) << 8)); // Size must be max 2^8 = 256
@ -740,26 +869,27 @@ make_tma_copy_desc(Tensor<GEngine,GLayout> const& gtensor, // The original GM
auto recast_ratio = cute::ratio(Int<sizeof_bits<typename GEngine::value_type>::value>{},
Int<sizeof_bits< TmaInternalType>::value>{});
auto gbasis = make_basis_like(shape(gtensor));
// Finally, get the inverse permutation of the E<i> bases for the mocked gmem stride
// NOTE This is essentially ArithmeticTuple inverse...
auto gmem_stride_bases = transform_leaf(stride(gbasis), [&](auto ei) {
auto gmem_tma_basis_stride = transform_leaf(gbasis, [&](auto ei) {
auto si = basis_get(ei, shape(gmem_layout));
auto di = basis_get(ei, stride(gmem_layout));
if constexpr (is_constant<1, decltype(si)>::value || is_constant<0, decltype(di)>::value) {
return Int<0>{}; // If size-1 or stride-0, return arithmetic identity -- no contribution to the TMA
} else {
auto tma_gbasis_stride = stride(tma_gbasis);
auto tma_gmem_basis_stride = stride(tma_gbasis);
// Find j such that E<i> is in stride<j>(tma_gbasis)
using EI = decltype(ei);
[[maybe_unused]] auto j = find_if(tma_gbasis_stride, [&](auto tma_stride_j) { return any_of(tma_stride_j, [&](auto dj) { return dj == EI{}; }); });
if constexpr (decltype(j == rank(tma_gbasis_stride))::value) {
[[maybe_unused]] auto j = find_if(tma_gmem_basis_stride, [&](auto tma_stride_j) { return any_of(tma_stride_j, [&](auto dj) { return dj == EI{}; }); });
if constexpr (decltype(j == rank(tma_gmem_basis_stride))::value) {
return Int<0>{}; // If not-found, return arithmetic identity -- no contribution to the TMA
} else
if constexpr (decltype(j == Int<0>{})::value) {
auto scale = recast_ratio * basis_get(ei, stride(gtensor));
return E<j>{} * scale; // Return TMA Coord basis -- with a recast scale factor
} else
if constexpr (decltype(rank<j>(tma_gbasis_stride) == Int<1>{})::value) {
if constexpr (decltype(rank<j>(tma_gmem_basis_stride) == Int<1>{})::value) {
return E<j>{}; // Return TMA Coord basis -- known scale of Int<1>{}
} else {
int32_t scale = ceil_div(int32_t(di * sizeof_bits_v<TmaInternalType> / cute::max(gmem_prob_stride[j], 16)), 8);
@ -768,14 +898,64 @@ make_tma_copy_desc(Tensor<GEngine,GLayout> const& gtensor, // The original GM
}
});
#if 0
print("tma_gbasis : "); print(gmem_stride_bases); print("\n");
#endif
#if 0
print("gmem_tma_basis_stride : "); print(gmem_tma_basis_stride); print("\n");
#endif
using AuxParams = AuxTmaParams<decltype(gmem_stride_bases),
using AuxParams = AuxTmaParams<decltype(gmem_tma_basis_stride),
decltype(tma_gbasis),
decltype(swizzle)>;
return cute::make_tuple(tma_desc, AuxParams{gmem_stride_bases});
return cute::make_tuple(tma_desc, AuxParams{gmem_tma_basis_stride});
}
template <class TmaInternalType,
class CopyOp,
class GEngine, class GLayout,
class SLayout,
class VShape, class VStride>
CUTE_HOST_RTC
auto
make_tma_copy_atom(CopyOp,
Tensor<GEngine,GLayout> const& gtensor, // Full GMEM Tensor
SLayout const& slayout, // CTA Tile of SMEM, potentially swizzled
uint32_t const& num_multicast, // The number of CTAs involved in multicasting
Layout<VShape,VStride> const& cta_v_map) // V: CTA val idx -> gmem mode
{
//
// TMA truncated layout
//
auto smem_swizzle = get_swizzle_portion(slayout);
auto smem_layout = get_nonswizzle_portion(slayout);
auto tma_gbasis = detail::construct_tma_gbasis<TmaInternalType>(gtensor, smem_layout, cta_v_map);
//
// Construct the TMA Desc and the strides of the TMA Tensor
//
auto [tma_desc, aux_params] = detail::make_tma_copy_desc<TmaInternalType>(gtensor,
tma_gbasis,
smem_swizzle,
num_multicast);
//
// Construct the Copy_Traits
//
constexpr int num_bits_per_tma = decltype(size(tma_gbasis))::value * sizeof_bits_v<TmaInternalType>;
using Traits = Copy_Traits<CopyOp, cute::C<num_bits_per_tma>, decltype(aux_params)>;
using Atom = Copy_Atom<Traits, typename GEngine::value_type>;
Traits tma_traits{tma_desc, aux_params};
#if 0
print("num_bits_per_tma : "); print(num_bits_per_tma); print("\n");
print("g_stride_bases : "); print(tma_traits.aux_params_.g_stride_); print("\n");
#endif
// Return the Copy_Atom
return Atom{tma_traits};
}
// The "logical TMA tid" is a map from the CTA rank to its logical id
@ -790,122 +970,46 @@ template <class TmaInternalType,
class VShape, class VStride>
CUTE_HOST_RTC
auto
make_tma_copy_tiled(CopyOp,
make_tma_copy_tiled(CopyOp const& copy_op,
Tensor<GEngine,GLayout> const& gtensor, // Full GMEM Tensor
SLayout const& slayout, // CTA Tile of SMEM
Layout<TShape,TStride> const& cta_t_map, // T: CTA thr idx -> logical TMA tid
Layout<VShape,VStride> const& cta_v_map) // V: CTA val idx -> gmem mode
{
//
// TMA parameter checking
//
CUTE_STATIC_ASSERT_V(product_each(shape(slayout)) == product_each(shape(cta_v_map)),
"TMA requires CTA_Tile and SLayout top-level shape equivalence.");
CUTE_STATIC_ASSERT_V(size(slayout) % cosize(cta_t_map) == Int<0>{},
"Number of active CTAs in TMA must divide domain size of slayout.");
//
// TMA slayout manipulation
//
// Invert the smem to get the largest contiguous vector in the smem layout
// smem idx -> smem coord
auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout));
// Compose with the V-Map to convert smem coord (CTA val idx) to gmem mode
// smem idx -> gmem mode
auto sidx_to_gmode = coalesce(composition(cta_v_map, inv_smem_layout));
#if 0
print("g_tensor : "); print(gtensor); print("\n");
print("s_layout : "); print(slayout); print("\n");
print("cta_t_map : "); print(cta_t_map); print("\n");
print("cta_v_map : "); print(cta_v_map); print("\n");
print("inv_s_layout : "); print(inv_smem_layout); print("\n");
print("sidx_to_gmode : "); print(sidx_to_gmode); print("\n");
#endif
//
// TMA gtensor manipulation
//
// Generate a TupleBasis for the gtensor
// gmem coord -> gmem coord
auto glayout_basis = make_identity_layout(shape(gtensor));
// Tile the modes of gtensor with the truncated cta_v_map o inv_smem_layout_trunc
// smem idx -> gmem coord
auto tma_layout_full = flatten(composition(glayout_basis, sidx_to_gmode));
// Truncate any incompatibilities -- no starting in the middle of gmodes
auto smem_rank = find_if(stride(tma_layout_full), [](auto e) {
[[maybe_unused]] auto v = basis_value(e);
return not is_constant<1,decltype(v)>{};
});
static_assert(smem_rank > 0, "Could not find a common tile-gmem vectorization. Does the Tile select out major GMEM modes?");
// Keep only the static-1 basis modes into gmem
auto tma_layout_trunc = take<0,smem_rank>(tma_layout_full);
// Keep only the portion each multicast CTA will be responsible for
auto tma_layout_v = composition(tma_layout_trunc, shape_div(size(tma_layout_trunc), cosize(cta_t_map)));
#if 0
print("glayout_basis : "); print(glayout_basis); print("\n");
print("tma_layout_full : "); print(tma_layout_full); print("\n");
print("tma_layout_trunc: "); print(tma_layout_trunc); print("\n");
print("tma_layout_v : "); print(tma_layout_v); print("\n");
#endif
//
// Construct the TMA Desc and the strides of the TMA Tensor
//
auto [tma_desc, aux_params] = detail::make_tma_copy_desc<TmaInternalType>(gtensor,
tma_layout_v,
get_swizzle_portion(slayout));
//
// Construct the Copy_Traits
//
using T = typename GEngine::value_type;
constexpr int num_bits_per_tma = decltype(size(tma_layout_trunc))::value * sizeof_bits_v<T>;
using Traits = Copy_Traits<CopyOp, cute::C<num_bits_per_tma>, decltype(aux_params)>;
using Atom = Copy_Atom<Traits, T>;
Traits tma_traits{tma_desc, aux_params};
#if 0
print("num_bits_per_tma : "); print(num_bits_per_tma); print("\n");
print("g_stride_bases : "); print(tma_traits.aux_params_.g_stride_); print("\n");
#endif
Copy_Atom atom = make_tma_copy_atom<TmaInternalType>(copy_op, gtensor, slayout,
cosize(cta_t_map), cta_v_map);
//
// Construct the TiledCopy
//
auto cta_tiler = product_each(shape(cta_v_map));
[[maybe_unused]] auto cta_tiler = product_each(shape(cta_v_map));
auto num_elems_per_tma = size<1>(typename decltype(atom)::RefLayout{}) / Int<sizeof_bits_v<typename GEngine::value_type>>{};
// smem idx -> smem coord
auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout));
// CTA V -> smem_coord
auto layout_v = composition(inv_smem_layout, size(tma_layout_trunc));
auto layout_v = composition(inv_smem_layout, num_elems_per_tma);
// Scale that up to cover all of the smem_coords
auto layout_V = tile_to_shape(make_layout(layout_v), size(cta_v_map));
// CTA T -> smem idx
auto layout_t = make_layout(cosize(cta_t_map), shape_div(size(tma_layout_trunc), cosize(cta_t_map)));
auto layout_t = make_layout(cosize(cta_t_map), shape_div(num_elems_per_tma, cosize(cta_t_map)));
// CTA TID -> smem coord
auto layout_T = composition(inv_smem_layout, composition(layout_t, cta_t_map));
// Combine with the T mapping
auto layout_TV = make_layout(layout_T, layout_V);
[[maybe_unused]] auto layout_TV = make_layout(layout_T, layout_V);
#if 0
print("cta_tiler : "); print(cta_tiler); print("\n");
print("layout_VT : "); print(layout_VT); print("\n");
print("layout_v : "); print(layout_v); print("\n");
print("layout_V : "); print(layout_V); print("\n");
print("layout_t : "); print(layout_t); print("\n");
print("layout_T : "); print(layout_T); print("\n");
print("layout_TV : "); print(layout_TV); print("\n");
#endif
return TiledCopy<Atom, decltype(layout_TV), decltype(cta_tiler)>{tma_traits};
return TiledCopy<decltype(atom), decltype(layout_TV), decltype(cta_tiler)>{atom};
}
} // end namespace detail
@ -982,7 +1086,7 @@ make_tma_copy_tiled(CopyOp,
copy(tma.with(barrier, mcast_mask), tAgA, tAsA); // copy with supporting TMA params
*/
template <class TmaInternalType,
template <class TmaInternalType = void,
class CopyOp,
class GEngine, class GLayout,
class SLayout,
@ -996,37 +1100,16 @@ make_tma_copy(CopyOp const& copy_op,
CTA_Tiler const& cta_tiler,
Cluster_Size const& cluster_size)
{
auto cta_v_tile = make_identity_layout(shape(gtensor)).compose(cta_tiler);
auto cta_t_tile = make_layout(cluster_size);
return detail::make_tma_copy_tiled<TmaInternalType>(copy_op,
gtensor,
slayout,
cta_t_tile,
cta_v_tile);
auto cta_v_tile = make_identity_layout(shape(gtensor)).compose(cta_tiler);
auto cta_t_tile = make_layout(cluster_size);
// Prefer TmaInternalType if specified. Fallback to GEngine::value_type
using TmaType = conditional_t<is_same<void, TmaInternalType>::value, typename GEngine::value_type, TmaInternalType>;
return detail::make_tma_copy_tiled<TmaType>(copy_op,
gtensor, slayout,
cta_t_tile, cta_v_tile);
}
// Explicit defaulting
template <class CopyOp,
class GEngine, class GLayout,
class SLayout,
class CTA_Tile,
class Cluster_Size>
CUTE_HOST_RTC
auto
make_tma_copy(CopyOp const& copy_op,
Tensor<GEngine,GLayout> const& gtensor,
SLayout const& slayout,
CTA_Tile const& cta_tile,
Cluster_Size const& cluster_size)
{
using TmaInternalType = typename GEngine::value_type;
return make_tma_copy<TmaInternalType>(copy_op,
gtensor,
slayout,
cta_tile,
cluster_size);
}
template <class CopyOp,
class GEngine, class GLayout,
class SLayout>
@ -1039,6 +1122,7 @@ make_tma_copy(CopyOp const& copy_op,
return make_tma_copy(copy_op, gtensor, slayout, product_each(shape(slayout)), Int<1>{});
}
// Explicit defaulting
template <class CopyOp,
class GEngine, class GLayout,
class SLayout,
@ -1053,4 +1137,86 @@ make_tma_copy(CopyOp const& copy_op,
return make_tma_copy(copy_op, gtensor, slayout, product_each(shape(slayout)), cluster_size);
}
////////////////////////////////////
// Experimental Make TMA Atom and Partitioner
///////////////////////////////////
template <class TmaInternalType = void,
class CopyOp,
class GEngine, class GLayout,
class SLayout,
class CTA_Tiler,
class Cluster_Size>
CUTE_HOST_RTC
auto
make_tma_atom(CopyOp const& copy_op,
Tensor<GEngine,GLayout> const& gtensor,
SLayout const& slayout,
CTA_Tiler const& cta_tiler,
Cluster_Size const& cluster_size)
{
auto cta_v_tile = make_identity_layout(shape(gtensor)).compose(cta_tiler);
// Prefer TmaInternalType if specified. Fallback to GEngine::value_type
using TmaType = conditional_t<is_same<void, TmaInternalType>::value, typename GEngine::value_type, TmaInternalType>;
return detail::make_tma_copy_atom<TmaType>(copy_op,
gtensor, slayout,
size(cluster_size), cta_v_tile);
}
// The "VectorCopy Partitioner" for TMA
template <class... Args,
class CtaCoord,
class TShape, class TStride,
class SEngine, class SLayout,
class GEngine, class GLayout>
CUTE_DEVICE
auto
tma_partition(Copy_Atom<Args...> const& copy_atom,
CtaCoord const& cta_coord,
Layout<TShape,TStride> const& cta_layout, // T: CTA coord -> logical multicast id
Tensor<SEngine,SLayout> const& stensor, // SMEM Tensor (TMATile, Iter)
Tensor<GEngine,GLayout> const& gtensor) // GMEM Tensor (TMATile, Iter)
{
// Invert the smem to get the largest contiguous vector in the smem layout
Layout inv_smem_layout = right_inverse(get_nonswizzle_portion(layout<0>(stensor)));
// Scale that up to cover all of the smem_coords
Layout layout_v = tile_to_shape(make_layout(inv_smem_layout), size<0>(stensor));
// Factor out the single-instrucion portion
Layout tma_layout_v = make_layout(Int<Copy_Atom<Args...>::NumValSrc>{});
Layout layout_V = logical_divide(layout_v, tma_layout_v);
// Transform tile mode and coalesce
Tensor gtensor_v = coalesce(gtensor.compose(layout_V, _), Shape<Shape<_1,_1>,_1>{}); // ((TMA,TMA_Iter),Iter)
Tensor stensor_v = coalesce(stensor.compose(layout_V, _), Shape<Shape<_1,_1>,_1>{}); // ((TMA,TMA_Iter),Iter)
#if 0
if (thread0()) {
print("layout_V : "); print(layout_V); print("\n");
print("gtensor_v : "); print(gtensor_v); print("\n");
print("stensor_v : "); print(stensor_v); print("\n");
}
#endif
// Restride the cta-into-tma-instr layout
Layout tma_layout_t = composition(make_layout(Int<1>{}, shape_div(size(tma_layout_v), cosize(cta_layout))), cta_layout);
Layout tma_layout_tv = make_layout(tma_layout_t, tma_layout_v);
// Transform TMA mode
Tensor gtensor_tv = gtensor_v.compose(make_tile(tma_layout_tv, _), _); // (((Thr,Frg),TMA_Iter),Iter)
Tensor stensor_tv = stensor_v.compose(make_tile(tma_layout_tv, _), _); // (((Thr,Frg),TMA_Iter),Iter)
#if 0
if (thread0()) {
print("tma_layout_tv : "); print(tma_layout_tv); print("\n");
print("gtensor_tv : "); print(gtensor_tv); print("\n");
print("stensor_tv : "); print(stensor_tv); print("\n");
}
#endif
// Slice and group Frg,TMA_Iter and return
auto c = make_coord(make_coord(make_coord(cta_coord, _), _), _);
return cute::make_tuple(group_modes<0,2>(gtensor_tv(c)), group_modes<0,2>(stensor_tv(c)));
}
} // end namespace cute

View File

@ -31,12 +31,12 @@
#pragma once
#include <cute/config.hpp>
#include <cute/arch/mma.hpp>
#include <cute/tensor.hpp>
#include <cute/arch/mma.hpp>
#include <cute/atom/mma_traits.hpp>
#include <cute/tensor.hpp>
#include <cute/util/type_traits.hpp>
namespace cute {
@ -196,39 +196,37 @@ struct MMA_Atom<MMA_Traits<Args...>>
template <class TiledMMA, class ThrCoord>
struct ThrMMA;
// @tparam MMA_Atom The MMA_Atom to use in the TiledMMA
// @tparam AtomLayoutMNK The MNK-tiling of the Atom to be performed.
// @tparam PermuationsMNK Permutations to apply to each MNK-mode before tiling for the Atom.
template <class MMA_Atom,
class AtomLayoutMNK = Layout<Shape<_1,_1,_1>>,
class ValLayoutMNK = Layout<Shape<_1,_1,_1>>,
class PermutationsMNK = Tile<Underscore,Underscore,Underscore>>
class AtomLayoutMNK,
class PermutationMNK = Tile<Underscore,Underscore,Underscore>>
struct TiledMMA : MMA_Atom
{
static_assert(rank_v<AtomLayoutMNK> == 3, "TiledMMA requires rank-3 AtomLayoutMNK");
static_assert(rank_v<ValLayoutMNK> == 3, "TiledMMA requires rank-3 ValLayoutMNK");
static_assert(rank_v<PermutationsMNK> == 3, "TiledMMA requires rank-3 PermutationsMNK");
using Atom = MMA_Atom;
using AtomShape_MNK = typename MMA_Atom::Shape_MNK;
using AtomThrID = typename MMA_Atom::ThrID;
using AtomLayoutC_TV = typename MMA_Atom::LayoutC_TV;
using AtomLayoutA_TV = typename MMA_Atom::LayoutA_TV;
using AtomLayoutB_TV = typename MMA_Atom::LayoutB_TV;
// ThrV -> thread_idx
using AtomThrID = typename MMA_Atom::ThrID;
static_assert( rank_v<AtomLayoutMNK> == 3, "TiledMMA requires rank-3 AtomLayoutMNK");
static_assert( rank_v<PermutationMNK> == 3, "TiledMMA requires rank-3 PermutationMNK");
static_assert( is_tile<PermutationMNK>::value, "TiledMMA requires independent permutations of MNK.");
static_assert(is_static<PermutationMNK>::value, "TiledMMA requires static permutations of MNK.");
// (M,N,K)
using TiledShape_MNK = decltype(make_shape(size<0>(AtomShape_MNK{})*size<0>(AtomLayoutMNK{})*size<0>(ValLayoutMNK{}),
size<1>(AtomShape_MNK{})*size<1>(AtomLayoutMNK{})*size<1>(ValLayoutMNK{}),
size<2>(AtomShape_MNK{})*size<2>(AtomLayoutMNK{})*size<2>(ValLayoutMNK{})));
// thrid = (ThrV,ThrM,ThrN,ThrK) -> thr_idx
using ThrLayoutVMNK = decltype(tiled_product(AtomThrID{}, AtomLayoutMNK{}));
ThrLayoutVMNK thr_layout_vmnk_;
// thr_idx -> (ThrV,ThrM,ThrN,ThrK)
using TidLayout = decltype(right_inverse(ThrLayoutVMNK{}));
CUTE_HOST_DEVICE constexpr
TiledMMA(MMA_Atom const& mma_atom = {}, AtomLayoutMNK const& thr_layout_mnk = {})
: MMA_Atom(mma_atom),
thr_layout_vmnk_(tiled_product(AtomThrID{}, thr_layout_mnk)) {}
CUTE_HOST_DEVICE constexpr auto
get_thr_layout_vmnk() const {
return ThrLayoutVMNK{};
return thr_layout_vmnk_;
}
// Tile a tensor or a layout from shape
@ -243,17 +241,17 @@ struct TiledMMA : MMA_Atom
// RestM: The values tiled in M.
// RestN: The values tiled in N.
template <class CTensor>
CUTE_HOST_DEVICE constexpr static
CUTE_HOST_DEVICE constexpr
auto
thrfrg_C(CTensor&& ctensor)
thrfrg_C(CTensor&& ctensor) const
{
CUTE_STATIC_ASSERT_V(rank(ctensor) >= Int<2>{});
CUTE_STATIC_ASSERT_V(size<0>(ctensor) % size<0>(TiledShape_MNK{}) == Int<0>{});
CUTE_STATIC_ASSERT_V(size<1>(ctensor) % size<1>(TiledShape_MNK{}) == Int<0>{});
//CUTE_STATIC_ASSERT_V(size<0>(ctensor) % size<0>(TiledShape_MNK{}) == Int<0>{});
//CUTE_STATIC_ASSERT_V(size<1>(ctensor) % size<1>(TiledShape_MNK{}) == Int<0>{});
// Reorder the tensor for the TiledAtom
auto t_tile = make_tile(left_inverse(get<0>(PermutationsMNK{})),
left_inverse(get<1>(PermutationsMNK{})));
auto t_tile = make_tile(get<0>(PermutationMNK{}),
get<1>(PermutationMNK{}));
auto t_tensor = logical_divide(ctensor, t_tile); // (PermM,PermN)
// Tile the tensor for the Atom
@ -266,25 +264,13 @@ struct TiledMMA : MMA_Atom
// Tile the tensor for the C-threads
auto thr_tile = make_tile(_,
make_tile(make_layout(size<1>(ThrLayoutVMNK{})),
make_layout(size<2>(ThrLayoutVMNK{}))));
make_tile(make_layout(size<1>(thr_layout_vmnk_)),
make_layout(size<2>(thr_layout_vmnk_))));
auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrN)),(FrgV,(RestM,RestN)))
return thr_tensor;
}
// Tile from (M,N,...)
// to (thr_idx,(FrgV,(RestM,RestN,...)))
template <class CTensor>
CUTE_HOST_DEVICE constexpr static
auto
tidfrg_C(CTensor&& ctensor)
{
// Don't need a ctile composition because ThrK is last mode in TidLayout
return thrfrg_C(ctensor).compose(TidLayout{}, _);
}
// Tile a tensor or a layout from shape
// (M,K,...)
// to shape
@ -297,17 +283,17 @@ struct TiledMMA : MMA_Atom
// RestM: The values tiled in M.
// RestK: The values tiled in K.
template <class ATensor>
CUTE_HOST_DEVICE constexpr static
CUTE_HOST_DEVICE constexpr
auto
thrfrg_A(ATensor&& atensor)
thrfrg_A(ATensor&& atensor) const
{
CUTE_STATIC_ASSERT_V(rank(atensor) >= Int<2>{});
//CUTE_STATIC_ASSERT_V(size<0>(atensor) % size<0>(TiledShape_MNK{}) == Int<0>{});
//UTE_STATIC_ASSERT_V(size<1>(atensor) % size<2>(TiledShape_MNK{}) == Int<0>{});
//CUTE_STATIC_ASSERT_V(size<1>(atensor) % size<2>(TiledShape_MNK{}) == Int<0>{});
// Reorder the tensor for the TiledAtom
auto t_tile = make_tile(left_inverse(get<0>(PermutationsMNK{})),
left_inverse(get<2>(PermutationsMNK{})));
auto t_tile = make_tile(get<0>(PermutationMNK{}),
get<2>(PermutationMNK{}));
auto t_tensor = logical_divide(atensor, t_tile); // (PermM,PermK)
// Tile the tensor for the Atom
@ -320,29 +306,13 @@ struct TiledMMA : MMA_Atom
// Tile the tensor for the Thread
auto thr_tile = make_tile(_,
make_tile(make_layout(size<1>(ThrLayoutVMNK{})),
make_layout(size<3>(ThrLayoutVMNK{}))));
make_tile(make_layout(size<1>(thr_layout_vmnk_)),
make_layout(size<3>(thr_layout_vmnk_))));
auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK)))
return thr_tensor;
}
// Tile from (M,K,...)
// to (thr_idx,(FrgV,(RestM,RestK,...)))
template <class ATensor>
CUTE_HOST_DEVICE constexpr static
auto
tidfrg_A(ATensor&& atensor)
{
auto atile = make_tile(_,
make_tile(make_layout(make_shape (size<1>(ThrLayoutVMNK{}), size<2>(ThrLayoutVMNK{})),
make_stride( Int<1>{} , Int<0>{} )),
_));
// (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK))
return thrfrg_A(atensor).compose(atile, _).compose(TidLayout{}, _);
}
// Tile a tensor or a layout from shape
// (N,K,...)
// to shape
@ -355,17 +325,17 @@ struct TiledMMA : MMA_Atom
// RestN: The values tiled in N.
// RestK: The values tiled in K.
template <class BTensor>
CUTE_HOST_DEVICE constexpr static
CUTE_HOST_DEVICE constexpr
auto
thrfrg_B(BTensor&& btensor)
thrfrg_B(BTensor&& btensor) const
{
CUTE_STATIC_ASSERT_V(rank(btensor) >= Int<2>{});
//CUTE_STATIC_ASSERT_V(size<0>(btensor) % size<1>(TiledShape_MNK{}) == Int<0>{});
//CUTE_STATIC_ASSERT_V(size<1>(btensor) % size<2>(TiledShape_MNK{}) == Int<0>{});
// Reorder the tensor for the TiledAtom
auto t_tile = make_tile(left_inverse(get<1>(PermutationsMNK{})),
left_inverse(get<2>(PermutationsMNK{})));
auto t_tile = make_tile(get<1>(PermutationMNK{}),
get<2>(PermutationMNK{}));
auto t_tensor = logical_divide(btensor, t_tile); // (PermN,PermK)
// Tile the tensor for the Atom
@ -378,44 +348,28 @@ struct TiledMMA : MMA_Atom
// Tile the tensor for the Thread
auto thr_tile = make_tile(_,
make_tile(make_layout(size<2>(ThrLayoutVMNK{})),
make_layout(size<3>(ThrLayoutVMNK{}))));
make_tile(make_layout(size<2>(thr_layout_vmnk_)),
make_layout(size<3>(thr_layout_vmnk_))));
auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK)))
return thr_tensor;
}
// Tile from (N,K,...)
// to (thr_idx,(FrgV,(RestN,RestK,...)))
template <class BTensor>
CUTE_HOST_DEVICE constexpr static
template <class ThrIdx,
__CUTE_REQUIRES(is_integral<ThrIdx>::value)>
CUTE_HOST_DEVICE constexpr
auto
tidfrg_B(BTensor&& btensor)
get_slice(ThrIdx const& thr_idx) const
{
auto btile = make_tile(_,
make_tile(make_layout(make_shape (size<1>(ThrLayoutVMNK{}), size<2>(ThrLayoutVMNK{})),
make_stride( Int<0>{} , Int<1>{} )),
_));
// (ThrV,(ThrN,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK))
return thrfrg_B(btensor).compose(btile, _).compose(TidLayout{}, _);
auto thr_vmnk = thr_layout_vmnk_.get_flat_coord(thr_idx);
return ThrMMA<TiledMMA, decltype(thr_vmnk)>{*this, thr_vmnk};
}
template <class ThrIdx,
__CUTE_REQUIRES(is_integral<ThrIdx>::value)>
CUTE_HOST_DEVICE static constexpr
CUTE_HOST_DEVICE constexpr
auto
get_slice(ThrIdx const& thr_idx)
{
auto thr_vmnk = ThrLayoutVMNK{}.get_flat_coord(thr_idx);
return ThrMMA<TiledMMA, decltype(thr_vmnk)>(thr_vmnk);
}
template <class ThrIdx,
__CUTE_REQUIRES(is_integral<ThrIdx>::value)>
CUTE_HOST_DEVICE static constexpr
auto
get_thread_slice(ThrIdx const& thr_idx)
get_thread_slice(ThrIdx const& thr_idx) const
{
return get_slice(thr_idx);
}
@ -424,104 +378,144 @@ struct TiledMMA : MMA_Atom
// Utility for printing and visualization
//
CUTE_HOST_DEVICE constexpr static
// The size of the MNK-mode
template <int I>
CUTE_HOST_DEVICE constexpr
auto
get_layoutC_MN()
tile_size_mnk() const {
static_assert(0 <= I && I < 3);
auto core_size = size<I>(AtomShape_MNK{}) * size<I+1>(get_thr_layout_vmnk());
[[maybe_unused]] auto perm_size = size<I>(PermutationMNK{});
if constexpr (is_underscore<decltype(perm_size)>::value) {
return core_size;
} else {
return cute::max(core_size, perm_size);
}
}
CUTE_HOST_DEVICE constexpr
auto
get_layoutC_MN() const
{
// (M,N) -> (M,N)
auto ref_C = make_layout(make_shape(size<0>(TiledShape_MNK{}), size<1>(TiledShape_MNK{})));
auto ref_C = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<1>()));
// (cthrid,val) -> (M,N)
auto layoutC_TV = thrfrg_C(ref_C);
// (M,N) -> (cthrid,frg)
auto layoutC_MN = right_inverse(layoutC_TV).with_shape(shape(ref_C));
// cthrid = (v,m,n) -> thr_idx
auto thrID_C = ThrLayoutVMNK{}(_,_,_,Int<0>{});
auto thrID_C = thr_layout_vmnk_(_,_,_,Int<0>{});
return cute::make_tuple(layoutC_MN, thrID_C);
}
CUTE_HOST_DEVICE constexpr static
CUTE_HOST_DEVICE constexpr
auto
get_layoutC_TV()
get_layoutC_TV() const
{
// (M,N) -> (M,N)
auto ref_C = make_layout(make_shape(size<0>(TiledShape_MNK{}), size<1>(TiledShape_MNK{})));
auto ref_C = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<1>()));
// (cthrid,val) -> (M,N)
auto layoutC_TV = thrfrg_C(ref_C);
return tidfrg_C(ref_C);
// thr_idx -> (ThrV,ThrM,ThrN,ThrK)
auto thridx_2_thrid = right_inverse(thr_layout_vmnk_);
// (thr_idx,val) -> (M,N)
return layoutC_TV.compose(thridx_2_thrid, _);
}
CUTE_HOST_DEVICE constexpr static
CUTE_HOST_DEVICE constexpr
auto
get_layoutA_MK()
get_layoutA_MK() const
{
// (M,K) -> (M,K)
auto ref_A = make_layout(make_shape(size<0>(TiledShape_MNK{}), size<2>(TiledShape_MNK{})));
auto ref_A = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<2>()));
// (athrid,val) -> (M,K)
auto layoutA_TV = thrfrg_A(ref_A);
// (M,K) -> (athrid,frg)
auto layoutA_MK = right_inverse(layoutA_TV).with_shape(shape(ref_A));
// athrid = (v,m,k) -> thr_idx
auto thrID_A = ThrLayoutVMNK{}(_,_,Int<0>{},_);
auto thrID_A = thr_layout_vmnk_(_,_,Int<0>{},_);
return cute::make_tuple(layoutA_MK, thrID_A);
}
CUTE_HOST_DEVICE constexpr static
CUTE_HOST_DEVICE constexpr
auto
get_layoutA_TV()
get_layoutA_TV() const
{
// (M,K) -> (M,K)
auto ref_A = make_layout(make_shape(size<0>(TiledShape_MNK{}), size<2>(TiledShape_MNK{})));
auto ref_A = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<2>()));
// (athrid,val) -> (M,K)
auto layoutA_TV = thrfrg_A(ref_A);
return tidfrg_A(ref_A);
// (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK))
auto atile = make_tile(_,
make_tile(make_layout(make_shape (size<1>(thr_layout_vmnk_), size<2>(thr_layout_vmnk_)),
make_stride( Int<1>{} , Int<0>{} )),
_));
// thr_idx -> (ThrV,ThrM,ThrN,ThrK)
auto thridx_2_thrid = right_inverse(thr_layout_vmnk_);
// (thr_idx,val) -> (M,K)
return thrfrg_A(ref_A).compose(atile, _).compose(thridx_2_thrid, _);
}
CUTE_HOST_DEVICE constexpr static
CUTE_HOST_DEVICE constexpr
auto
get_layoutB_NK()
get_layoutB_NK() const
{
// (N,K) -> (N,K)
auto ref_B = make_layout(make_shape(size<1>(TiledShape_MNK{}), size<2>(TiledShape_MNK{})));
auto ref_B = make_layout(make_shape(tile_size_mnk<1>(), tile_size_mnk<2>()));
// (bthrid,val) -> (N,K)
auto layoutB_TV = thrfrg_B(ref_B);
// (N,K) -> (bthrid,frg)
auto layoutB_NK = right_inverse(layoutB_TV).with_shape(shape(ref_B));
// bthrid = (v,n,k) -> thr_idx
auto thrID_B = ThrLayoutVMNK{}(_,Int<0>{},_,_);
auto thrID_B = thr_layout_vmnk_(_,Int<0>{},_,_);
return cute::make_tuple(layoutB_NK, thrID_B);
}
CUTE_HOST_DEVICE constexpr static
CUTE_HOST_DEVICE constexpr
auto
get_layoutB_TV()
get_layoutB_TV() const
{
// (N,K) -> (N,K)
auto ref_B = make_layout(make_shape(size<1>(TiledShape_MNK{}), size<2>(TiledShape_MNK{})));
auto ref_B = make_layout(make_shape(tile_size_mnk<1>(), tile_size_mnk<2>()));
// (bthrid,val) -> (N,K)
auto layoutB_TV = thrfrg_B(ref_B);
return tidfrg_B(ref_B);
// (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK))
auto btile = make_tile(_,
make_tile(make_layout(make_shape (size<1>(thr_layout_vmnk_), size<2>(thr_layout_vmnk_)),
make_stride( Int<0>{} , Int<1>{} )),
_));
// thr_idx -> (ThrV,ThrM,ThrN,ThrK)
auto thridx_2_thrid = right_inverse(thr_layout_vmnk_);
// (thr_idx,val) -> (N,K)
return thrfrg_B(ref_B).compose(btile, _).compose(thridx_2_thrid, _);
}
};
template <class TiledMMA, class ThrVMNK>
struct ThrMMA : TiledMMA
{
// Use ThrVMNK and thrfrg rather than thr_idx and tidfrg
// to support swizzled threads partitioning dynamic layouts
ThrVMNK thr_vmnk_;
CUTE_HOST_DEVICE constexpr
ThrMMA(ThrVMNK const& thr_vmnk) : thr_vmnk_(thr_vmnk) {}
template <class CTensor>
CUTE_HOST_DEVICE constexpr
auto
partition_C(CTensor&& ctensor) const
{
auto thr_tensor = make_tensor(std::forward<CTensor>(ctensor).data(), TiledMMA::thrfrg_C(ctensor.layout()));
auto thr_tensor = make_tensor(std::forward<CTensor>(ctensor).data(), this->thrfrg_C(ctensor.layout()));
auto thr_vmn = make_coord(get<0>(thr_vmnk_), make_coord(get<1>(thr_vmnk_), get<2>(thr_vmnk_)));
return thr_tensor(thr_vmn, make_coord(_, repeat<rank<1,1>(thr_tensor)>(_)));
@ -532,7 +526,7 @@ struct ThrMMA : TiledMMA
auto
partition_A(ATensor&& atensor) const
{
auto thr_tensor = make_tensor(std::forward<ATensor>(atensor).data(), TiledMMA::thrfrg_A(atensor.layout()));
auto thr_tensor = make_tensor(std::forward<ATensor>(atensor).data(), this->thrfrg_A(atensor.layout()));
auto thr_vmk = make_coord(get<0>(thr_vmnk_), make_coord(get<1>(thr_vmnk_), get<3>(thr_vmnk_)));
return thr_tensor(thr_vmk, make_coord(_, repeat<rank<1,1>(thr_tensor)>(_)));
@ -543,7 +537,7 @@ struct ThrMMA : TiledMMA
auto
partition_B(BTensor&& btensor) const
{
auto thr_tensor = make_tensor(std::forward<BTensor>(btensor).data(), TiledMMA::thrfrg_B(btensor.layout()));
auto thr_tensor = make_tensor(std::forward<BTensor>(btensor).data(), this->thrfrg_B(btensor.layout()));
auto thr_vnk = make_coord(get<0>(thr_vmnk_), make_coord(get<2>(thr_vmnk_), get<3>(thr_vmnk_)));
return thr_tensor(thr_vnk, make_coord(_, repeat<rank<1,1>(thr_tensor)>(_)));
@ -580,38 +574,32 @@ struct ThrMMA : TiledMMA
template <class MMA_Op,
class MMAThrLayout = Layout<Shape<_1,_1,_1>>,
class MMAValLayout = Layout<Shape<_1,_1,_1>>,
class Permutations = Tile<Underscore,Underscore,Underscore>>
CUTE_HOST_DEVICE constexpr
auto
make_tiled_mma(MMA_Atom<MMA_Op> const&,
make_tiled_mma(MMA_Atom<MMA_Op> const& mma_atom,
MMAThrLayout const& thr_layout = {},
MMAValLayout const& val_layout = {},
Permutations const& permutations = {})
{
auto thr_layout_mnk = append<3>(thr_layout, Layout<_1,_0>{});
auto val_layout_mnk = append<3>(val_layout, Layout<_1,_0>{});
auto permutation_mnk = append<3>(permutations, _);
return TiledMMA<MMA_Atom<MMA_Op>,
decltype(thr_layout_mnk),
decltype(val_layout_mnk),
decltype(permutation_mnk)>{};
decltype(permutation_mnk)>{mma_atom, thr_layout_mnk};
}
template <class MMA_Op,
class MMAThrLayout = Layout<Shape<_1,_1,_1>>,
class MMAValLayout = Layout<Shape<_1,_1,_1>>,
class Permutations = Tile<Underscore,Underscore,Underscore>>
CUTE_HOST_DEVICE constexpr
auto
make_tiled_mma(MMA_Op const&,
MMAThrLayout const& thr_layout = {},
MMAValLayout const& val_layout = {},
Permutations const& permutations = {})
{
// Attempt to wrap in an MMA_Atom<> and forward
return make_tiled_mma(MMA_Atom<MMA_Op>{}, thr_layout, val_layout, permutations);
return make_tiled_mma(MMA_Atom<MMA_Op>{}, thr_layout, permutations);
}
//
@ -680,28 +668,38 @@ partition_shape_B(TiledMMA<Args...> const& mma, Shape_NK const& shape_NK)
// Size
//
template <int... I, class... Args>
template <int I, class... Args>
CUTE_HOST_DEVICE constexpr
auto
tile_size(TiledMMA<Args...> const& mma)
{
return size<I...>(typename TiledMMA<Args...>::TiledShape_MNK{});
return mma.template tile_size_mnk<I>();
}
template <int... I, class... Args>
template <class... Args>
CUTE_HOST_DEVICE constexpr
auto
tile_shape(TiledMMA<Args...> const& mma)
{
return shape<I...>(typename TiledMMA<Args...>::TiledShape_MNK{});
return make_shape(tile_size<0>(mma), tile_size<1>(mma), tile_size<2>(mma));
}
// Deprecate?
template <int... I, class... Args>
CUTE_HOST_DEVICE constexpr
auto
size(TiledMMA<Args...> const& mma)
{
return size<I...>(typename TiledMMA<Args...>::ThrLayoutVMNK{});
return size<I...>(mma.get_thr_layout_vmnk());
}
// Alias
template <int... I, class... Args>
CUTE_HOST_DEVICE constexpr
auto
thr_size(TiledMMA<Args...> const& mma)
{
return size<I...>(mma.get_thr_layout_vmnk());
}
//
@ -715,33 +713,31 @@ print(MMA_Atom<MMA_Traits<Args...>> const&)
{
using Atom = MMA_Atom<MMA_Traits<Args...>>;
print("MMA_Atom\n");
print(" ThrID: "); print(typename Atom::ThrID{}); print("\n");
print(" LayoutA_TV: "); print(typename Atom::LayoutA_TV{}); print("\n");
print(" LayoutB_TV: "); print(typename Atom::LayoutB_TV{}); print("\n");
print(" LayoutC_TV: "); print(typename Atom::LayoutC_TV{}); print("\n");
print(" ThrID: "); print(typename Atom::ThrID{}); print("\n");
print(" LayoutA_TV: "); print(typename Atom::LayoutA_TV{}); print("\n");
print(" LayoutB_TV: "); print(typename Atom::LayoutB_TV{}); print("\n");
print(" LayoutC_TV: "); print(typename Atom::LayoutC_TV{}); print("\n");
}
template <class Atom, class TiledThr, class TiledVal, class TiledPerm>
template <class Atom, class TiledThr, class TiledPerm>
CUTE_HOST_DEVICE
void
print(TiledMMA<Atom, TiledThr, TiledVal, TiledPerm> const& mma)
print(TiledMMA<Atom, TiledThr, TiledPerm> const& mma)
{
using MMA = TiledMMA<Atom, TiledThr, TiledVal, TiledPerm>;
print("TiledMMA\n");
print(" TiledThr: "); print(TiledThr{}); print("\n");
print(" TiledVal: "); print(TiledVal{}); print("\n");
print(" TiledPerm: "); print(TiledPerm{}); print("\n");
print(" TiledShape_MNK: "); print(typename MMA::TiledShape_MNK{}); print("\n");
print(" ThrLayoutVMNK: "); print(typename MMA::ThrLayoutVMNK{}); print("\n");
print(" ThrLayoutVMNK: "); print(mma.get_thr_layout_vmnk()); print("\n");
print(" PermutationMNK: "); print(TiledPerm{}); print("\n");
print(static_cast<Atom const&>(mma));
}
template <class TiledMMA, class ThrVMNK>
CUTE_HOST_DEVICE
void
print(ThrMMA<TiledMMA, ThrVMNK> const&)
print(ThrMMA<TiledMMA, ThrVMNK> const& thr_mma)
{
print(TiledMMA{});
print("ThrMMA\n");
print(" Thr VMNK: "); print(thr_mma.thr_vmnk_); print("\n");
print(static_cast<TiledMMA>(thr_mma));
}
template <class... Args>
@ -766,18 +762,6 @@ print_latex(TiledMMA<Args...> const& mma)
layoutB_NK, thrID_B);
}
// EXPERIMENTAL -- Doesn't work with Swizzled Thr TileMMAs...
template <class... Args>
CUTE_HOST_DEVICE
auto
print_latex_2(TiledMMA<Args...> const& mma)
{
print_latex_mma(typename TiledMMA<Args...>::TiledShape_MNK{},
mma.get_layoutC_TV(),
mma.get_layoutA_TV(),
mma.get_layoutB_TV());
}
// MNK MMA Layout to console printer -- 8-value color coded by thread
template <class LayoutC, class ThrIDC,
class LayoutA, class ThrIDA,
@ -943,122 +927,6 @@ print_latex_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and
printf(latex_footer);
}
// ThrVal MMA Layout to Latex TIKZ -- 8-value color coded by thread
template <class Shape_MNK,
class LayoutC, class LayoutA, class LayoutB>
CUTE_HOST_DEVICE
void
print_latex_mma(Shape_MNK const& shape_mnk,
LayoutC const& C, // (thr_idx,vid) -> (m,n)
LayoutA const& A, // (thr_idx,vid) -> (m,k)
LayoutB const& B) // (thr_idx,vid) -> (n,k)
{
CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{});
CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{});
CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{});
char const* latex_header =
"\\documentclass{standalone}\n"
"\\usepackage{tikz}\n"
"\\usetikzlibrary{external}\n"
"\\tikzexternalize\n"
"\\begin{document}\n"
"\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/.style={rectangle,draw=black,thick,minimum size=1cm,anchor=center}]\n\n";
char const* latex_footer =
"\\end{tikzpicture}\n"
"\\end{document}\n";
char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}",
"{rgb,255:red,175;green,255;blue,175}",
"{rgb,255:red,255;green,255;blue,175}",
"{rgb,255:red,255;green,175;blue,175}",
"{rgb,255:red,210;green,210;blue,255}",
"{rgb,255:red,210;green,255;blue,210}",
"{rgb,255:red,255;green,255;blue,210}",
"{rgb,255:red,255;green,210;blue,210}"};
// Header
printf("%% Shape_MNK: "); print(shape_mnk); printf("\n");
printf("%% LayoutC : "); print(C); printf("\n");
printf("%% LayoutA : "); print(A); printf("\n");
printf("%% LayoutB : "); print(B); printf("\n\n");
printf(latex_header);
auto M = size<0>(shape_mnk);
auto N = size<1>(shape_mnk);
auto K = size<2>(shape_mnk);
// C starting at 0,0
bool c_filled[M][N] = {};
for (int t = 0; t < size<0>(C); ++t) {
for (int v = 0; v < size<1>(C); ++v) {
int m = C(t,v) % M;
int n = C(t,v) / M;
if (not c_filled[m][n]) {
printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
color_map[t % 8],
m, n,
t, v);
c_filled[m][n] = true;
}
}
}
// A starting at 0,-size<1>(A)-1
bool a_filled[M][K] = {};
for (int t = 0; t < size<0>(A); ++t) {
for (int v = 0; v < size<1>(A); ++v) {
int m = A(t,v) % M;
int k = A(t,v) / M;
if (not a_filled[m][k]) {
printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
color_map[t % 8],
m, k - 1 - K,
t, v);
a_filled[m][k] = true;
}
}
}
// B starting at -size<1>(B)-1,0
bool b_filled[N][K] = {};
for (int t = 0; t < size<0>(B); ++t) {
for (int v = 0; v < size<1>(B); ++v) {
int n = B(t,v) % N;
int k = B(t,v) / N;
if (not b_filled[n][k]) {
printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
color_map[t % 8],
k - 1 - K, n,
t, v);
b_filled[n][k] = true;
}
}
}
// A labels
for (int m = 0, k = -1; m < M; ++m) {
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k - 1 - K, m);
}
for (int k = 0, m = -1; k < K; ++k) {
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k - 1 - K, k);
}
// B labels
for (int n = 0, k = -1; n < N; ++n) {
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", k - 1 - K, n, n);
}
for (int k = 0, n = -1; k < K; ++k) {
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", k - 1 - K, n, k);
}
// Footer
printf(latex_footer);
}
} // namespace cute
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -47,7 +47,7 @@
#endif
#if !defined(__CUDACC_RTC__) && !defined(__clang__) && \
(defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA))
(defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA))
# define CUTE_UNROLL #pragma unroll
# define CUTE_NO_UNROLL #pragma unroll 1
#elif defined(__CUDACC_RTC__) || defined(__clang__)
@ -120,6 +120,8 @@
#include <cassert>
#endif
#define CUTE_STATIC_V(x) decltype(x)::value
#define CUTE_STATIC_ASSERT static_assert
#define CUTE_STATIC_ASSERT_V(x,...) static_assert(decltype(x)::value, ##__VA_ARGS__)

View File

@ -91,7 +91,7 @@ private:
// Flag for fast branching on straddled elements
static constexpr bool is_storage_unaligned = ((sizeof_bits_v<storage_type> % sizeof_bits_v<element_type>) != 0);
friend class subbyte_iterator<T>;
friend struct subbyte_iterator<T>;
// Pointer to storage element
storage_type* ptr_ = nullptr;
@ -208,7 +208,7 @@ struct subbyte_iterator
private:
template <class, class> friend class swizzle_ptr;
template <class, class> friend struct swizzle_ptr;
// Pointer to storage element
storage_type* ptr_ = nullptr;

View File

@ -327,7 +327,7 @@ ceil_div(IntTupleA const& a, IntTupleB const& b)
{
if constexpr (is_tuple<IntTupleA>::value && is_tuple<IntTupleB>::value) {
static_assert(tuple_size<IntTupleA>::value >= tuple_size<IntTupleB>::value, "Mismatched ranks");
constexpr int R = tuple_size<IntTupleA>::value; // Missing ranks in TupleB are implictly 1
constexpr int R = tuple_size<IntTupleA>::value; // Missing ranks in TupleB are implicitly 1
return transform(a, append<R>(b,Int<1>{}), [](auto const& x, auto const& y) { return ceil_div(x,y); });
} else {
return (a + b - Int<1>{}) / b;
@ -336,6 +336,28 @@ ceil_div(IntTupleA const& a, IntTupleB const& b)
CUTE_GCC_UNREACHABLE;
}
//
// round_up
// Round @a a up to the nearest multiple of @a b.
// For negative numbers, rounds away from zero.
//
template <class IntTupleA, class IntTupleB>
CUTE_HOST_DEVICE constexpr
auto
round_up(IntTupleA const& a, IntTupleB const& b)
{
if constexpr (is_tuple<IntTupleA>::value && is_tuple<IntTupleB>::value) {
static_assert(tuple_size<IntTupleA>::value >= tuple_size<IntTupleB>::value, "Mismatched ranks");
constexpr int R = tuple_size<IntTupleA>::value; // Missing ranks in TupleB are implicitly 1
return transform(a, append<R>(b,Int<1>{}), [](auto const& x, auto const& y) { return round_up(x,y); });
} else {
return ((a + b - Int<1>{}) / b) * b;
}
CUTE_GCC_UNREACHABLE;
}
/** Division for Shapes
* Case Tuple Tuple:
* Perform shape_div element-wise
@ -429,6 +451,7 @@ template <class A, class B>
using is_congruent = decltype(congruent(declval<A>(), declval<B>()));
/** Test if two IntTuple have the similar profiles up to Shape A (hierarchical rank division)
* weakly_congruent is a partial order on A and B: A <= B
*/
template <class IntTupleA, class IntTupleB>
CUTE_HOST_DEVICE constexpr
@ -458,7 +481,7 @@ using is_weakly_congruent = decltype(weakly_congruent(declval<A>(), declval<B>()
/** Test if Shape B is compatible with Shape A:
* Any coordinate into A can also be used as a coordinate into B
* A <= B is a partially ordered set of factored shapes
* compatible is a partial order on A and B: A <= B
*/
template <class IntTupleA, class IntTupleB>
CUTE_HOST_DEVICE constexpr
@ -487,7 +510,8 @@ template <class A, class B>
using is_compatible = decltype(compatible(declval<A>(), declval<B>()));
/** Test if Shape B is weakly compatible with Shape A:
* Shape B divides Shape A at some level of refinement
* Shape B is a multiple of a shape that is compatible with Shape A
* weakly_compatible is a partial order on A and B: A <= B
*/
template <class IntTupleA, class IntTupleB>
CUTE_HOST_DEVICE constexpr
@ -502,7 +526,7 @@ weakly_compatible(IntTupleA const& a, IntTupleB const& b)
[](auto const&... z) { return (true_type{} && ... && z); });
}
} else if constexpr (is_integral<IntTupleA>::value) {
return a % size(b) == Int<0>{};
return size(b) % a == Int<0>{};
} else if constexpr (is_integral<IntTupleB>::value) {
return false_type{};
} else {

View File

@ -981,7 +981,6 @@ auto
composition(Layout<LShape,LStride> const& lhs,
Layout<RShape,RStride> const& rhs)
{
//return detail::composition_impl(flatten(lhs), rhs.shape(), rhs.stride());
return detail::composition_impl(lhs, rhs.shape(), rhs.stride());
}
@ -997,8 +996,8 @@ composition(Layout<LShape,LStride> const& lhs,
return detail::transform_layout(lhs, rhs, [](auto const& l, auto const& r) { return composition(l,r); }, make_seq<tuple_size<IntTuple>::value>{}, seq<>{}, seq<>{});
} else if constexpr (is_underscore<IntTuple>::value) {
return lhs;
} else {
return composition(lhs, make_layout(rhs));
} else if constexpr (is_integral<IntTuple>::value) {
return detail::composition_impl(lhs, rhs, Int<1>{});
}
CUTE_GCC_UNREACHABLE;
@ -1097,15 +1096,18 @@ inverse_seq(Shape const& shape, Stride const& stride, seq<Is...>)
auto next_I = cute::find_if(stride, [](auto a) { return is_constant<NextStride, decltype(a)>{}; });
if constexpr (next_I == decltype(rank(stride))::value) {
// If not found, return current seq
return seq<Is...>{};
} else {
// auto next_stride = get<next_I>(shape) * get<next_I>(stride);
// NOTE: Needed for g++-7
using next_stride = decltype(get<next_I>(shape) * get<next_I>(stride));
if constexpr (is_static<next_stride>::value) {
if constexpr (is_static<next_stride>::value && !is_constant<NextStride, next_stride>::value) {
// If next_stride is static and unique, then continue
return inverse_seq<next_stride::value>(shape, stride, seq<Is..., next_I>{});
} else {
// Else return current seq + next_I
return seq<Is..., next_I>{};
}
}
@ -1340,28 +1342,24 @@ template <class LShape, class LStride,
CUTE_HOST_DEVICE constexpr
auto
logical_divide(Layout<LShape,LStride> const& layout,
Layout<TShape,TStride> const& tile)
Layout<TShape,TStride> const& tiler)
{
//CUTE_STATIC_ASSERT_V(size(layout) % size(tile) == Int<0>{},
// "Tiling does not evenly divide the block");
// NOTE: With tiles that have stride-0, this doesn't have to be true
return composition(layout, make_layout(tile, complement(tile, size(layout))));
return composition(layout, make_layout(tiler, complement(tiler, size(layout))));
}
template <class LShape, class LStride, class IntTuple>
template <class LShape, class LStride, class Tiler>
CUTE_HOST_DEVICE constexpr
auto
logical_divide(Layout<LShape,LStride> const& layout,
IntTuple const& tile)
Tiler const& tiler)
{
if constexpr (is_tuple<IntTuple>::value) {
static_assert(tuple_size<IntTuple>::value <= Layout<LShape,LStride>::rank, "logical_divide: Too many modes in tile.");
return transform_layout(layout, tile, [](auto const& l, auto const& t) { return logical_divide(l,t); });
} else if constexpr (is_underscore<IntTuple>::value) {
if constexpr (is_tuple<Tiler>::value) {
static_assert(tuple_size<Tiler>::value <= Layout<LShape,LStride>::rank, "logical_divide: Too many modes in tiler.");
return transform_layout(layout, tiler, [](auto const& l, auto const& t) { return logical_divide(l,t); });
} else if constexpr (is_underscore<Tiler>::value) {
return layout;
} else if constexpr (is_integral<IntTuple>::value) {
return logical_divide(layout, make_layout(tile));
} else if constexpr (is_integral<Tiler>::value) {
return logical_divide(layout, make_layout(tiler));
}
CUTE_GCC_UNREACHABLE;
@ -1374,24 +1372,24 @@ logical_divide(Layout<LShape,LStride> const& layout,
//
template <class LShape, class LStride,
class Tile>
class Tiler>
CUTE_HOST_DEVICE constexpr
auto
zipped_divide(Layout<LShape,LStride> const& layout,
Tile const& tile)
Tiler const& tiler)
{
return tile_unzip(logical_divide(layout, tile), tile);
return tile_unzip(logical_divide(layout, tiler), tiler);
}
// Same as zipped_divide, but unpacks the second mode: ((BLK_A,BLK_B,...),a,b,...,x,y)
template <class LShape, class LStride,
class Tile>
class Tiler>
CUTE_HOST_DEVICE constexpr
auto
tiled_divide(Layout<LShape,LStride> const& layout,
Tile const& tile)
Tiler const& tiler)
{
auto div = zipped_divide(layout, tile);
auto div = zipped_divide(layout, tiler);
auto R = rank<1>(div);
return div(_, repeat<R>(_));
@ -1399,13 +1397,13 @@ tiled_divide(Layout<LShape,LStride> const& layout,
// Same as zipped_divide, but unpacks both modes: (BLK_A,BLK_B,...,a,b,...,x,y)
template <class LShape, class LStride,
class Tile>
class Tiler>
CUTE_HOST_DEVICE constexpr
auto
flat_divide(Layout<LShape,LStride> const& layout,
Tile const& tile)
Tiler const& tiler)
{
auto div = zipped_divide(layout, tile);
auto div = zipped_divide(layout, tiler);
auto R0 = rank<0>(div);
auto R1 = rank<1>(div);
@ -1421,24 +1419,24 @@ template <class LShape, class LStride,
CUTE_HOST_DEVICE constexpr
auto
logical_product(Layout<LShape,LStride> const& layout,
Layout<TShape,TStride> const& tile)
Layout<TShape,TStride> const& tiler)
{
return make_layout(layout, composition(complement(layout, size(layout)*cosize(tile)), tile));
return make_layout(layout, composition(complement(layout, size(layout)*cosize(tiler)), tiler));
}
template <class LShape, class LStride, class IntTuple>
template <class LShape, class LStride, class Tiler>
CUTE_HOST_DEVICE constexpr
auto
logical_product(Layout<LShape,LStride> const& layout,
IntTuple const& tile)
Tiler const& tiler)
{
if constexpr (is_tuple<IntTuple>::value) {
static_assert(tuple_size<IntTuple>::value <= Layout<LShape,LStride>::rank);
return transform_layout(layout, tile, [](auto const& l, auto const& t) { return logical_product(l,t); });
} else if constexpr (is_underscore<IntTuple>::value) {
if constexpr (is_tuple<Tiler>::value) {
static_assert(tuple_size<Tiler>::value <= Layout<LShape,LStride>::rank, "logical_product: Too many modes in tiler.");
return transform_layout(layout, tiler, [](auto const& l, auto const& t) { return logical_product(l,t); });
} else if constexpr (is_underscore<Tiler>::value) {
return layout;
} else if constexpr (is_integral<IntTuple>::value) {
return logical_product(layout, make_layout(tile));
} else if constexpr (is_integral<Tiler>::value) {
return logical_product(layout, make_layout(tiler));
}
CUTE_GCC_UNREACHABLE;
@ -1451,45 +1449,43 @@ logical_product(Layout<LShape,LStride> const& layout,
//
template <class LShape, class LStride,
class Tile>
class Tiler>
CUTE_HOST_DEVICE constexpr
auto
zipped_product(Layout<LShape,LStride> const& layout,
Tile const& tile)
Tiler const& tiler)
{
return tile_unzip(logical_product(layout, tile), tile);
return tile_unzip(logical_product(layout, tiler), tiler);
}
// Same as zipped_product, but unpacks the second mode: ((BLK_A,BLK_B,...),a,b,...,x,y)
template <class LShape, class LStride,
class Tile>
class Tiler>
CUTE_HOST_DEVICE constexpr
auto
tiled_product(Layout<LShape,LStride> const& layout,
Tile const& tile)
Tiler const& tiler)
{
auto div = zipped_product(layout, tile);
auto div = zipped_product(layout, tiler);
auto R = rank(tile);
auto R = rank<1>(div);
return div(_, repeat<R>(_));
}
// Attempts to reproduce layout "block" over layout "layout"
// That is, think of every element of "layout" as a "block"
// Attempts to reproduce a layout over a tiler
// That is, think of every element of "tiler" as a "layout"
// and return the layout of the resulting structure
template <class TShape, class TStride,
class UShape, class UStride>
CUTE_HOST_DEVICE constexpr
auto
blocked_product(Layout<TShape,TStride> const& block,
Layout<UShape,UStride> const& layout)
blocked_product(Layout<TShape,TStride> const& layout,
Layout<UShape,UStride> const& tiler)
{
constexpr int R = cute::max(rank_v<TShape>, rank_v<UShape>);
auto padded_block = append<R>(block);
auto padded_layout = append<R>(layout);
auto result = logical_product(padded_block, padded_layout);
auto result = logical_product(append<R>(layout), append<R>(tiler));
return coalesce(zip(get<0>(result), get<1>(result)), repeat<R>(Int<1>{}));
}
@ -1497,14 +1493,12 @@ template <class TShape, class TStride,
class UShape, class UStride>
CUTE_HOST_DEVICE constexpr
auto
raked_product(Layout<TShape,TStride> const& block,
Layout<UShape,UStride> const& layout)
raked_product(Layout<TShape,TStride> const& layout,
Layout<UShape,UStride> const& tiler)
{
constexpr int R = cute::max(rank_v<TShape>, rank_v<UShape>);
auto padded_block = append<R>(block);
auto padded_layout = append<R>(layout);
auto result = logical_product(padded_block, padded_layout);
auto result = logical_product(append<R>(layout), append<R>(tiler));
return coalesce(zip(get<1>(result), get<0>(result)), repeat<R>(Int<1>{}));
}

View File

@ -473,6 +473,16 @@ zipped_divide(ComposedLayout<A,O,B> const& a,
return composition(a.layout_a(), a.offset(), zipped_divide(a.layout_b(), b));
}
template <class A, class O, class B,
class Tile>
CUTE_HOST_DEVICE constexpr
auto
flat_divide(ComposedLayout<A,O,B> const& a,
Tile const& b)
{
return composition(a.layout_a(), a.offset(), flat_divide(a.layout_b(), b));
}
template <class A, class O, class B,
class Tile>
CUTE_HOST_DEVICE constexpr

View File

@ -181,7 +181,7 @@ template <class T0, class T1, class... Ts>
CUTE_HOST_DEVICE constexpr
auto
make_inttuple_iter(T0 const& t0, T1 const& t1, Ts const&... ts) {
return make_tuple_iter(cute::make_tuple(t0, t1, ts...));
return make_inttuple_iter(cute::make_tuple(t0, t1, ts...));
}
//

View File

@ -148,7 +148,9 @@ using _96 = Int<96>;
using _128 = Int<128>;
using _192 = Int<192>;
using _256 = Int<256>;
using _384 = Int<384>;
using _512 = Int<512>;
using _768 = Int<768>;
using _1024 = Int<1024>;
using _2048 = Int<2048>;
using _4096 = Int<4096>;

View File

@ -97,7 +97,7 @@ template <class T, class U,
__CUTE_REQUIRES(is_std_integral<T>::value &&
is_std_integral<U>::value)>
CUTE_HOST_DEVICE constexpr
auto
cute::common_type_t<T, U>
gcd(T t, U u) {
while (true) {
if (t == 0) { return u; }
@ -112,7 +112,7 @@ template <class T, class U,
__CUTE_REQUIRES(is_std_integral<T>::value &&
is_std_integral<U>::value)>
CUTE_HOST_DEVICE constexpr
auto
cute::common_type_t<T, U>
lcm(T const& t, U const& u) {
return (t / gcd(t,u)) * u;
}

View File

@ -233,14 +233,14 @@ CUTE_HOST_DEVICE void print(T const* const ptr)
template <class T>
CUTE_HOST_DEVICE void print(counting_iterator<T> ptr)
{
printf("counting_iter_"); print(ptr.n_);
printf("counting_iter("); print(ptr.n_); printf(")");
}
#if !defined(__CUDACC_RTC__)
template <class T>
CUTE_HOST std::ostream& operator<<(std::ostream& os, counting_iterator<T> ptr)
{
return os << "counting_iter_" << ptr.n_;
return os << "counting_iter(" << ptr.n_ << ")";
}
#endif // !defined(__CUDACC_RTC__)

View File

@ -990,13 +990,10 @@ CUTE_HOST_DEVICE void print_tensor(Tensor<Engine,Layout> const& tensor)
{
print(tensor); print(":\n");
auto format = get_format(tensor(0));
using type = typename decltype(format)::type;
if constexpr (Layout::rank == 1)
{
for (int m = 0; m < size(tensor); ++m) {
printf(format.format, format.digits, type(tensor(m)));
pretty_print(tensor(m));
printf("\n");
}
} else
@ -1004,7 +1001,7 @@ CUTE_HOST_DEVICE void print_tensor(Tensor<Engine,Layout> const& tensor)
{
for (int m = 0; m < size<0>(tensor); ++m) {
for (int n = 0; n < size<1>(tensor); ++n) {
printf(format.format, format.digits, type(tensor(m,n)));
pretty_print(tensor(m,n));
}
printf("\n");
}
@ -1013,7 +1010,7 @@ CUTE_HOST_DEVICE void print_tensor(Tensor<Engine,Layout> const& tensor)
{
print_tensor(tensor(_,_,0));
for (int k = 1; k < size<2>(tensor); ++k) {
for (int i = 0; i < format.digits*size<1>(tensor); ++i) { print("-"); } print("\n");
for (int i = 0; i < 5*size<1>(tensor); ++i) { print("-"); } print("\n");
print_tensor(tensor(_,_,k));
}
} else
@ -1021,7 +1018,7 @@ CUTE_HOST_DEVICE void print_tensor(Tensor<Engine,Layout> const& tensor)
{
print_tensor(tensor(_,_,_,0));
for (int p = 1; p < size<3>(tensor); ++p) {
for (int i = 0; i < format.digits*size<1>(tensor); ++i) { print("="); } print("\n");
for (int i = 0; i < 5*size<1>(tensor); ++i) { print("="); } print("\n");
print_tensor(tensor(_,_,_,p));
}
}

View File

@ -57,61 +57,6 @@ num_digits(int x)
10)))))))));
}
template <class T>
struct format_and_size {
using type = T;
char const* format;
int digits;
};
CUTE_HOST_DEVICE
format_and_size<int>
get_format(bool) {
return {"%*d", 3};
}
CUTE_HOST_DEVICE
format_and_size<int32_t>
get_format(int32_t) {
return {"%*d", 5};
}
CUTE_HOST_DEVICE
format_and_size<uint32_t>
get_format(uint32_t) {
return {"%*d", 5};
}
CUTE_HOST_DEVICE
format_and_size<int64_t>
get_format(int64_t) {
return {"%*d", 5};
}
CUTE_HOST_DEVICE
format_and_size<uint64_t>
get_format(uint64_t) {
return {"%*d", 5};
}
CUTE_HOST_DEVICE
format_and_size<float>
get_format(half_t) {
return {"%*.2f", 8};
}
CUTE_HOST_DEVICE
format_and_size<float>
get_format(float) {
return {"%*.2e", 10};
}
CUTE_HOST_DEVICE
format_and_size<double>
get_format(double) {
return {"%*.3e", 11};
}
//
// print dispatcher
//
@ -195,4 +140,54 @@ print(char const* format) {
printf("%s", format);
}
//
// pretty printing
//
template <class T>
CUTE_HOST_DEVICE void
pretty_print(T const& v) {
printf(" "); print(v);
}
CUTE_HOST_DEVICE void
pretty_print(bool const& v) {
printf("%*d", 3, int(v));
}
CUTE_HOST_DEVICE void
pretty_print(int32_t const& v) {
printf("%*d", 5, v);
}
CUTE_HOST_DEVICE void
pretty_print(uint32_t const& v) {
printf("%*d", 5, v);
}
CUTE_HOST_DEVICE void
pretty_print(int64_t const& v) {
printf("%*lld", 5, static_cast<long long>(v));
}
CUTE_HOST_DEVICE void
pretty_print(uint64_t const& v) {
printf("%*llu", 5, static_cast<unsigned long long>(v));
}
CUTE_HOST_DEVICE void
pretty_print(half_t const& v) {
printf("%*.2f", 8, float(v));
}
CUTE_HOST_DEVICE void
pretty_print(float const& v) {
printf("%*.2e", 10, v);
}
CUTE_HOST_DEVICE void
pretty_print(double const& v) {
printf("%*.3e", 11, v);
}
} // end namespace cute

View File

@ -122,6 +122,9 @@ using CUTE_STL_NAMESPACE::is_empty_v;
using CUTE_STL_NAMESPACE::invoke_result_t;
using CUTE_STL_NAMESPACE::common_type;
using CUTE_STL_NAMESPACE::common_type_t;
// <utility>
using CUTE_STL_NAMESPACE::declval;

View File

@ -47,6 +47,18 @@ namespace cutlass {
namespace arch {
////////////////////////////////////////////////////////////////////////////////////////////////////
// Enumerates the reserved named barriers to avoid potential conflicts
// This enum class specifies the NamedBarriers reserved by CUTLASS.
enum class ReservedNamedBarriers {
EpilogueBarrier = 0,
TransposeBarrier = 1,
TransformBarrier = 2,
StreamkBarrier0 = 3,
StreamkBarrier1 = 4
, FirstUserBarrier = StreamkBarrier1 + 1
};
class NamedBarrier {
// Data Members:
@ -60,9 +72,19 @@ class NamedBarrier {
public:
// Constructor for CUTLASS developers:
// effective barrier ID starts from 0
CUTLASS_DEVICE
NamedBarrier(uint32_t num_threads, ReservedNamedBarriers reserved_named_barriers)
: num_threads_(num_threads), id_(static_cast<uint32_t>(reserved_named_barriers)) {}
// Constructor for CUTLASS users:
// effective barrier ID starts from ReservedNamedBarrierCount
CUTLASS_DEVICE
NamedBarrier(uint32_t num_threads, uint32_t id = 0)
: num_threads_(num_threads), id_(id) {}
: num_threads_(num_threads), id_(id + ReservedNamedBarrierCount) {
CUTLASS_ASSERT(id + ReservedNamedBarrierCount <= HardwareMaxNumNamedBarriers && "Effective barrier_id should not exceed 16.");
}
CUTLASS_DEVICE
void arrive_and_wait() const {
@ -80,8 +102,52 @@ class NamedBarrier {
}
// Static variants
// Calling interface for CUTLASS users:
// effective barrier ID starts from ReservedNamedBarrierCount
CUTLASS_DEVICE
static void arrive_and_wait(uint32_t num_threads, uint32_t barrier_id) {
arrive_and_wait_internal(num_threads, barrier_id + ReservedNamedBarrierCount);
}
// Calling interface for CUTLASS developers:
// effective barrier ID starts from 0
CUTLASS_DEVICE
static void arrive_and_wait(uint32_t num_threads, ReservedNamedBarriers reserved_named_barriers) {
arrive_and_wait_internal(num_threads, static_cast<int>(reserved_named_barriers));
}
// Calling interface for CUTLASS users:
// effective barrier ID starts from ReservedNamedBarrierCount
CUTLASS_DEVICE
static void arrive(uint32_t num_threads, uint32_t barrier_id) {
arrive_internal(num_threads, barrier_id + ReservedNamedBarrierCount);
}
// Calling interface for CUTLASS developers:
// effective barrier ID starts from 0
CUTLASS_DEVICE
static void arrive(uint32_t num_threads, ReservedNamedBarriers reserved_named_barriers) {
arrive_internal(num_threads, static_cast<int>(reserved_named_barriers));
}
// Calling interface for CUTLASS users:
// effective barrier ID starts from ReservedNamedBarrierCount
CUTLASS_DEVICE
static void sync(uint32_t num_threads, uint32_t barrier_id) {
sync_internal(num_threads, barrier_id + ReservedNamedBarrierCount);
}
// Calling interface for CUTLASS developers:
// effective barrier ID starts from 0
CUTLASS_DEVICE
static void sync(uint32_t num_threads, ReservedNamedBarriers reserved_named_barriers) {
sync_internal(num_threads, static_cast<int>(reserved_named_barriers));
}
private:
CUTLASS_DEVICE
static void arrive_and_wait_internal(uint32_t num_threads, uint32_t barrier_id) {
#if CUDA_BARRIER_ENABLED
asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads));
#elif defined(__CUDA_ARCH__)
@ -90,7 +156,7 @@ class NamedBarrier {
}
CUTLASS_DEVICE
static void arrive(uint32_t num_threads, uint32_t barrier_id) {
static void arrive_internal(uint32_t num_threads, uint32_t barrier_id) {
#if CUDA_BARRIER_ENABLED
asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads));
#elif defined(__CUDA_ARCH__)
@ -99,9 +165,16 @@ class NamedBarrier {
}
CUTLASS_DEVICE
static void sync(uint32_t num_threads, uint32_t barrier_id) {
NamedBarrier::arrive_and_wait(num_threads, barrier_id);
static void sync_internal(uint32_t num_threads, uint32_t barrier_id) {
NamedBarrier::arrive_and_wait_internal(num_threads, barrier_id);
}
public:
// Currently we reserve 8 NamedBarriers for CUTLASS' own use cases,
// while leaving the renaming for general users.
static const uint32_t ReservedNamedBarrierCount = static_cast<uint32_t>(ReservedNamedBarriers::FirstUserBarrier);
static const uint32_t HardwareMaxNumNamedBarriers = 16;
};
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -80,7 +80,7 @@ public:
static int const kElementsPerStoredItem = int(sizeof(Storage) * 8) / sizeof_bits<T>::value;
/// Number of storage elements
static size_t const kStorageElements = N / kElementsPerStoredItem;
static size_t const kStorageElements = (N + kElementsPerStoredItem - 1) / kElementsPerStoredItem;
/// Number of logical elements
static size_t const kElements = N;

View File

@ -68,7 +68,7 @@ template <
struct NamedBarrierSync {
CUTLASS_DEVICE
static void sync() {
cutlass::arch::NamedBarrier::sync(ThreadCount, BarrierId);
cutlass::arch::NamedBarrier::sync(ThreadCount, static_cast<arch::ReservedNamedBarriers>(BarrierId));
}
};
@ -227,9 +227,9 @@ template <
uint32_t MaxNumNamedBarriers = 16
>
struct NamedBarrierManager {
static constexpr uint32_t HardwareMaxNumNamedBarriers = 16;
static_assert(MaxNumNamedBarriers <= HardwareMaxNumNamedBarriers);
static_assert(MaxNumNamedBarriers + Offset <= HardwareMaxNumNamedBarriers, "Barrier IDs cannot exceed 15");
static_assert(MaxNumNamedBarriers <= arch::NamedBarrier::HardwareMaxNumNamedBarriers);
static_assert(MaxNumNamedBarriers + Offset <= arch::NamedBarrier::HardwareMaxNumNamedBarriers, "Barrier IDs cannot exceed 15");
// Number of threads participating in the barrier
static constexpr uint32_t ThreadCount = ThreadCount_;

View File

@ -55,6 +55,7 @@
#include <cstring>
#endif
#include <cuda_bf16.h>
#include "cutlass/cutlass.h"
namespace cutlass {
@ -83,6 +84,28 @@ struct alignas(2) bfloat16_t {
return h;
}
private:
struct from_32_bit_integer_t {};
static constexpr from_32_bit_integer_t from_32_bit_integer{};
template<class T>
CUTLASS_HOST_DEVICE
explicit bfloat16_t(from_32_bit_integer_t, T x) {
static_assert(cutlass::platform::is_integral<T>::value && sizeof(T) == 4, "Requires 32-bit integer");
float flt = static_cast<float>(x);
uint32_t bits;
#if defined(__CUDA_ARCH__)
bits = reinterpret_cast<uint32_t &>(flt);
#else
std::memcpy(&bits, &flt, sizeof(bits));
#endif
storage = uint16_t(bits >> 16);
}
public:
/// Default constructor
bfloat16_t() = default;
@ -129,18 +152,10 @@ struct alignas(2) bfloat16_t {
/// Integer conversion - round toward nearest
CUTLASS_HOST_DEVICE
explicit bfloat16_t(int x) {
float flt = static_cast<float>(x);
uint32_t bits;
explicit bfloat16_t(int x) : bfloat16_t(from_32_bit_integer, x) {}
#if defined(__CUDA_ARCH__)
bits = reinterpret_cast<uint32_t &>(flt);
#else
std::memcpy(&bits, &flt, sizeof(bits));
#endif
storage = uint16_t(bits >> 16);
}
CUTLASS_HOST_DEVICE
explicit bfloat16_t(uint32_t x) : bfloat16_t(from_32_bit_integer, x) {}
/// Converts to float
CUTLASS_HOST_DEVICE

View File

@ -35,7 +35,6 @@
#pragma once
#include <cstdio>
#include <cuda_runtime_api.h>
#include "cutlass/cutlass.h"
#include "cutlass/trace.h"

View File

@ -28,6 +28,16 @@
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*
Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain
existing integrations of CUTLASS require C++11 host compilers.
Until this requirement can be lifted, certain headers with this annotation are required
to be remain consistent with C++11 syntax.
C++11 compatibility is enforced by this unit test: `cutlass_test_unit_core_cpp11`.
*/
#pragma once
#include <cuComplex.h>

View File

@ -0,0 +1,147 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Interface betweeen a CUTLASS device-wide operator and CUDA.
*/
#pragma once
#include <cuda_runtime_api.h>
#include "cutlass/cutlass.h"
#include "cutlass/trace.h"
#include "cutlass/platform/platform.h"
#if ! defined(__CUDACC_RTC__)
#include <cstdio>
#endif
#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8)))
# define CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED
#endif
/////////////////////////////////////////////////////////////////////////////////////////////////
//
// Macro-level guard for CUDA Host Adapter
//
#if !defined(CUTLASS_ENABLE_CUDA_HOST_ADAPTER)
#define CUTLASS_ENABLE_CUDA_HOST_ADAPTER false
#endif
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// This class defines an object which abstracts interactions between the CUTLASS device-wide GEMM and
/// CUDA. The intention is to enable CUTLASS to be used with both the CUDA Runtime API and CUDA Driver API.
struct CudaHostAdapter {
/// Limit the number of kernels
static constexpr int32_t kMaximumKernelCount = 4;
/// Maximum cluster size
static constexpr int MaxClusterSize = 32;
//
// Data members
//
/// Handles
void *kernel_handles[kMaximumKernelCount];
int32_t kernel_count = 0;
CudaHostAdapter() = default;
/// Dtor
virtual ~CudaHostAdapter() {}
/// Copy Ctor deleted
CudaHostAdapter(const CudaHostAdapter&) = delete;
/// Copy Assignment deleted
CudaHostAdapter& operator=(const CudaHostAdapter&) = delete;
/// Move ctor deleted
CudaHostAdapter(CudaHostAdapter&&) = delete;
/// Move assignment deleted
CudaHostAdapter& operator=(CudaHostAdapter&&) = delete;
/// Ctor
inline CudaHostAdapter(
void **kernel_handles_,
int32_t kernel_count_
):
kernel_count(kernel_count_)
{
CUTLASS_ASSERT(kernel_count >= 0);
for (int32_t i = 0; i < kernel_count && i < kMaximumKernelCount; ++i) {
kernel_handles[i] = kernel_handles_[i];
}
}
/// Queries the occupancy of a kernel
virtual Status query_occupancy(
int32_t *device_sms,
int32_t *sm_occupancy,
int32_t kernel_index,
int32_t thread_count,
int32_t smem_size) = 0;
/// Launches a kernel without using Threadblock Clusters.
virtual Status launch(
dim3 const grid_dims,
dim3 const block_dims,
size_t const smem_size,
cudaStream_t cuda_stream,
void** kernel_params,
int32_t kernel_index) = 0;
/// Launches a kernel using the CUDA Extensible Launch API and Threadblock Clusters.
virtual Status launch(
dim3 const grid_dims,
dim3 const cluster_dims,
dim3 const block_dims,
size_t const smem_size,
cudaStream_t cuda_stream,
void** kernel_params,
int32_t kernel_index) = 0;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,64 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/container/tuple.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective {
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
template <size_t I, class Tuple>
struct deduce_mixed_width_dtype {
static_assert(I >= 0u && I <= 2u, "Valid indices are 0, 1, and 2, which represent Operand, Scale, and Bias, respectively.");
private:
using underlying_tuple = cute::conditional_t<cute::is_tuple<Tuple>::value, Tuple, cute::tuple<Tuple>>;
static constexpr size_t valid_index = cute::min(I, cute::tuple_size_v<underlying_tuple> - 1);
public:
using type = cute::conditional_t<(I < cute::tuple_size_v<underlying_tuple>),
cute::tuple_element_t<valid_index, underlying_tuple>,
void>;
};
template <size_t I, class Tuple>
using deduce_mixed_width_dtype_t = typename deduce_mixed_width_dtype<I, Tuple>::type;
} // namespace detail
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective

View File

@ -187,6 +187,7 @@ constexpr bool is_tma_copy_engine() {
|| cute::is_base_of_v<cute::SM90_TMA_LOAD_IM2COL, GmemTiledCopy>
|| cute::is_base_of_v<cute::SM90_TMA_LOAD_IM2COL_MULTICAST, GmemTiledCopy>
|| cute::is_base_of_v<cute::SM90_TMA_STORE, GmemTiledCopy>
|| cute::is_base_of_v<cute::SM90_TMA_STORE_IM2COL, GmemTiledCopy>
) {
return true;
}

View File

@ -104,19 +104,22 @@ sm90_compute_tile_shape_or_override() {
if constexpr (cute::is_same_v<EpilogueTileType, EpilogueTileAuto>) {
if constexpr (detail::sm90_is_cooperative_v<Schedule>) {
using N_tile = decltype(cute::min(_32{}, get<1>(TileShape_MNK{})));
if constexpr (size<0>(TileShape_MNK{}) >= 128) {
return Shape<_128,_32>{};
return Shape<_128, N_tile>{};
}
else {
return Shape<_64,_32>{};
return Shape<_64, N_tile>{};
}
}
else if constexpr (detail::sm90_is_warp_specialized_v<Schedule>) {
if constexpr (sizeof_bits_v<ElementD> == 8) {
return Shape<_64,_64>{};
using N_tile = decltype(cute::min(_64{}, get<1>(TileShape_MNK{})));
return Shape<_64, N_tile>{};
}
else {
return Shape<_64,_32>{};
using N_tile = decltype(cute::min(_32{}, get<1>(TileShape_MNK{})));
return Shape<_64,N_tile>{};
}
}
else {
@ -265,6 +268,13 @@ struct Sm90TmaBuilderImpl {
using GmemStrideTypeC = cutlass::detail::TagToStrideC_t<GmemLayoutTagC>;
using GmemStrideTypeD = cutlass::detail::TagToStrideC_t<GmemLayoutTagD>;
using CopyOpS2G =
SM90_TMA_STORE
;
using CopyOpG2S =
SM90_TMA_LOAD
;
// TMA builder allows for passing callbacks directly, which is either a fusion::FusionCallbacks
// instance or a direct visitor implementation, e.g. fusion::Sm90LinearCombination
using FusionCallbacks =
@ -285,10 +295,10 @@ struct Sm90TmaBuilderImpl {
ElementD,
GmemStrideTypeD,
FusionCallbacks,
SM90_TMA_LOAD,
CopyOpG2S,
decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom<GmemStrideTypeC, ElementC, EpilogueTile_MN>()),
decltype(detail::sm90_get_smem_load_op_for_source<GmemStrideTypeC, ElementC>()),
SM90_TMA_STORE,
CopyOpS2G,
decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom<GmemStrideTypeD, ElementD, EpilogueTile_MN>()),
decltype(detail::sm90_get_smem_store_op_for_accumulator<GmemStrideTypeD, ElementD>())
>;
@ -400,6 +410,7 @@ template <
class ElementD,
class GmemLayoutTagD,
int AlignmentD,
class Schedule,
FloatRoundStyle RoundStyle
>
struct CollectiveBuilder<
@ -416,9 +427,11 @@ struct CollectiveBuilder<
ElementD,
GmemLayoutTagD,
AlignmentD,
NoSmemWarpSpecialized,
fusion::LinearCombination<ElementD,ElementCompute,ElementCompute,RoundStyle>,
void> {
Schedule,
fusion::LinearCombination<ElementD,ElementCompute,ElementC_,ElementCompute,RoundStyle>,
cute::enable_if_t<cute::is_same_v<Schedule, NoSmemWarpSpecialized> ||
cute::is_same_v<Schedule, NoSmemWarpSpecializedArray> ||
cute::is_same_v<Schedule, NoSmemWarpSpecializedGroup> >> {
// Passing void C disables source load
using ElementC = cute::conditional_t<cute::is_void_v<ElementC_>,
@ -433,12 +446,21 @@ struct CollectiveBuilder<
ElementD, FragmentSize, ElementAccumulator, ElementCompute,
ScaleType, RoundStyle, ElementC>;
using CollectiveOp = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter<
cutlass::epilogue::collective::DefaultEpilogue<
cutlass::detail::TagToStrideC_t<GmemLayoutTagC>,
cutlass::detail::TagToStrideC_t<GmemLayoutTagD>,
ThreadOp,
cutlass::gemm::EpilogueDefault>
using CollectiveOp = cute::conditional_t<
cute::is_same_v<Schedule, NoSmemWarpSpecialized>,
cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter<
cutlass::epilogue::collective::DefaultEpilogue<
cutlass::detail::TagToStrideC_t<GmemLayoutTagC>,
cutlass::detail::TagToStrideC_t<GmemLayoutTagD>,
ThreadOp,
cutlass::gemm::EpilogueDefault>>,
// Epilogue for Ptr-Array and Grouped Gemm
cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter<
cutlass::epilogue::collective::DefaultEpilogueArray<
cutlass::detail::TagToStrideC_t<GmemLayoutTagC>,
cutlass::detail::TagToStrideC_t<GmemLayoutTagD>,
ThreadOp,
Schedule>>
>;
};
@ -533,6 +555,9 @@ struct CollectiveBuilder<
FusionOperation,
void> {
private:
static_assert(cute::is_same_v<FusionOperation, fusion::LinearCombination<ElementD,ElementCompute,ElementC,ElementCompute>>,
"Auto schedule doesn't support fusion. Use one of the TmaWarpSpecialized schedules instead.");
// Pick No-Smem epilogue as the Auto Epilogue Schedule (Auto schedules do not guarantee best performance)
// since TMA epilogues are not compatible with non-TMA non-WS mainloops
using EpilogueSchedule = NoSmemWarpSpecialized;
@ -595,7 +620,7 @@ CollectiveBuilder<
cute::is_base_of_v<TmaWarpSpecializedCooperativeElementwiseBase, Schedule> >> {
private:
using FusionOp =
fusion::LinCombEltAct<Schedule::template ActivationFunctor, ElementD, ElementCompute, ElementCompute, Schedule::Round>;
fusion::LinCombEltAct<Schedule::template ActivationFunctor, ElementD, ElementCompute, ElementC, ElementCompute, Schedule::Round>;
using ImplSchedule =
cute::conditional_t<cute::is_base_of_v<TmaWarpSpecializedElementwiseBase, Schedule>,
TmaWarpSpecialized, TmaWarpSpecializedCooperative>;
@ -677,7 +702,7 @@ private:
GmemStrideTypeAux, typename Schedule::ElementT>());
using FusionOperationAux = fusion::LinCombPerRowBiasEltActAux<
GmemLayoutTagD, Schedule::template ActivationFunctor, ElementD, ElementCompute,
typename Schedule::ElementT, typename Schedule::ElementBias, ElementCompute
typename Schedule::ElementT, typename Schedule::ElementBias, ElementC_, ElementCompute
>;
using FusionCallbacksAux = fusion::FusionCallbacks<
DispatchPolicy, FusionOperationAux, TileShape_MNK, EpilogueTile_MN, SmemLayoutAtomAux, SmemCopyOpAux
@ -685,7 +710,7 @@ private:
using FusionOperationNoAux = fusion::LinCombPerRowBiasEltAct<
Schedule::template ActivationFunctor, ElementD, ElementCompute,
typename Schedule::ElementBias, ElementCompute
typename Schedule::ElementBias, ElementC_, ElementCompute
>;
using FusionCallbacksNoAux = fusion::FusionCallbacks<
DispatchPolicy, FusionOperationNoAux, TileShape_MNK, EpilogueTile_MN
@ -750,7 +775,7 @@ struct CollectiveBuilder<
GmemLayoutTagD,
AlignmentD,
cutlass::gemm::EpilogueTransposed,
fusion::LinearCombination<ElementD,ElementCompute,ElementCompute,RoundStyle>,
fusion::LinearCombination<ElementD,ElementCompute,ElementC_,ElementCompute,RoundStyle>,
void> {
// Passing void C disables source load
using ElementC = cute::conditional_t<cute::is_void_v<ElementC_>,

View File

@ -62,7 +62,7 @@ template <
class GmemLayoutTagD,
int AlignmentD,
class Schedule,
class FusionOpOrCallbacks = cutlass::epilogue::fusion::LinearCombination<ElementD,ElementCompute>,
class FusionOpOrCallbacks = cutlass::epilogue::fusion::LinearCombination<ElementD,ElementCompute,ElementC,ElementCompute>,
class Enable = void
>
struct CollectiveBuilder {

View File

@ -54,6 +54,7 @@ class CollectiveEpilogue {
#include "detail.hpp"
#include "default_epilogue.hpp"
#include "default_epilogue_array.hpp"
#include "epilogue_tensor_broadcast.hpp"
#include "sm70_epilogue_vectorized.hpp"
#include "sm90_epilogue_tma_warpspecialized.hpp"

View File

@ -131,8 +131,9 @@ public:
return true;
}
// Note: SharedStorage is unused for DefaultEpilogue
CUTLASS_HOST_DEVICE
DefaultEpilogue(Params const& params_)
DefaultEpilogue(Params const& params_, SharedStorage const& shared_storage = SharedStorage())
: params(params_), epilogue_op(params_.thread) { }
CUTLASS_DEVICE

View File

@ -0,0 +1,254 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Functor performing elementwise operations used by epilogues.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/detail.hpp"
#include "cute/tensor.hpp"
#include "cute/numeric/int.hpp"
#include "cutlass/trace.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace collective {
/////////////////////////////////////////////////////////////////////////////////////////////////
// Applies an element wise operation to all elements within the fragment
// and writes them out to destination storage.
template <
class StrideC_,
class StrideD_,
class ThreadEpilogueOp_,
class EpilogueSchedule_
>
class DefaultEpilogueArray {
public:
//
// Type Aliases
//
using EpilogueSchedule = EpilogueSchedule_;
// derived types of output thread level operator
using ThreadEpilogueOp = ThreadEpilogueOp_;
using ElementOutput = typename ThreadEpilogueOp::ElementOutput;
using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator;
using ElementCompute = typename ThreadEpilogueOp::ElementCompute;
using ElementScalar = ElementCompute;
using ElementC = typename ThreadEpilogueOp::ElementC;
using StrideC = StrideC_;
using ElementD = typename ThreadEpilogueOp::ElementD;
using StrideD = StrideD_;
using StridesC = cute::conditional_t<cute::is_same_v<EpilogueSchedule, NoSmemWarpSpecializedGroup>,
StrideC const*, StrideC>;
using StridesD = cute::conditional_t<cute::is_same_v<EpilogueSchedule, NoSmemWarpSpecializedGroup>,
StrideD const*, StrideD>;
using GmemTiledCopyC = void;
using GmemTiledCopyD = void;
static const int kOutputAlignment = ThreadEpilogueOp::kCount;
using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;
static_assert(cute::is_same_v<EpilogueSchedule, NoSmemWarpSpecializedGroup> ||
cute::is_same_v<EpilogueSchedule, NoSmemWarpSpecializedArray>, "Incompatible epilogue schedule.");
static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
struct SharedStorage { };
// Host side epilogue arguments
struct Arguments {
typename ThreadEpilogueOp::Params thread{};
ElementC const** ptr_C = nullptr;
StridesC dC{};
ElementD** ptr_D = nullptr;
StridesD dD{};
};
// Device side epilogue params
using Params = Arguments;
//
// Methods
//
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(
ProblemShape const&,
Arguments const& args,
[[maybe_unused]] void* workspace) {
return args;
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
return 0;
}
template <class ProblemShape>
static cutlass::Status
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream) {
return cutlass::Status::kSuccess;
}
template<class ProblemShape>
CUTLASS_HOST_DEVICE static bool
can_implement(
[[maybe_unused]] ProblemShape const& problem_shape,
[[maybe_unused]] Arguments const& args) {
return true;
}
CUTLASS_HOST_DEVICE
DefaultEpilogueArray(Params const& params_)
: params(params_), epilogue_op(params_.thread) { }
CUTLASS_DEVICE
bool
is_source_needed() {
return epilogue_op.is_source_needed();
}
template<
class ProblemShapeMNKL,
class BlockShapeMNK,
class BlockCoordMNKL,
class FrgEngine, class FrgLayout,
class TiledMma,
class ResidueMNK
>
CUTLASS_HOST_DEVICE void
operator()(
ProblemShapeMNKL problem_shape_mnkl,
BlockShapeMNK blk_shape_MNK,
BlockCoordMNKL blk_coord_mnkl,
cute::Tensor<FrgEngine, FrgLayout> const& accumulators,
TiledMma tiled_mma,
ResidueMNK residue_mnk,
int thread_idx,
[[maybe_unused]] char* smem_buf)
{
using namespace cute;
using X = Underscore;
static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static");
static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");
// Separate out problem shape for convenience
auto M = get<0>(problem_shape_mnkl);
auto N = get<1>(problem_shape_mnkl);
auto L = get<3>(problem_shape_mnkl);
// Batches are managed by using appropriate pointers to C and D matrices
const int32_t mock_L = 1;
const int32_t mock_l_coord = 0;
// Slice to get the tile this CTA is responsible for
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl;
StrideC stride_c;
StrideD stride_d;
if constexpr (cute::is_same_v<EpilogueSchedule, NoSmemWarpSpecializedGroup>) {
stride_c = detail::get_epilogue_stride<EpilogueSchedule>(params.dC[l_coord]);
stride_d = detail::get_epilogue_stride<EpilogueSchedule>(params.dD[l_coord]);
}
else {
stride_c = detail::get_epilogue_stride<EpilogueSchedule>(params.dC);
stride_d = detail::get_epilogue_stride<EpilogueSchedule>(params.dD);
}
// Represent the full output tensor
Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C[l_coord]), make_shape(M,N,mock_L), stride_c); // (m,n,l)
Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), make_shape(M,N,mock_L), stride_d); // (m,n,l)
Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)
Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)
Tensor gC = gC_mnl(_,_,m_coord,n_coord, mock_l_coord); // (BLK_M,BLK_N)
Tensor gD = gD_mnl(_,_,m_coord,n_coord, mock_l_coord); // (BLK_M,BLK_N)
// Partition source and destination tiles to match the accumulator partitioning
auto thr_mma = tiled_mma.get_thread_slice(thread_idx);
Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N)
Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N)
static_assert(is_static<FrgLayout>::value, "Accumulator layout must be static");
CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD),
"Source and destination must have the same number of elements.");
CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators),
"Accumulator count must have the same destination element count.");
// Make an identity coordinate tensor for predicating our output MN tile
auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD))));
Tensor tCcD = thr_mma.partition_C(cD);
// source is needed
if (epilogue_op.is_source_needed()) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(accumulators); ++i) {
if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) {
tCgD(i) = epilogue_op(accumulators(i), tCgC(i));
}
}
}
// source is not needed, avoid load
else {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(accumulators); ++i) {
if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) {
tCgD(i) = epilogue_op(accumulators(i));
}
}
}
}
private:
Params params;
ThreadEpilogueOp epilogue_op;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace collective
} // namespace epilogue
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -170,7 +170,8 @@ public:
[[maybe_unused]] TileCoordMNKL tile_coord_mnkl,
[[maybe_unused]] TiledMma tiled_mma,
[[maybe_unused]] int thread_idx,
[[maybe_unused]] TensorStorage& shared_tensors)
[[maybe_unused]] TensorStorage& shared_tensors,
[[maybe_unused]] int subtile_idx=-1)
{
return load_pipe_producer_state;
}
@ -202,7 +203,8 @@ public:
cute::Tensor<AccEngine,AccLayout> accumulators,
TiledMma tiled_mma,
int thread_idx,
TensorStorage& shared_tensors)
TensorStorage& shared_tensors,
int subtile_index = -1)
{
constexpr int BLK_M_RANK = cute::rank<0>(tile_shape_MNK);
auto m_max_coord = unwrap(cute::transform(make_seq<BLK_M_RANK>{}, [&](auto i) {

View File

@ -85,6 +85,7 @@ public:
using CopyAtomR2G = CopyAtomR2G_;
static const int kOutputAlignment = ThreadEpilogueOp::kCount;
using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;
static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
@ -179,7 +180,7 @@ public:
// synchronizing function for smem reads/writes
#if CUDA_BARRIER_ENABLED
auto synchronize = [] () { cutlass::arch::NamedBarrier::sync(typename TiledCopyS2R::TiledNumThr{}, 0); };
auto synchronize = [] () { cutlass::arch::NamedBarrier::sync(typename TiledCopyS2R::TiledNumThr{}, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
#else
auto synchronize = [] () { __syncthreads(); };
#endif

View File

@ -109,8 +109,8 @@ public:
using CopyOpR2S = CopyOpR2S_;
using ThreadEpilogueOp = typename epilogue::fusion::FusionCallbacksTraits<FusionCallbacks>::Operation;
using GmemTiledCopyC = SM90_TMA_LOAD;
using GmemTiledCopyD = SM90_TMA_STORE;
using GmemTiledCopyC = CopyOpG2S;
using GmemTiledCopyD = CopyOpS2G;
static_assert(!is_layout<EpilogueTile>::value && is_tuple<EpilogueTile>::value, "EpilogueTile must be a cute::Tile or cute::Shape");
static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]");
@ -198,12 +198,12 @@ public:
struct Params {
using TMA_C = decltype(make_tma_copy(
CopyOpG2S{},
make_tensor(static_cast<SmemElementC const*>(nullptr),
make_tensor(make_gmem_ptr(static_cast<SmemElementC const*>(nullptr)),
repeat_like(StrideC{}, int32_t(0)), StrideC{}),
SmemLayoutC{}(_,_,0)));
using TMA_D = decltype(make_tma_copy(
CopyOpS2G{},
make_tensor(static_cast<ElementD const*>(nullptr),
make_tensor(make_gmem_ptr(static_cast<ElementD const*>(nullptr)),
repeat_like(StrideD{}, int32_t(0)), StrideD{}),
SmemLayoutD{}(_,_,0)));
@ -225,14 +225,20 @@ public:
// Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_MNKL;
auto M_C =
size(M)
;
auto M_D =
size(M)
;
typename Params::TMA_C tma_load_c;
if constexpr (not cute::is_void_v<ElementC>) {
Tensor tensor_c = make_tensor(args.ptr_C, make_layout(make_shape(M,N,L), args.dC));
Tensor tensor_c = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M_C,N,L), args.dC));
tma_load_c = make_tma_copy(CopyOpG2S{}, tensor_c, SmemLayoutC{}(_,_,0));
}
Tensor tensor_d = make_tensor(args.ptr_D, make_layout(make_shape(M,N,L), args.dD));
Tensor tensor_d = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M_D,N,L), args.dD));
typename Params::TMA_D tma_store_d = make_tma_copy(
CopyOpS2G{},
tensor_d,
@ -332,25 +338,31 @@ public:
TileCoordMNKL tile_coord_mnkl,
TiledMma tiled_mma,
int thread_idx,
TensorStorage& shared_tensors) {
TensorStorage& shared_tensors,
int subtile_idx=-1) {
using namespace cute;
// Indexing variables
auto [M, N, K, L] = problem_shape_mnkl;
auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl;
auto coord_shape =
make_coord(m_coord, n_coord, l_coord)
;
// Tile residue
auto m_max_coord = unwrap(cute::transform(make_seq<cute::rank<0>(tile_shape_MNK)>{}, [&](auto i) {
auto m_max_coord = unwrap(cute::transform(make_seq<rank<0>(tile_shape_MNK)>{}, [&](auto i) {
return get<0,i>(problem_shape_mnkl) - get<0,i>(tile_shape_MNK) * get<0,i>(tile_coord_mnkl);
}));
auto n_max_coord = unwrap(cute::transform(make_seq<cute::rank<1>(tile_shape_MNK)>{}, [&](auto i) {
auto n_max_coord = unwrap(cute::transform(make_seq<rank<1>(tile_shape_MNK)>{}, [&](auto i) {
return get<1,i>(problem_shape_mnkl) - get<1,i>(tile_shape_MNK) * get<1,i>(tile_coord_mnkl);
}));
auto residue_mn = make_coord(m_max_coord, n_max_coord);
// Represent the full source tensor, slice to get the tile this CTA is currently responsible for
Tensor mC = params.tma_load_c.get_tma_tensor(make_shape(M,N,L)); // (M,N,L)
Tensor gC = local_tile(mC, take<0,2>(CtaTileMNK{}), make_coord(m_coord,n_coord,l_coord)); // (CTA_M,CTA_N)
Tensor mC_mn = params.tma_load_c.get_tma_tensor(make_shape(M,N,L)); // (M,N,L)
Tensor mC = coalesce(mC_mn, take<0,2>(CtaTileMNK{}));
Tensor gC = local_tile(mC, take<0,2>(CtaTileMNK{}), coord_shape); // (CTA_M,CTA_N)
// Apply epilogue subtile, get matching smem tensor
SmemElementC* ptr_sC = reinterpret_cast<SmemElementC*>(shared_tensors.smem_D.data());
@ -391,6 +403,9 @@ public:
for (int epi_n = 0; epi_n < size<3>(gC_epi); ++epi_n) {
CUTLASS_PRAGMA_UNROLL
for (int epi_m = 0; epi_m < size<2>(gC_epi); ++epi_m) {
if (subtile_idx != -1 && (epi_n * static_cast<int>(size<2>(gC_epi)) + epi_m) != subtile_idx) {
continue;
}
// Acquire the lock for this stage
constexpr uint16_t mcast_mask = 0;
uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state);
@ -449,31 +464,36 @@ public:
cute::Tensor<AccEngine,AccLayout> accumulators,
TiledMma tiled_mma,
int thread_idx,
TensorStorage& shared_tensors) {
TensorStorage& shared_tensors,
int subtile_idx=-1) {
using namespace cute;
using ElementAccumulator = typename AccEngine::value_type;
using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits<FusionCallbacks>::ElementCompute;
using ElementCompute = cute::conditional_t<cute::is_void_v<ElementCompute_>,ElementAccumulator,ElementCompute_>;
static_assert(is_rmem<AccEngine>::value, "Accumulator must be RF resident.");
static_assert(cute::rank(AccLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA,MMA_M,MMA_N)");
static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
static_assert(rank(AccLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA,MMA_M,MMA_N)");
static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
static_assert(is_static<TileShapeMNK>::value, "TileShapeMNK must be static");
static_assert(cute::rank(TileShapeMNK{}) == 3, "TileShapeMNK must be rank 3");
static_assert(cute::rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4");
static_assert(rank(TileShapeMNK{}) == 3, "TileShapeMNK must be rank 3");
static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4");
// Indexing variables
auto [M, N, K, L] = problem_shape_mnkl;
auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl;
auto mma_tile_m = size<0>(typename TiledMma::TiledShape_MNK{});
auto mma_tile_n = size<1>(typename TiledMma::TiledShape_MNK{});
auto mma_tile_m = tile_size<0>(tiled_mma);
auto mma_tile_n = tile_size<1>(tiled_mma);
auto epi_tile_m = size<0>(EpilogueTile{});
auto epi_tile_n = size<1>(EpilogueTile{});
auto coord_shape =
make_coord(m_coord, n_coord, l_coord)
;
// Represent the full output tensor, slice to get the tile this CTA is responsible for
Tensor mD = params.tma_store_d.get_tma_tensor(make_shape(M,N,L)); // (M,N,L)
Tensor gD = local_tile(mD, take<0,2>(CtaTileMNK{}), make_coord(m_coord,n_coord,l_coord)); // (CTA_M,CTA_N)
Tensor mD_mn = params.tma_store_d.get_tma_tensor(make_shape(M,N,L)); // (M,N,L)
Tensor mD = coalesce(mD_mn, take<0,2>(CtaTileMNK{}));
Tensor gD = local_tile(mD, take<0,2>(CtaTileMNK{}), coord_shape); // (CTA_M,CTA_N)
// Apply epilogue subtiling
Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
@ -530,11 +550,11 @@ public:
Tensor bSG_gD = thrblk_s2g.partition_D(gD_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N)
// Coordinate tensors and residue for tile quantization
auto m_max_coord = unwrap(cute::transform(make_seq<cute::rank<0>(CtaTileMNK{})>{}, [&](auto i) {
auto m_max_coord = unwrap(cute::transform(make_seq<rank<0>(CtaTileMNK{})>{}, [&](auto i) {
auto c_m = get<0,i>(problem_shape_mnkl) - get<0,i>(CtaTileMNK{}) * get<0,i>(tile_coord_mnkl);
return cute::max(0, c_m);
}));
auto n_max_coord = unwrap(cute::transform(make_seq<cute::rank<1>(CtaTileMNK{})>{}, [&](auto i) {
auto n_max_coord = unwrap(cute::transform(make_seq<rank<1>(CtaTileMNK{})>{}, [&](auto i) {
auto c_n = get<1,i>(problem_shape_mnkl) - get<1,i>(CtaTileMNK{}) * get<1,i>(tile_coord_mnkl);
return cute::max(0, c_n);
}));
@ -559,13 +579,13 @@ public:
tRS_cD,
tRS_rC
};
auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks<RefSrc>(cst_args);
auto cst_callbacks = fusion_callbacks.get_consumer_store_callbacks<RefSrc>(cst_args);
bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed();
bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed();
// Thread synchronizer for previously issued waits or fences
// to ensure visibility of smem reads/writes to threads or TMA unit
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(size(TiledMma{}), 0); };
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
// Predication for TMA store (one warp issues TMA store)
bool issue_tma_store = (thread_idx / NumThreadsPerWarp) == 0;
@ -594,9 +614,12 @@ public:
for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m) {
bool is_last_iteration = epi_m == size<2>(gD_epi)-1 && epi_n == size<3>(gD_epi)-1;
if (subtile_idx != -1 && (epi_n * static_cast<int>(size<2>(gD_epi)) + epi_m) != subtile_idx) {
continue;
}
// The current tile in accumulator
int mma_m = epi_m;
int mma_n = (epi_n * epi_tile_n) / mma_tile_n;
int mma_n = (epi_n * size<1>(EpilogueTile{})) / mma_tile_n;
Tensor tRS_rAcc_frg_mn = tRS_rAcc_frg(_,mma_m,mma_n);
if (is_producer_load_needed) {

View File

@ -46,6 +46,8 @@ namespace cutlass::epilogue {
//////////////////////////////////////////////////////////////////////////////
struct NoSmemWarpSpecialized {};
struct NoSmemWarpSpecializedArray {};
struct NoSmemWarpSpecializedGroup {};
struct TmaWarpSpecialized {};
struct TmaWarpSpecializedCooperative {};
// DEPRECATED schedules, will be removed in next release

View File

@ -50,6 +50,8 @@ struct FusionOperation {
// metadata types/queries that can be overrided
using ElementOutput = void;
using ElementCompute = void;
using ElementSource = void;
static constexpr bool IsSourceSupported = false;
using ElementScalar = void;
@ -96,11 +98,13 @@ struct ScaledAcc : FusionOperation {
template<
class ElementOutput_,
class ElementCompute_,
class ElementSource_ = ElementOutput_,
class ElementScalar_ = ElementCompute_,
FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest
>
struct LinearCombination
: ScaledAcc<ElementOutput_, ElementCompute_, ElementScalar_, RoundStyle_> {
using ElementSource = ElementSource_;
static constexpr bool IsSourceSupported = true;
};
@ -109,11 +113,12 @@ template<
template <class> class ActivationFn_,
class ElementOutput_,
class ElementCompute_,
class ElementSource_ = ElementOutput_,
class ElementScalar_ = ElementCompute_,
FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest
>
struct LinCombEltAct
: LinearCombination<ElementOutput_, ElementCompute_, ElementScalar_, RoundStyle_> {
: LinearCombination<ElementOutput_, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_> {
using ActivationFn = ActivationFn_<ElementCompute_>;
static constexpr bool IsEltActSupported = true;
};
@ -123,12 +128,13 @@ template<
class ElementOutput_,
class ElementCompute_,
class ElementBias_ = ElementOutput_,
class ElementSource_ = ElementOutput_,
class ElementScalar_ = ElementCompute_,
int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest
>
struct LinCombPerRowBias
: LinearCombination<ElementOutput_, ElementCompute_, ElementScalar_, RoundStyle_> {
: LinearCombination<ElementOutput_, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_> {
using ElementBias = ElementBias_;
static constexpr int AlignmentBias = AlignmentBias_;
static constexpr bool IsPerRowBiasSupported = true;
@ -140,13 +146,14 @@ template<
class ElementOutput_,
class ElementCompute_,
class ElementBias_ = ElementOutput_,
class ElementSource_ = ElementOutput_,
class ElementScalar_ = ElementCompute_,
int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest
>
struct LinCombPerRowBiasEltAct
: LinCombPerRowBias<ElementOutput_, ElementCompute_,
ElementBias_, ElementScalar_, AlignmentBias_, RoundStyle_> {
ElementBias_, ElementSource_, ElementScalar_, AlignmentBias_, RoundStyle_> {
using ActivationFn = ActivationFn_<ElementCompute_>;
static constexpr bool IsEltActSupported = true;
};
@ -160,6 +167,7 @@ template<
class ElementCompute_,
class ElementAux_ = ElementOutput_,
class ElementBias_ = ElementOutput_,
class ElementSource_ = ElementOutput_,
class ElementScalar_ = ElementCompute_,
int AlignmentAux_ = 128 / sizeof_bits_v<ElementAux_>,
int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
@ -167,7 +175,7 @@ template<
>
struct LinCombPerRowBiasEltActAux
: LinCombPerRowBiasEltAct<ActivationFn_, ElementOutput_, ElementCompute_,
ElementBias_, ElementScalar_, AlignmentBias_, RoundStyle_> {
ElementBias_, ElementSource_, ElementScalar_, AlignmentBias_, RoundStyle_> {
using ElementAux = ElementAux_;
using GmemLayoutTagAux = GmemLayoutTagAux_;
static constexpr int AlignmentAux = AlignmentAux_;
@ -180,6 +188,7 @@ template<
class ElementOutput_,
class ElementCompute_,
class ElementBias_ = ElementOutput_,
class ElementSource_ = ElementOutput_,
class ElementScalar_ = ElementCompute_, // per-row alpha/beta
int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
int AlignmentScalar_ = 128 / sizeof_bits_v<ElementScalar_>,
@ -187,7 +196,7 @@ template<
>
struct PerRowLinCombPerRowBiasEltAct
: LinCombPerRowBiasEltAct<ActivationFn_, ElementOutput_, ElementCompute_,
ElementBias_, ElementScalar_, AlignmentBias_, RoundStyle_> {
ElementBias_, ElementSource_, ElementScalar_, AlignmentBias_, RoundStyle_> {
static constexpr int AlignmentScalar = AlignmentScalar_;
static constexpr bool IsPerRowScaleSupported = true;
};
@ -202,13 +211,14 @@ template<
class ElementOutput_,
class ElementCompute_,
class ElementBias_ = ElementOutput_,
class ElementSource_ = ElementOutput_,
class ElementScalar_ = ElementCompute_,
int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest
>
struct ScaledLinCombPerRowBiasEltAct
: LinCombPerRowBiasEltAct<ActivationFn_, ElementOutput_, ElementCompute_,
ElementBias_, ElementScalar_, AlignmentBias_, RoundStyle_> {
ElementBias_, ElementSource_, ElementScalar_, AlignmentBias_, RoundStyle_> {
static constexpr bool IsScaleFactorSupported = true;
};
@ -231,6 +241,7 @@ template<
class ElementAux_ = ElementOutput_,
class ElementAmax_ = ElementCompute_,
class ElementBias_ = ElementOutput_,
class ElementSource_ = ElementOutput_,
class ElementScalar_ = ElementCompute_,
int AlignmentAux_ = 128 / sizeof_bits_v<ElementAux_>,
int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
@ -238,7 +249,7 @@ template<
>
struct ScaledLinCombPerRowBiasEltActAmaxAux
: ScaledLinCombPerRowBiasEltAct<ActivationFn_, ElementOutput_, ElementCompute_,
ElementBias_, ElementScalar_, AlignmentBias_, RoundStyle_> {
ElementBias_, ElementSource_, ElementScalar_, AlignmentBias_, RoundStyle_> {
using ElementAmax = ElementAmax_;
static constexpr bool IsAbsMaxSupported = true;
@ -257,12 +268,16 @@ template<
class ElementOutput_,
class ElementCompute_,
class ElementAux_ = ElementOutput_,
class ElementSource_ = ElementOutput_,
class ElementScalar_ = ElementCompute_,
int AlignmentAux_ = 128 / sizeof_bits_v<ElementAux_>,
FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest
>
struct LinCombDeEltAct
: LinCombEltAct<ActivationFn_, ElementOutput_, ElementCompute_, ElementScalar_, RoundStyle_> {
: LinearCombination<ElementOutput_, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_> {
using ActivationFn = ActivationFn_<ElementCompute_>;
static constexpr bool IsDeEltActSupported = true;
using ElementAux = ElementAux_;
using GmemLayoutTagAux = GmemLayoutTagAux_;
static constexpr int AlignmentAux = AlignmentAux_;
@ -280,6 +295,7 @@ template<
class ElementCompute_,
class ElementAux_ = ElementOutput_,
class ElementBias_ = ElementCompute_,
class ElementSource_ = ElementOutput_,
class ElementScalar_ = ElementCompute_,
int AlignmentAux_ = 128 / sizeof_bits_v<ElementAux_>,
int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
@ -287,7 +303,7 @@ template<
>
struct LinCombDeEltActDePerRowBias
: LinCombDeEltAct<GmemLayoutTagAux_, ActivationFn_, ElementOutput_, ElementCompute_,
ElementAux_, ElementScalar_, AlignmentAux_, RoundStyle_> {
ElementAux_, ElementSource_, ElementScalar_, AlignmentAux_, RoundStyle_> {
using ElementBias = ElementBias_;
static constexpr int AlignmentBias = AlignmentBias_;
static constexpr bool IsDePerRowBiasSupported = true;

View File

@ -113,13 +113,14 @@ struct FusionCallbacks<
template<
class ElementOutput,
class ElementCompute,
class ElementSource = ElementOutput,
class ElementScalar = ElementCompute,
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
>
using Sm90LinearCombination =
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + (alpha * acc)
Sm90ScalarBroadcast<ElementScalar>, // beta
Sm90SrcFetch, // C
Sm90SrcFetch<ElementSource>, // C
Sm90EVT<Sm90Compute<multiplies, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc
Sm90ScalarBroadcast<ElementScalar>, // alpha
Sm90AccFetch // acc
@ -133,6 +134,7 @@ template <
bool ReuseSmemC,
class ElementOutput,
class ElementCompute,
class ElementSource,
class ElementScalar,
FloatRoundStyle RoundStyle,
class CtaTileShapeMNK,
@ -140,13 +142,13 @@ template <
>
struct FusionCallbacks<
epilogue::Sm90TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmemC>,
fusion::LinearCombination<ElementOutput, ElementCompute, ElementScalar, RoundStyle>,
fusion::LinearCombination<ElementOutput, ElementCompute, ElementSource, ElementScalar, RoundStyle>,
CtaTileShapeMNK,
EpilogueTile
> : Sm90LinearCombination<typename cutlass::detail::get_unpacked_element_type<ElementOutput>::type, ElementCompute, ElementScalar, RoundStyle> {
> : Sm90LinearCombination<typename cutlass::detail::get_unpacked_element_type<ElementOutput>::type, ElementCompute, ElementSource, ElementScalar, RoundStyle> {
using Impl = Sm90LinearCombination<typename cutlass::detail::get_unpacked_element_type<ElementOutput>::type, ElementCompute, ElementScalar, RoundStyle>;
using Operation = fusion::LinearCombination<ElementOutput, ElementCompute, ElementScalar, RoundStyle>;
using Impl = Sm90LinearCombination<typename cutlass::detail::get_unpacked_element_type<ElementOutput>::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>;
using Operation = fusion::LinearCombination<ElementOutput, ElementCompute, ElementSource, ElementScalar, RoundStyle>;
struct Arguments {
ElementScalar alpha = ElementScalar(1);
@ -180,12 +182,13 @@ template<
template <class> class ActivationFn,
class ElementOutput,
class ElementCompute,
class ElementSource = ElementOutput,
class ElementScalar = ElementCompute,
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
>
using Sm90LinCombEltAct =
Sm90EVT<Sm90Compute<ActivationFn, ElementOutput, ElementCompute, RoundStyle>, // activation(beta * C + (alpha * acc))
Sm90LinearCombination<ElementCompute, ElementCompute, ElementScalar, RoundStyle> // beta * C + (alpha * acc)
Sm90LinearCombination<ElementCompute, ElementCompute, ElementSource, ElementScalar, RoundStyle> // beta * C + (alpha * acc)
>;
template <
@ -196,6 +199,7 @@ template <
template <class> class ActivationFn,
class ElementOutput,
class ElementCompute,
class ElementSource,
class ElementScalar,
FloatRoundStyle RoundStyle,
class CtaTileShapeMNK,
@ -203,13 +207,13 @@ template <
>
struct FusionCallbacks<
epilogue::Sm90TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmemC>,
fusion::LinCombEltAct<ActivationFn, ElementOutput, ElementCompute, ElementScalar, RoundStyle>,
fusion::LinCombEltAct<ActivationFn, ElementOutput, ElementCompute, ElementSource, ElementScalar, RoundStyle>,
CtaTileShapeMNK,
EpilogueTile
> : Sm90LinCombEltAct<ActivationFn, ElementOutput, ElementCompute, ElementScalar, RoundStyle> {
> : Sm90LinCombEltAct<ActivationFn, ElementOutput, ElementCompute, ElementSource, ElementScalar, RoundStyle> {
using Impl = Sm90LinCombEltAct<ActivationFn, typename cutlass::detail::get_unpacked_element_type<ElementOutput>::type, ElementCompute, ElementScalar, RoundStyle>;
using Operation = fusion::LinCombEltAct<ActivationFn, ElementOutput, ElementCompute, ElementScalar, RoundStyle>;
using Impl = Sm90LinCombEltAct<ActivationFn, typename cutlass::detail::get_unpacked_element_type<ElementOutput>::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>;
using Operation = fusion::LinCombEltAct<ActivationFn, ElementOutput, ElementCompute, ElementSource, ElementScalar, RoundStyle>;
struct Arguments {
ElementScalar alpha = ElementScalar(1);
@ -250,6 +254,7 @@ template<
class ElementOutput,
class ElementCompute,
class ElementBias = ElementOutput,
class ElementSource = ElementOutput,
class ElementScalar = ElementCompute,
int AlignmentBias = 128 / sizeof_bits_v<ElementBias>,
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
@ -257,7 +262,7 @@ template<
using Sm90LinCombPerRowBias =
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + (alpha * acc + bias)
Sm90ScalarBroadcast<ElementScalar>, // beta
Sm90SrcFetch, // C
Sm90SrcFetch<ElementSource>, // C
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc + bias
Sm90ScalarBroadcast<ElementScalar>, // alpha
Sm90AccFetch, // acc
@ -273,6 +278,7 @@ template <
class ElementOutput,
class ElementCompute,
class ElementBias,
class ElementSource,
class ElementScalar,
int AlignmentBias,
FloatRoundStyle RoundStyle,
@ -281,15 +287,15 @@ template <
>
struct FusionCallbacks<
epilogue::Sm90TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmemC>,
fusion::LinCombPerRowBias<ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle>,
fusion::LinCombPerRowBias<ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>,
CtaTileShapeMNK,
EpilogueTile
> : Sm90LinCombPerRowBias<
CtaTileShapeMNK, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle> {
CtaTileShapeMNK, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle> {
using Impl = Sm90LinCombPerRowBias<
CtaTileShapeMNK, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle>;
CtaTileShapeMNK, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>;
using Operation = fusion::LinCombPerRowBias<
ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle>;
ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>;
struct Arguments {
ElementScalar alpha = ElementScalar(1);
@ -330,13 +336,14 @@ template<
class ElementOutput,
class ElementCompute,
class ElementBias = ElementOutput,
class ElementSource = ElementOutput,
class ElementScalar = ElementCompute,
int AlignmentBias = 128 / sizeof_bits_v<ElementBias>,
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
>
using Sm90LinCombPerRowBiasEltAct =
Sm90EVT<Sm90Compute<ActivationFn, ElementOutput, ElementCompute, RoundStyle>,
Sm90LinCombPerRowBias<CtaTileShapeMNK, ElementCompute, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle>
Sm90LinCombPerRowBias<CtaTileShapeMNK, ElementCompute, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>
>;
template <
@ -348,6 +355,7 @@ template <
class ElementOutput,
class ElementCompute,
class ElementBias,
class ElementSource,
class ElementScalar,
int AlignmentBias,
FloatRoundStyle RoundStyle,
@ -357,21 +365,21 @@ template <
struct FusionCallbacks<
epilogue::Sm90TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmemC>,
fusion::LinCombPerRowBiasEltAct<
ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle
ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle
>,
CtaTileShapeMNK,
EpilogueTile
> : Sm90LinCombPerRowBiasEltAct<
CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle
CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle
> {
using Impl =
Sm90LinCombPerRowBiasEltAct<
CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle
CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle
>;
using Operation =
fusion::LinCombPerRowBiasEltAct<
ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle
ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle
>;
struct Arguments {
@ -426,6 +434,7 @@ template<
class ElementCompute,
class ElementAux = ElementOutput,
class ElementBias = ElementOutput,
class ElementSource = ElementOutput,
class ElementScalar = ElementCompute,
int AlignmentAux = 128 / sizeof_bits_v<ElementAux>,
int AlignmentBias = 128 / sizeof_bits_v<ElementBias>,
@ -434,7 +443,7 @@ template<
using Sm90LinCombPerRowBiasEltActAux =
Sm90EVT<Sm90Compute<ActivationFn, ElementOutput, ElementCompute, RoundStyle>,
Sm90EVT<Sm90AuxStore<Stages, EpilogueTile, ElementAux, RoundStyle, StrideAux, SmemLayoutAtom, CopyOpR2S, AlignmentAux>,
Sm90LinCombPerRowBias<CtaTileShapeMNK, ElementCompute, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle>
Sm90LinCombPerRowBias<CtaTileShapeMNK, ElementCompute, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>
>
>;
@ -449,6 +458,7 @@ template <
class ElementCompute,
class ElementAux,
class ElementBias,
class ElementSource,
class ElementScalar,
int AlignmentAux,
int AlignmentBias,
@ -462,7 +472,7 @@ struct FusionCallbacks<
epilogue::Sm90TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmemC>,
fusion::LinCombPerRowBiasEltActAux<
GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute,
ElementAux, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
>,
CtaTileShapeMNK,
EpilogueTile,
@ -470,18 +480,18 @@ struct FusionCallbacks<
CopyOpR2S
> : Sm90LinCombPerRowBiasEltActAux<
CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t<GmemLayoutTagAux>, SmemLayoutAtom, CopyOpR2S, ActivationFn,
ElementOutput, ElementCompute, ElementAux, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
> {
using Impl =
Sm90LinCombPerRowBiasEltActAux<
CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t<GmemLayoutTagAux>, SmemLayoutAtom, CopyOpR2S, ActivationFn,
ElementOutput, ElementCompute, ElementAux, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
>;
using Operation =
fusion::LinCombPerRowBiasEltActAux<
GmemLayoutTagAux, ActivationFn,
ElementOutput, ElementCompute, ElementAux, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
>;
struct Arguments {
@ -535,6 +545,7 @@ template<
class ElementOutput,
class ElementCompute,
class ElementBias = ElementOutput,
class ElementSource = ElementOutput,
class ElementScalar = ElementCompute,
int AlignmentBias = 128 / sizeof_bits_v<ElementBias>,
int AlignmentScalar = 128 / sizeof_bits_v<ElementScalar>,
@ -543,7 +554,7 @@ template<
using Sm90PerRowLinCombPerRowBias =
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + (alpha * acc + bias)
Sm90ColBroadcast<0, CtaTileShapeMNK, ElementScalar, Stride<_1,_0,_0>, AlignmentScalar>, // beta
Sm90SrcFetch, // C
Sm90SrcFetch<ElementSource>, // C
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc + bias
Sm90ColBroadcast<0, CtaTileShapeMNK, ElementScalar, Stride<_1,_0,_0>, AlignmentScalar>, // alpha
Sm90AccFetch, // acc
@ -558,6 +569,7 @@ template<
class ElementOutput,
class ElementCompute,
class ElementBias = ElementOutput,
class ElementSource = ElementOutput,
class ElementScalar = ElementCompute,
int AlignmentBias = 128 / sizeof_bits_v<ElementBias>,
int AlignmentScalar = 128 / sizeof_bits_v<ElementScalar>,
@ -566,7 +578,7 @@ template<
using Sm90PerRowLinCombPerRowBiasEltAct =
Sm90EVT<Sm90Compute<ActivationFn, ElementOutput, ElementCompute, RoundStyle>,
Sm90PerRowLinCombPerRowBias<CtaTileShapeMNK, ElementCompute, ElementCompute,
ElementBias, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle>
ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle>
>;
template <
@ -578,6 +590,7 @@ template <
class ElementOutput,
class ElementCompute,
class ElementBias,
class ElementSource,
class ElementScalar,
int AlignmentBias,
int AlignmentScalar,
@ -588,21 +601,21 @@ template <
struct FusionCallbacks<
epilogue::Sm90TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmemC>,
fusion::PerRowLinCombPerRowBiasEltAct<
ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle
ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle
>,
CtaTileShapeMNK,
EpilogueTile
> : Sm90PerRowLinCombPerRowBiasEltAct<
CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle
CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle
> {
using Impl =
Sm90PerRowLinCombPerRowBiasEltAct<
CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle
CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle
>;
using Operation =
fusion::PerRowLinCombPerRowBiasEltAct<
ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle
ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle
>;
struct Arguments {
@ -664,6 +677,7 @@ template<
class ElementOutput,
class ElementCompute,
class ElementBias = ElementOutput,
class ElementSource = ElementOutput,
class ElementScalar = ElementCompute,
int AlignmentBias = 128 / sizeof_bits_v<ElementBias>,
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
@ -671,7 +685,7 @@ template<
using Sm90ScaledLinCombPerRowBias =
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + (alpha * acc + bias)
Sm90ScalarBroadcast<ElementScalar, Stride<_0,_0,_0>, 2>, // scale_c * beta
Sm90SrcFetch, // C
Sm90SrcFetch<ElementSource>, // C
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc + bias
Sm90ScalarBroadcast<ElementScalar, Stride<_0,_0,_0>, 3>, // scale_a * scale_b * alpha
Sm90AccFetch, // acc
@ -690,6 +704,7 @@ template<
class ElementOutput,
class ElementCompute,
class ElementBias = ElementOutput,
class ElementSource = ElementOutput,
class ElementScalar = ElementCompute,
int AlignmentBias = 128 / sizeof_bits_v<ElementBias>,
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
@ -698,7 +713,7 @@ using Sm90ScaledLinCombPerRowBiasEltAct =
Sm90EVT<Sm90Compute<detail::ScaleOutOp<ElementOutput>::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d
Sm90EVT<Sm90Compute<ActivationFn, ElementCompute, ElementCompute, RoundStyle>, // activation(Z)
// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias
Sm90ScaledLinCombPerRowBias<CtaTileShapeMNK, ElementCompute, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle>
Sm90ScaledLinCombPerRowBias<CtaTileShapeMNK, ElementCompute, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>
>,
Sm90ScalarBroadcast<ElementScalar> // scale_d
>;
@ -712,6 +727,7 @@ template <
class ElementOutput,
class ElementCompute,
class ElementBias,
class ElementSource,
class ElementScalar,
int AlignmentBias,
FloatRoundStyle RoundStyle,
@ -721,21 +737,21 @@ template <
struct FusionCallbacks<
epilogue::Sm90TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmemC>,
fusion::ScaledLinCombPerRowBiasEltAct<
ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle
ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle
>,
CtaTileShapeMNK,
EpilogueTile
> : Sm90ScaledLinCombPerRowBiasEltAct<
CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle
CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle
> {
using Impl =
Sm90ScaledLinCombPerRowBiasEltAct<
CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle
CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle
>;
using Operation =
fusion::ScaledLinCombPerRowBiasEltAct<
ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle
ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle
>;
struct Arguments {
@ -819,6 +835,7 @@ template<
class ElementAux = ElementOutput,
class ElementAmax = ElementCompute,
class ElementBias = ElementOutput,
class ElementSource = ElementOutput,
class ElementScalar = ElementCompute,
int AlignmentAux = 128 / sizeof_bits_v<ElementAux>,
int AlignmentBias = 128 / sizeof_bits_v<ElementBias>,
@ -827,7 +844,7 @@ template<
using Sm90ScaledLinCombPerRowBiasEltActAmaxAux =
Sm90SplitTreeVisitor<
// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias
Sm90ScaledLinCombPerRowBias<CtaTileShapeMNK, ElementCompute, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle>,
Sm90ScaledLinCombPerRowBias<CtaTileShapeMNK, ElementCompute, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>,
// D = activation(Z) * scale_d, amax_d = max(abs(elements in D))
Sm90EVT<Sm90Compute<detail::ScaleOutOp<ElementOutput>::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d
Sm90EVT<Sm90ScalarReduction<detail::amax, atomic_maximum, ElementAmax, ElementCompute, RoundStyle>, // amax_d
@ -860,6 +877,7 @@ template <
class ElementAux,
class ElementAmax,
class ElementBias,
class ElementSource,
class ElementScalar,
int AlignmentAux,
int AlignmentBias,
@ -873,7 +891,7 @@ struct FusionCallbacks<
epilogue::Sm90TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmemC>,
fusion::ScaledLinCombPerRowBiasEltActAmaxAux<
GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute,
ElementAux, ElementAmax, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
>,
CtaTileShapeMNK,
EpilogueTile,
@ -882,19 +900,19 @@ struct FusionCallbacks<
> : Sm90ScaledLinCombPerRowBiasEltActAmaxAux<
CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t<GmemLayoutTagAux>,
SmemLayoutAtom, CopyOpR2S, ActivationFn,
ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
> {
using Impl =
Sm90ScaledLinCombPerRowBiasEltActAmaxAux<
CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t<GmemLayoutTagAux>,
SmemLayoutAtom, CopyOpR2S, ActivationFn,
ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
>;
using Operation =
fusion::ScaledLinCombPerRowBiasEltActAmaxAux<
GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute,
ElementAux, ElementAmax, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
>;
struct Arguments {
@ -1014,13 +1032,14 @@ template<
class ElementOutput,
class ElementCompute,
class ElementAux = ElementOutput,
class ElementSource = ElementOutput,
class ElementScalar = ElementCompute,
int AlignmentAux = 128 / sizeof_bits_v<ElementAux>,
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
>
using Sm90LinCombDeEltAct =
Sm90EVT<Sm90Compute<ActivationFn, ElementOutput, ElementCompute, RoundStyle>, // activation(beta * C + (alpha * acc), aux)
Sm90LinearCombination<ElementCompute, ElementCompute, ElementScalar, RoundStyle>, // beta * C + (alpha * acc)
Sm90LinearCombination<ElementCompute, ElementCompute, ElementSource, ElementScalar, RoundStyle>, // beta * C + (alpha * acc)
Sm90AuxLoad<Stages, EpilogueTile, ElementAux, StrideAux, SmemLayoutAtom, CopyOpS2R, AlignmentAux> // aux
>;
@ -1034,6 +1053,7 @@ template <
class ElementOutput,
class ElementCompute,
class ElementAux,
class ElementSource,
class ElementScalar,
int AlignmentAux,
FloatRoundStyle RoundStyle,
@ -1046,7 +1066,7 @@ struct FusionCallbacks<
epilogue::Sm90TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmemC>,
fusion::LinCombDeEltAct<
GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute,
ElementAux, ElementScalar, AlignmentAux, RoundStyle
ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle
>,
CtaTileShapeMNK,
EpilogueTile,
@ -1054,18 +1074,18 @@ struct FusionCallbacks<
CopyOpS2R
> : Sm90LinCombDeEltAct<
CtaTileShapeMNK, EpilogueTile, StagesC, cutlass::gemm::TagToStrideC_t<GmemLayoutTagAux>, SmemLayoutAtom, CopyOpS2R, ActivationFn,
ElementOutput, ElementCompute, ElementAux, ElementScalar, AlignmentAux, RoundStyle
ElementOutput, ElementCompute, ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle
> {
using Impl =
Sm90LinCombDeEltAct<
CtaTileShapeMNK, EpilogueTile, StagesC, cutlass::gemm::TagToStrideC_t<GmemLayoutTagAux>, SmemLayoutAtom, CopyOpS2R, ActivationFn,
ElementOutput, ElementCompute, ElementAux, ElementScalar, AlignmentAux, RoundStyle
ElementOutput, ElementCompute, ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle
>;
using Operation =
fusion::LinCombDeEltAct<
GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute,
ElementAux, ElementScalar, AlignmentAux, RoundStyle
ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle
>;
struct Arguments {
@ -1118,6 +1138,7 @@ template<
class ElementCompute,
class ElementAux = ElementOutput,
class ElementBias = ElementOutput,
class ElementSource = ElementOutput,
class ElementScalar = ElementCompute,
int AlignmentAux = 128 / sizeof_bits_v<ElementAux>,
int AlignmentBias = 128 / sizeof_bits_v<ElementBias>,
@ -1128,7 +1149,7 @@ using Sm90LinCombDeEltActDePerRowBias =
Sm90EVT<Sm90ColReduction<plus, plus, 0, CtaTileShapeMNK,
ElementBias, ElementCompute, RoundStyle, Stride<_1,_0,int>, AlignmentBias>,
Sm90LinCombDeEltAct<CtaTileShapeMNK, EpilogueTile, Stages, StrideAux, SmemLayoutAtom, CopyOpS2R, ActivationFn,
ElementCompute, ElementCompute, ElementAux, ElementScalar, AlignmentAux, RoundStyle>
ElementCompute, ElementCompute, ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle>
>
>;
@ -1143,6 +1164,7 @@ template <
class ElementCompute,
class ElementAux,
class ElementBias,
class ElementSource,
class ElementScalar,
int AlignmentAux,
int AlignmentBias,
@ -1156,7 +1178,7 @@ struct FusionCallbacks<
epilogue::Sm90TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmemC>,
fusion::LinCombDeEltActDePerRowBias<
GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute,
ElementAux, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
>,
CtaTileShapeMNK,
EpilogueTile,
@ -1164,18 +1186,18 @@ struct FusionCallbacks<
CopyOpS2R
> : Sm90LinCombDeEltActDePerRowBias<
CtaTileShapeMNK, EpilogueTile, StagesC, cutlass::gemm::TagToStrideC_t<GmemLayoutTagAux>, SmemLayoutAtom, CopyOpS2R, ActivationFn,
ElementOutput, ElementCompute, ElementAux, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
> {
using Impl =
Sm90LinCombDeEltActDePerRowBias<
CtaTileShapeMNK, EpilogueTile, StagesC, cutlass::gemm::TagToStrideC_t<GmemLayoutTagAux>, SmemLayoutAtom, CopyOpS2R, ActivationFn,
ElementOutput, ElementCompute, ElementAux, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
>;
using Operation =
fusion::LinCombDeEltActDePerRowBias<
GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute,
ElementAux, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
>;
struct Arguments {

View File

@ -215,16 +215,17 @@ template <
class StrideScalar,
int ScalarCount,
template <class> class ScalarReduceFn,
class ElementSource,
class InputAddOp // Z
>
struct Sm90TreeVisitor<
Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>,
Sm90ScalarBroadcast<ElementScalar, StrideScalar, ScalarCount, ScalarReduceFn>,
Sm90SrcFetch,
Sm90SrcFetch<ElementSource>,
InputAddOp
> : Sm90VisitorImpl<
Sm90ScalarBroadcast<ElementScalar, StrideScalar, ScalarCount, ScalarReduceFn>,
Sm90SrcFetch,
Sm90SrcFetch<ElementSource>,
InputAddOp,
Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>
>
@ -232,11 +233,10 @@ struct Sm90TreeVisitor<
using Impl =
Sm90VisitorImpl<
Sm90ScalarBroadcast<ElementScalar, StrideScalar, ScalarCount, ScalarReduceFn>,
Sm90SrcFetch,
Sm90SrcFetch<ElementSource>,
InputAddOp,
Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>
>;
using Params = typename Impl::Params;
using SharedStorage = typename Impl::SharedStorage;
@ -260,8 +260,9 @@ struct Sm90TreeVisitor<
CUTLASS_DEVICE bool
is_C_load_needed() const {
auto const& bcast_op = get<0>(Impl::ops);
auto const& src_op = get<1>(Impl::ops);
auto const& added_op = get<2>(Impl::ops);
return bcast_op.scalar != 0 || added_op.is_C_load_needed();
return (bcast_op.scalar != 0 && src_op.is_C_load_needed()) || added_op.is_C_load_needed();
}
template <class CallbacksImpl>
@ -320,17 +321,9 @@ struct Sm90TreeVisitor<
// ReLU with aux bit tensor dReLU/dZ
// Aux(i) = Z(i) >= 0 ? 1 : 0
namespace detail {
template <
class ElementOutput,
class ElementCompute,
FloatRoundStyle RoundStyle,
class StrideMNL,
int Alignment,
bool EnableNullptr
>
struct Sm90ReLUAuxStore {
static_assert(Alignment % 128 == 0, "sub-16B alignment not supported yet");
// Placeholder node so we can retain standard EVT structure
template <class StrideMNL>
struct Sm90ReLUAuxStore : Sm90VisitorImpl<> {
struct SharedStorage {};
struct Arguments {
@ -362,41 +355,90 @@ struct Sm90ReLUAuxStore {
Sm90ReLUAuxStore() { }
CUTLASS_HOST_DEVICE
Sm90ReLUAuxStore(Params const& params, SharedStorage const& shared_storage)
: params(params) { }
Sm90ReLUAuxStore(Params const& params, SharedStorage const& shared_storage) { }
};
} // namespace detail
Params const params;
// Specialization on the generic compute+aux EVT
template <
// Compute node
template <class> class Activation,
class ElementOutput,
class ElementCompute,
FloatRoundStyle RoundStyle,
// Aux node
int Stages,
class EpilogueTile,
class StrideMNL,
class SmemLayoutAtom,
class CopyOpR2S,
int Alignment,
bool EnableNullptr,
// Input node
class InputOp
>
struct Sm90TreeVisitor<
Sm90Compute<Activation, ElementOutput, ElementCompute, RoundStyle,
enable_if_t<is_same_v<Activation<ElementCompute>, cutlass::epilogue::thread::ReLu<ElementCompute>> ||
is_same_v<Activation<ElementCompute>, cutlass::epilogue::thread::Clamp<ElementCompute>> >>,
Sm90TreeVisitor<
Sm90AuxStore<
Stages,
EpilogueTile,
cutlass::uint1b_t,
RoundStyle,
StrideMNL,
SmemLayoutAtom,
CopyOpR2S,
Alignment,
EnableNullptr
>,
InputOp
>
> : Sm90VisitorImpl<
Sm90VisitorImpl<
InputOp,
detail::Sm90ReLUAuxStore<StrideMNL>
>,
Sm90Compute<Activation, ElementOutput, ElementCompute, RoundStyle>
>
{
using Impl =
Sm90VisitorImpl<
Sm90VisitorImpl<
InputOp,
detail::Sm90ReLUAuxStore<StrideMNL>
>,
Sm90Compute<Activation, ElementOutput, ElementCompute, RoundStyle>
>;
using Params = typename Impl::Params;
using SharedStorage = typename Impl::SharedStorage;
CUTLASS_DEVICE bool
is_producer_load_needed() const {
return false;
}
CUTLASS_HOST_DEVICE
Sm90TreeVisitor() {}
CUTLASS_DEVICE bool
is_C_load_needed() const {
return false;
}
CUTLASS_HOST_DEVICE
Sm90TreeVisitor(Params const& params_, SharedStorage const& shared_storage)
: params(params_), Impl(params_, shared_storage) {}
template <class... Args>
CUTLASS_DEVICE auto
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
return EmptyProducerLoadCallbacks{};
}
Params const& params;
template <class RTensor, class GTensor, class CTensor, class ResidueMN>
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
template <class RTensor, class GTensor, class CTensor, class ResidueMN, class CallbacksImpl>
struct ConsumerStoreCallbacks : CallbacksImpl {
CUTLASS_DEVICE
ConsumerStoreCallbacks(
RTensor&& tC_rAux,
GTensor&& tC_gAux,
CTensor tC_cAux,
ResidueMN residue_mn,
Params const& params)
Params const& params,
CallbacksImpl&& impl)
: tC_rAux(cute::forward<RTensor>(tC_rAux)),
tC_gAux(cute::forward<GTensor>(tC_gAux)),
tC_cAux(tC_cAux),
residue_mn(residue_mn),
params(params) {}
params(params),
CallbacksImpl(cute::forward<CallbacksImpl>(impl)) {}
RTensor tC_rAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
GTensor tC_gAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
@ -404,13 +446,23 @@ struct Sm90ReLUAuxStore {
ResidueMN residue_mn;
Params const& params;
template <typename ElementAccumulator, typename ElementInput, int FragmentSize>
CUTLASS_DEVICE auto
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n,
Array<ElementInput, FragmentSize> const& frg_input) {
template <typename ElementAccumulator, int FragmentSize>
CUTLASS_DEVICE Array<ElementOutput, FragmentSize>
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
// Unpack callbacks + params
auto& [callbacks_input_aux, callbacks_compute] = CallbacksImpl::callbacks_tuple;
auto& [callbacks_input, callbacks_aux] = callbacks_input_aux.callbacks_tuple;
auto const& [params_input_aux, params_compute] = params;
auto const& [params_input, params_aux] = params_input_aux;
// Visit the input node
Array frg_input = callbacks_input.visit(frg_acc, epi_v, epi_m, epi_n);
// Compute activation + aux
using ElementInput = typename decltype(frg_input)::Element;
using ConvertInput = NumericArrayConverter<ElementCompute, ElementInput, FragmentSize, RoundStyle>;
using ConvertAux = PackPredicates<FragmentSize>;
using ComputeOutput = cutlass::epilogue::thread::ReLu<ElementCompute>;
using ComputeOutput = Activation<ElementCompute>;
using ConvertOutput = NumericArrayConverter<ElementOutput, ElementCompute, FragmentSize, RoundStyle>;
ConvertInput convert_input{};
ComputeOutput relu{};
@ -422,7 +474,12 @@ struct Sm90ReLUAuxStore {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < FragmentSize; ++i) {
ElementCompute pre_relu = frg_compute[i];
frg_compute[i] = relu(frg_compute[i]);
if constexpr (is_same_v<Activation<ElementCompute>, cutlass::epilogue::thread::Clamp<ElementCompute>>) {
frg_compute[i] = relu(frg_compute[i], params_compute);
}
else {
frg_compute[i] = relu(frg_compute[i]);
}
frg_aux[i] = frg_compute[i] == pre_relu;
}
@ -435,8 +492,18 @@ struct Sm90ReLUAuxStore {
CUTLASS_DEVICE void
end() {
// Unpack callbacks + params
auto& [callbacks_input_aux, callbacks_compute] = CallbacksImpl::callbacks_tuple;
auto& [callbacks_input, callbacks_aux] = callbacks_input_aux.callbacks_tuple;
auto const& [params_input_aux, params_compute] = params;
auto const& [params_input, params_aux] = params_input_aux;
// Visit the input node
callbacks_input.end();
// Nullptr is no-op
if constexpr (EnableNullptr) {
if (params.ptr_aux == nullptr) {
if (params_aux.ptr_aux == nullptr) {
return;
}
}
@ -473,114 +540,25 @@ struct Sm90ReLUAuxStore {
>
CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
// Unpack params
auto const& [params_input_aux, params_compute] = params;
auto const& [params_input, params_aux] = params_input_aux;
auto [M, N, K, L] = args.problem_shape_mnkl;
auto [m, n, k, l] = args.tile_coord_mnkl;
gmem_ptr ptr_aux = make_gmem_ptr(subbyte_iterator<cutlass::uint1b_t>(params.ptr_aux));
Tensor mAux = make_tensor(ptr_aux, make_layout(make_shape(M,N,L), params.dAux)); // (M,N,L)
gmem_ptr ptr_aux = make_gmem_ptr(subbyte_iterator<cutlass::uint1b_t>(params_aux.ptr_aux));
Tensor mAux = make_tensor(ptr_aux, make_layout(make_shape(M,N,L), params_aux.dAux)); // (M,N,L)
Tensor gAux = local_tile(mAux, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N)
Tensor tC_gAux = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
gAux, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tC_rAux = make_tensor<cutlass::uint1b_t>(shape(tC_gAux)); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
return ConsumerStoreCallbacks<decltype(tC_rAux), decltype(tC_gAux), decltype(args.tCcD), decltype(args.residue_mn)>(
cute::move(tC_rAux), cute::move(tC_gAux), args.tCcD, args.residue_mn, params);
auto callbacks_impl = Impl::template get_consumer_store_callbacks<ReferenceSrc>(args);
return ConsumerStoreCallbacks<decltype(tC_rAux), decltype(tC_gAux), decltype(args.tCcD), decltype(args.residue_mn), decltype(callbacks_impl)>(
cute::move(tC_rAux), cute::move(tC_gAux), args.tCcD, args.residue_mn, params, cute::move(callbacks_impl));
}
};
} // namespace detail
// Specialization on the generic compute+aux EVT
template <
// Compute node
template <class> class Activation,
class ElementOutput,
class ElementCompute,
FloatRoundStyle RoundStyle,
// Aux node
int Stages,
class EpilogueTile,
class StrideMNL,
class SmemLayoutAtom,
class CopyOpR2S,
int Alignment,
bool EnableNullptr,
// Input node
class InputOp
>
struct Sm90TreeVisitor<
Sm90Compute<Activation, ElementOutput, ElementCompute, RoundStyle,
enable_if_t<is_same_v<Activation<ElementCompute>, cutlass::epilogue::thread::ReLu<ElementCompute>>, void>>,
Sm90TreeVisitor<
Sm90AuxStore<
Stages,
EpilogueTile,
cutlass::uint1b_t,
RoundStyle,
StrideMNL,
SmemLayoutAtom,
CopyOpR2S,
Alignment,
EnableNullptr
>,
InputOp
>
> : Sm90VisitorImpl<
Sm90VisitorImpl<
InputOp,
detail::Sm90ReLUAuxStore<ElementOutput, ElementCompute, RoundStyle, StrideMNL, Alignment, EnableNullptr>
>,
Sm90Compute<Activation, ElementOutput, ElementCompute, RoundStyle>
>
{
using Impl =
Sm90VisitorImpl<
Sm90VisitorImpl<
InputOp,
detail::Sm90ReLUAuxStore<ElementOutput, ElementCompute, RoundStyle, StrideMNL, Alignment, EnableNullptr>
>,
Sm90Compute<Activation, ElementOutput, ElementCompute, RoundStyle>
>;
using Params = typename Impl::Params;
using SharedStorage = typename Impl::SharedStorage;
CUTLASS_HOST_DEVICE
Sm90TreeVisitor() {}
CUTLASS_HOST_DEVICE
Sm90TreeVisitor(
Params const& params,
SharedStorage const& shared_storage)
: Impl(params, shared_storage) {}
template <class CallbacksImpl>
struct ConsumerStoreCallbacks : CallbacksImpl {
CUTLASS_DEVICE
ConsumerStoreCallbacks(CallbacksImpl&& impl)
: CallbacksImpl(cute::forward<CallbacksImpl>(impl)) { }
template <typename ElementAccumulator, int FragmentSize>
CUTLASS_DEVICE Array<ElementOutput, FragmentSize>
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
auto& [callbacks_input, callbacks_relu_aux] = get<0>(CallbacksImpl::callbacks_tuple).callbacks_tuple;
Array frg_input = callbacks_input.visit(frg_acc, epi_v, epi_m, epi_n);
return callbacks_relu_aux.visit(frg_acc, epi_v, epi_m, epi_n, frg_input);
}
};
template <
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
class... Args
>
CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
auto callbacks_tuple = Impl::template get_consumer_store_callbacks<ReferenceSrc>(args);
return ConsumerStoreCallbacks<decltype(callbacks_tuple)>(std::move(callbacks_tuple));
}
};
// Aux load for uint1b_t
template <

View File

@ -85,16 +85,17 @@ using Sm90SplitTreeFetch = Sm90AccFetch;
/////////////////////////////////////////////////////////////////////////////////////////////////
// returns C
template <class Element>
struct Sm90SrcFetch : Sm90VisitorImpl<> {
CUTLASS_DEVICE bool
is_producer_load_needed() const {
return true;
return is_C_load_needed();
}
CUTLASS_DEVICE bool
is_C_load_needed() const {
return true;
return not is_void_v<Element>;
}
using Sm90VisitorImpl<>::Sm90VisitorImpl;
@ -105,7 +106,6 @@ struct Sm90SrcFetch : Sm90VisitorImpl<> {
ConsumerStoreCallbacks(SrcTensor const& tCrC)
: tCrC(tCrC) {}
// make this a pointer if we need default ctor for generic tuple of visitors
SrcTensor const& tCrC; // (CPY,CPY_M,CPY_N)
template <typename ElementAccumulator, int FragmentSize>
@ -122,7 +122,7 @@ struct Sm90SrcFetch : Sm90VisitorImpl<> {
>
CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
// register type may differ from logical type so we can't assert matching types here
return ConsumerStoreCallbacks(args.tCrC);
}
};
@ -344,7 +344,6 @@ struct Sm90AuxLoad {
make_tensor(make_smem_ptr(smem_aux), SmemLayout{})); // (EPI_TILE_M,EPI_TILE_N,PIPE)
auto tSR_sAux = tiled_s2r.get_slice(args.thread_idx).partition_S(sAux_epi); // (S2R,S2R_M,S2R_N,PIPE)
return ConsumerStoreCallbacks<decltype(tC_rAux), decltype(tiled_s2r), decltype(tSR_sAux)>(
cute::move(tC_rAux), tiled_s2r, cute::move(tSR_sAux), params_ptr);
}

View File

@ -303,7 +303,7 @@ private:
static constexpr bool IsAtomic = is_atomic<GmemReduceFn<ElementCompute>>::value;
static_assert(IsAtomic, "non-atomic scalar reduction not supported yet");
public:
public:
struct SharedStorage { };
struct Arguments {
@ -814,7 +814,7 @@ public:
}
auto& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout,
lane_layout_MN, lane_mn, warp_layout_MN, warp_mn,
lane_layout_MN, lane_mn, warp_layout_MN, warp_mn,
tile_coord_mnkl, residue_mn, epi_tile, tiled_copy, thread_idx] = args_tuple;
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
Tensor tCcCol_mn = tCcCol(_,_,_,epi_m,epi_n);
@ -843,8 +843,8 @@ public:
return;
}
auto& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout,
lane_layout_MN, lane_mn, warp_layout_MN, warp_mn,
auto& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout,
lane_layout_MN, lane_mn, warp_layout_MN, warp_mn,
tile_coord_mnkl, residue_mn, epi_tile, tiled_copy, thread_idx] = args_tuple;
auto [m, n, k, l] = tile_coord_mnkl;
constexpr bool ReferenceSrc = decltype(ref_src)::value;
@ -1002,15 +1002,15 @@ public:
return;
}
auto& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout,
lane_layout_MN, lane_mn, warp_layout_MN, warp_mn,
auto& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout,
lane_layout_MN, lane_mn, warp_layout_MN, warp_mn,
tile_coord_mnkl, residue_mn, epi_tile, tiled_copy, thread_idx] = args_tuple;
using ReduceOutput = GmemReduceFn<ElementCompute>;
using ConvertOutput = NumericConverter<ElementOutput, ElementCompute, RoundStyle>;
ReduceOutput reduce_output{};
ConvertOutput convert_output{};
// Reduction over batches
if (size<2>(stride(gCol_l)) == 0) {
CUTLASS_PRAGMA_NO_UNROLL
@ -1051,8 +1051,8 @@ public:
CUTLASS_DEVICE bool
is_reduction_buffer_needed(int epi_m, int epi_n, bool is_last_iteration) const {
auto const& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout,
lane_layout_MN, lane_mn, warp_layout_MN, warp_mn,
auto const& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout,
lane_layout_MN, lane_mn, warp_layout_MN, warp_mn,
tile_coord_mnkl, residue_mn, epi_tile, tiled_copy, thread_idx] = args_tuple;
return (not IsAtomic && // atomic reduction doesn't use smem
@ -1111,7 +1111,7 @@ public:
auto args_tuple = make_tuple(
bool_constant<ReferenceSrc>{}, cute::move(tCrCol), args.tCcD, gCol_l, args.cD, gBuf_nl, sBuf_layout,
lane_layout_MN, lane_mn, warp_layout_MN, warp_mn,
lane_layout_MN, lane_mn, warp_layout_MN, warp_mn,
args.tile_coord_mnkl, args.residue_mn, args.epi_tile, args.tiled_copy, args.thread_idx);
return ConsumerStoreCallbacks<decltype(args_tuple)>(std::move(args_tuple), params);
}

View File

@ -563,7 +563,6 @@ struct Sm90TreeVisitor : Sm90VisitorImpl<ChildOps..., NodeOp> {
template get_consumer_store_callbacks<ReferenceSrc>(args);
return ConsumerStoreCallbacks<decltype(callbacks_tuple)>(std::move(callbacks_tuple));
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -614,7 +613,6 @@ struct Sm90SplitTreeVisitor : Sm90VisitorImpl<InputTree, AuxOutTrees..., OutputT
template get_consumer_store_callbacks<ReferenceSrc>(args);
return ConsumerStoreCallbacks<decltype(callbacks_tuple)>(std::move(callbacks_tuple));
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -692,7 +690,6 @@ struct Sm90TopologicalVisitor : Sm90VisitorImpl<Ops...> {
template get_consumer_store_callbacks<ReferenceSrc>(args);
return ConsumerStoreCallbacks<decltype(callbacks_tuple)>(std::move(callbacks_tuple));
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -49,7 +49,6 @@ namespace cutlass {
namespace epilogue {
namespace thread {
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
@ -66,13 +65,65 @@ struct ArrayMaximum {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ElementsPerAccess; ++i) {
result[i] = fmax(lhs[i], rhs[i]);
result[i] = platform::max(lhs[i].get(), rhs[i]);
}
return result;
}
CUTLASS_HOST_DEVICE
Array<Element, ElementsPerAccess> operator()(
Array<Element, ElementsPerAccess> const &lhs,
Element rhs) const {
Array<Element, ElementsPerAccess> result;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ElementsPerAccess; ++i) {
result[i] = platform::max(lhs[i].get(), rhs);
}
return result;
}
};
/// Partial specialization: Element=float
template <int ElementsPerAccess>
struct ArrayMaximum<float, ElementsPerAccess> {
CUTLASS_HOST_DEVICE
Array<float, ElementsPerAccess> operator()(
Array<float, ElementsPerAccess> const &lhs,
Array<float, ElementsPerAccess> const &rhs) const {
Array<float, ElementsPerAccess> result;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ElementsPerAccess; ++i) {
result[i] = fmax(lhs[i], rhs[i]);
}
return result;
}
CUTLASS_HOST_DEVICE
Array<float, ElementsPerAccess> operator()(
Array<float, ElementsPerAccess> const &lhs,
float rhs) const {
Array<float, ElementsPerAccess> result;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ElementsPerAccess; ++i) {
result[i] = fmax(lhs[i], rhs);
}
return result;
}
};
/// Partial specialization: Element=half
template <int ElementsPerAccess>
struct ArrayMaximum<half_t, ElementsPerAccess> {
@ -96,6 +147,8 @@ struct ArrayMaximum<half_t, ElementsPerAccess> {
res_ptr[i] = __hmax2(lhs_ptr[i], rhs_ptr[i]);
}
static_assert(!(ElementsPerAccess % 2), "Output array must be divisible by vector length.");
#else
__half const *lhs_ptr = reinterpret_cast<__half const *>(lhs.raw_data());
__half const *rhs_ptr = reinterpret_cast<__half const *>(rhs.raw_data());
@ -133,6 +186,8 @@ struct ArrayMaximum<half_t, ElementsPerAccess> {
res_ptr[i] = __hmax2(lhs_ptr[i], rhs_pair);
}
static_assert(!(ElementsPerAccess % 2), "Output array must be divisible by vector length.");
#else
__half const *lhs_ptr = reinterpret_cast<__half const *>(lhs.raw_data());
@ -150,6 +205,90 @@ struct ArrayMaximum<half_t, ElementsPerAccess> {
}
};
/// Partial specialization: Element=bfloat16_t
template <int ElementsPerAccess>
struct ArrayMaximum<bfloat16_t, ElementsPerAccess> {
using NvType = __nv_bfloat16;
using NvTypeV2 = __nv_bfloat162;
CUTLASS_DEVICE
Array<bfloat16_t, ElementsPerAccess> operator()(
Array<bfloat16_t, ElementsPerAccess> const &lhs,
Array<bfloat16_t, ElementsPerAccess> const &rhs) const {
Array<bfloat16_t, ElementsPerAccess> result;
#if __CUDA_ARCH__ >= 800
int const kVectorCount = ElementsPerAccess / 2;
NvTypeV2 const *lhs_ptr = reinterpret_cast<NvTypeV2 const *>(lhs.raw_data());
NvTypeV2 const *rhs_ptr = reinterpret_cast<NvTypeV2 const *>(rhs.raw_data());
NvTypeV2 *res_ptr = reinterpret_cast<NvTypeV2 *>(result.raw_data());
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kVectorCount; ++i) {
res_ptr[i] = __hmax2(lhs_ptr[i], rhs_ptr[i]);
}
#else
NvType const *lhs_ptr = reinterpret_cast<NvType const *>(lhs.raw_data());
NvType const *rhs_ptr = reinterpret_cast<NvType const *>(rhs.raw_data());
NvType *res_ptr = reinterpret_cast<NvType *>(result.raw_data());
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ElementsPerAccess; ++i) {
res_ptr[i] = ((lhs_ptr[i] < rhs_ptr[i]) ? rhs_ptr[i] : lhs_ptr[i]);
}
#endif
return result;
}
CUTLASS_DEVICE
Array<bfloat16_t, ElementsPerAccess> operator()(
Array<bfloat16_t, ElementsPerAccess> const &lhs,
bfloat16_t rhs) const {
Array<bfloat16_t, ElementsPerAccess> result;
#if __CUDA_ARCH__ >= 800
int const kVectorCount = ElementsPerAccess / 2;
NvType rhs_raw = reinterpret_cast<NvType const &>(rhs);
NvTypeV2 rhs_pair = __bfloat162bfloat162(rhs_raw);
NvTypeV2 const *lhs_ptr = reinterpret_cast<NvTypeV2 const *>(lhs.raw_data());
NvTypeV2 *res_ptr = reinterpret_cast<NvTypeV2 *>(result.raw_data());
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kVectorCount; ++i) {
res_ptr[i] = __hmax2(lhs_ptr[i], rhs_pair);
}
static_assert(!(ElementsPerAccess % 2), "Output array must be divisible by vector length.");
#else
NvType const *lhs_ptr = reinterpret_cast<NvType const *>(lhs.raw_data());
NvType const rhs_raw = reinterpret_cast<NvType const &>(rhs);
NvType *res_ptr = reinterpret_cast<NvType *>(result.raw_data());
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ElementsPerAccess; ++i) {
res_ptr[i] = ((lhs_ptr[i] < rhs_raw) ? rhs_raw : lhs_ptr[i]);
}
#endif
return result;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Element, int ElementsPerAccess>
@ -187,6 +326,25 @@ struct ReluConditional<half_t, ElementsPerAccess> {
}
};
template <int ElementsPerAccess>
struct ReluConditional<bfloat16_t, ElementsPerAccess> {
CUTLASS_DEVICE
void operator()(
bool conditional[],
Array<bfloat16_t, ElementsPerAccess> const &fragment,
bfloat16_t threshold) const {
__nv_bfloat16 y = reinterpret_cast<__nv_bfloat16 const &>(threshold);
__nv_bfloat16 const *x = reinterpret_cast<__nv_bfloat16 const *>(fragment.raw_data());
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ElementsPerAccess; ++i) {
conditional[i] = !__hlt(x[i], y);
}
}
};
} // namespace detail
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -92,9 +92,9 @@ public:
ElementCompute alpha; ///< scales accumulators
ElementCompute beta; ///< scales source tensor
ElementCompute threshold; ///< minimum value that is output
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
ElementCompute threshold; ///< minimum value that is output
//
// Methods
//

View File

@ -87,9 +87,9 @@ public:
ElementCompute alpha; ///< scales accumulators
ElementCompute beta; ///< scales source tensor
ElementCompute threshold; ///< minimum value that is output
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
ElementCompute threshold; ///< minimum value that is output
//
// Methods
//

View File

@ -1698,13 +1698,13 @@ private:
//
if (OutputOp::kStoreZ) {
destination_iterator += reduce_fragment_idx;
destination_iterator.store(frag_Z);
++destination_iterator;
}
if (OutputOp::kStoreT) {
tensor_iterator += reduce_fragment_idx;
tensor_iterator.store(frag_T);
++tensor_iterator;
}
}
};

View File

@ -187,7 +187,7 @@ CUTLASS_CONSTEXPR_IF_CXX17
value_t lcm(value_t a, value_t b) {
value_t temp = gcd(a, b);
return temp ? (cutlass::abs_for_integer(a) / temp * cutlass::abs_for_integer(b)) : 0;
return temp ? (cutlass::abs_for_integer(a) / temp * cutlass::abs_for_integer(b)) : value_t();
}
/**
@ -207,8 +207,11 @@ template <typename value_t>
CUTLASS_HOST_DEVICE
constexpr
value_t lcm_cxx11(value_t a, value_t b) {
return gcd_cxx11(a, b) ? (cutlass::abs_for_integer(a) / gcd_cxx11(a, b) * cutlass::abs_for_integer(b)) : 0;
return gcd_cxx11(a, b) ? (cutlass::abs_for_integer(a) / gcd_cxx11(a, b) *
cutlass::abs_for_integer(b))
: value_t();
}
/// Returns the smallest value in the half-open range [a, a+b) that is a multiple of b
CUTLASS_HOST_DEVICE
CUTLASS_CONSTEXPR_IF_CXX17

View File

@ -31,6 +31,7 @@
#pragma once
#include "cutlass/detail/layout.hpp"
#include "cutlass/detail/collective.hpp"
#include "cute/atom/mma_traits_sm90_gmma.hpp"
#include "cute/atom/copy_traits_sm90_tma.hpp"
@ -322,13 +323,21 @@ is_input_fp8() {
(cute::is_same_v<ElementB, float_e4m3_t> || cute::is_same_v<ElementB, float_e5m2_t>));
}
template <class ElementA, class LayoutA, class ElementB, class LayoutB>
// We need to handle the tuples in this function since it is used in SFINAE dispatch in the CollectiveBuilder.
// At that point, it is not guaranteed that the tuples have been split out into the required parts.
template <class MaybeTupleElementA, class LayoutA, class MaybeTupleElementB, class LayoutB>
constexpr bool
is_use_rmem_A() {
using ElementA = detail::deduce_mixed_width_dtype_t<0, MaybeTupleElementA>;
using ElementB = detail::deduce_mixed_width_dtype_t<0, MaybeTupleElementB>;
constexpr bool IsABDifferentWidth = cute::sizeof_bits_v<ElementA> != cute::sizeof_bits_v<ElementB>;
constexpr bool HasScales = cute::is_tuple<MaybeTupleElementA>::value ^ cute::is_tuple<MaybeTupleElementB>::value;
constexpr bool IsInputSizeTwoBytes = is_input_size_two_bytes<ElementA, ElementB>();
constexpr bool IsLayoutAkBk = cutlass::gemm::detail::is_k_major_A<LayoutA>() &&
cutlass::gemm::detail::is_k_major_B<LayoutB>();
constexpr bool IsUseRmemA = !IsInputSizeTwoBytes && !IsLayoutAkBk;
constexpr bool IsUseRmemA = (!IsInputSizeTwoBytes && !IsLayoutAkBk) || IsABDifferentWidth || HasScales;
return IsUseRmemA;
}

View File

@ -79,6 +79,50 @@ compute_stage_count_or_override(StageCountAutoCarveout<carveout_bytes> stage_cou
return (CapacityBytes - carveout_bytes) / stage_bytes;
}
// Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count.
template<int CapacityBytes, class ElementA, class ElementB, class ElementScale, class ElementZero, class TileShapeMNK, int stages>
constexpr int
compute_stage_count_or_override_single_affine_transformed_input(StageCount<stages> stage_count) {
return stages;
}
template <class Element>
constexpr int get_bits_for_possibly_void_element() {
if constexpr (cute::is_same_v<Element, void>) {
return 0;
}
else {
return sizeof_bits<Element>::value;
}
}
// Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count.
template<int CapacityBytes, class ElementA, class ElementB, class ElementScale, class ElementZero, class TileShapeMNK, int carveout_bytes>
constexpr int
compute_stage_count_or_override_single_affine_transformed_input(StageCountAutoCarveout<carveout_bytes> stage_count) {
// 32 bytes to account for barriers etc.
constexpr int stage_barrier_bytes = 32;
constexpr int scale_zero_k_tile = 1;
constexpr int a_bits = static_cast<int>(sizeof_bits<ElementA>::value);
constexpr int b_bits = static_cast<int>(sizeof_bits<ElementB>::value);
constexpr int s_bits = get_bits_for_possibly_void_element<ElementScale>();
constexpr int z_bits = get_bits_for_possibly_void_element<ElementZero>();
constexpr int scale_bytes = (s_bits * size<0>(TileShapeMNK{}) * scale_zero_k_tile) / 8;
constexpr int zero_bytes = (z_bits * size<0>(TileShapeMNK{}) * scale_zero_k_tile) / 8;
static_assert(scale_bytes % 128 == 0, "Scale bytes must be a multiple of 128");
static_assert(zero_bytes % 128 == 0, "Zero bytes must be a multiple of 128");
// When scales are void, s_bits will be 0 so no smem will be allocated for scales.
constexpr int stage_bytes =
(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 +
(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 +
scale_bytes + zero_bytes + stage_barrier_bytes;
return (CapacityBytes - carveout_bytes) / stage_bytes;
}
template <class ElementA, class LayoutA, class ElementB, class LayoutB>
constexpr bool
is_swapAB(){
@ -105,29 +149,6 @@ is_warpspecialized_transpose_B(){
return IsWarpSpecializedTransposeB;
}
template <typename ElementA, typename ElementB>
struct Sm90TypeWidths {
static constexpr bool IsElementALarger = (cute::sizeof_bits_v<ElementA>) > cute::sizeof_bits_v<ElementB>;
using WideType = cute::conditional_t<IsElementALarger, ElementA, ElementB>;
using NarrowType = cute::conditional_t<IsElementALarger, ElementB, ElementA>;
};
template <class ElementA, class LayoutA, class ElementB, class LayoutB>
constexpr bool
sm90_is_narrow_type_k_major() {
using Widths = Sm90TypeWidths<ElementA, ElementB>;
using NarrowType = typename Widths::NarrowType;
using WideType = typename Widths::WideType;
constexpr bool IsANarrow = cute::is_same_v<NarrowType, ElementA>;
constexpr cute::GMMA::Major NarrowGmmaMajor = IsANarrow ? detail::gmma_rs_tag_to_major_A<LayoutA>() :
detail::gmma_rs_tag_to_major_B<LayoutB>();
constexpr bool IsNarrowLayoutKMajor = NarrowGmmaMajor == cute::GMMA::Major::K;
return IsNarrowLayoutKMajor;
}
} // namespace detail
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -163,18 +184,24 @@ struct CollectiveBuilder<
cute::enable_if_t<
(cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong> ||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative>) &&
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative> ||
cute::is_same_v<KernelScheduleType, KernelArrayTmaWarpSpecializedCooperative> ||
cute::is_same_v<KernelScheduleType, KernelGroupTmaWarpSpecializedCooperative>) &&
not detail::is_use_rmem_A<ElementA, GmemLayoutA, ElementB, GmemLayoutB>()>
> {
static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
static_assert(cutlass::detail::dependent_false<ElementA> == 0, "Unsupported Toolkit for SM90 Collective Builder\n");
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
#endif
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
"Should meet TMA alignment requirement\n");
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
static constexpr bool IsArrayOfPointersGemm = (cute::is_same_v<KernelScheduleType, KernelArrayTmaWarpSpecializedCooperative> ||
cute::is_same_v<KernelScheduleType, KernelGroupTmaWarpSpecializedCooperative>);
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm),
"Kernel[Array/Group]TmaWarpSpecializedCooperative is only compatible with FP8 FastAccum version right now\n");
// For fp32 types, map to tf32 MMA value type
using MmaElementA = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
@ -183,7 +210,8 @@ static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<MmaElementA, GmemLayoutA>();
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<MmaElementB, GmemLayoutB>();
using AtomLayoutMNK = cute::conditional_t<cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative>,
using AtomLayoutMNK = cute::conditional_t<
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative> || IsArrayOfPointersGemm,
Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>;
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<
@ -199,10 +227,12 @@ static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes,
MmaElementA, MmaElementB, TileShape_MNK>(StageCountType{});
/* For FP8 use a separate mainloop compared to other datatypes */
using DispatchPolicy = cute::conditional_t<IsFP8Input,
MainloopSm90TmaGmmaWarpSpecializedFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>>;
using DispatchPolicy = cute::conditional_t<IsArrayOfPointersGemm,
MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
/* For FP8 use a separate mainloop compared to other datatypes */
cute::conditional_t<IsFP8Input,
MainloopSm90TmaGmmaWarpSpecializedFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>>>;
using SmemCopyAtomA = void;
using SmemCopyAtomB = void;
@ -267,7 +297,7 @@ struct CollectiveBuilder<
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
"Should meet TMA alignment requirement\n");
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
static_assert(cutlass::detail::dependent_false<ElementA> == 0, "Unsupported Toolkit for SM90 Collective Builder\n");
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
#endif
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_rs_tag_to_major_A<GmemLayoutA>();
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_rs_tag_to_major_B<GmemLayoutB>();
@ -323,13 +353,13 @@ struct CollectiveBuilder<
/////////////////////////////////////////////////////////////////////////////////////////////////
// GMMA_TMA_WS_RS Mixed GEMM
// GMMA_TMA_WS_RS Mixed Scaled GEMM
template <
class ElementPairA_,
class GmemLayoutPairA_,
class GmemLayoutA_,
int AlignmentA,
class ElementPairB_,
class GmemLayoutPairB_,
class GmemLayoutB_,
int AlignmentB,
class ElementAccumulator,
class TileShape_MNK,
@ -341,10 +371,10 @@ struct CollectiveBuilder<
arch::Sm90,
arch::OpClassTensorOp,
ElementPairA_,
GmemLayoutPairA_,
GmemLayoutA_,
AlignmentA,
ElementPairB_,
GmemLayoutPairB_,
GmemLayoutB_,
AlignmentB,
ElementAccumulator,
TileShape_MNK,
@ -357,22 +387,38 @@ struct CollectiveBuilder<
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeMixedInput>)>
> {
private:
using ScaleA = detail::deduce_mixed_width_dtype_t<1, ElementPairA_>;
using ScaleB = detail::deduce_mixed_width_dtype_t<1, ElementPairB_>;
using ZeroA = detail::deduce_mixed_width_dtype_t<2, ElementPairA_>;
using ZeroB = detail::deduce_mixed_width_dtype_t<2, ElementPairB_>;
static constexpr bool NeitherIsTuple = !cute::is_tuple<ElementPairA_>::value && !cute::is_tuple<ElementPairB_>::value;
public:
static constexpr bool IsATransformed = cute::sizeof_bits_v<ElementPairA_> < cute::sizeof_bits_v<ElementPairB_>;
using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementPairA_>;
using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementPairB_>;
static_assert(cute::is_tuple<ElementPairA_>::value ^ cute::is_tuple<ElementPairB_>::value ||
(NeitherIsTuple && (sizeof_bits<ElementA>::value != sizeof_bits<ElementB>::value)),
"Either A OR B must be a tuple or the widths of A and B must be different.");
// Split out items for processessing, no splitting for now since scales aren't supported.
using ElementA = ElementPairA_;
using ElementB = ElementPairB_;
static constexpr bool IsANarrow = sizeof_bits<ElementA>::value < sizeof_bits<ElementB>::value;
using GmemLayoutA = GmemLayoutPairA_;
using GmemLayoutB = GmemLayoutPairB_;
using GmemLayoutA = GmemLayoutA_;
using GmemLayoutB = GmemLayoutB_;
using ElementPairA = cute::conditional_t<IsANarrow && NeitherIsTuple, cute::tuple<ElementA>, ElementPairA_>;
using ElementPairB = cute::conditional_t<!IsANarrow && NeitherIsTuple, cute::tuple<ElementB>, ElementPairB_>;
static constexpr bool IsATransformed = cute::is_tuple<ElementPairA>::value;
using ElementScale = cute::conditional_t<IsATransformed, ScaleA, ScaleB>;
using ElementZero = cute::conditional_t<IsATransformed, ZeroA, ZeroB>;
static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
"Should meet TMA alignment requirement\n");
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
static_assert(cutlass::detail::dependent_false<ElementA> == 0, "Unsupported Toolkit for SM90 Collective Builder\n");
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
#endif
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_rs_tag_to_major_A<GmemLayoutA>();
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_rs_tag_to_major_B<GmemLayoutB>();
@ -382,12 +428,7 @@ public:
// If A is scaled, then we don't need to swap. Otherwise, we must ensure B goes to RF and we must swap the operands.
static constexpr bool SwapAB = !IsATransformed;
static_assert(detail::sm90_is_narrow_type_k_major<ElementA, GmemLayoutA, ElementB, GmemLayoutB>(), "The narrow type must be K-major.");
static_assert((IsATransformed && (cute::sizeof_bits_v<ElementA> <= 8) && (sizeof(ElementB) == 2)) ||
(!IsATransformed && (cute::sizeof_bits_v<ElementB> <= 8) && (sizeof(ElementA) == 2)) ||
(GmmaMajorA == cute::GMMA::Major::K && GmmaMajorB == cute::GMMA::Major::K),
"The unscaled element must be 2 bytes OR both inputs must be K-major");
// When we relax the above assertion, we must handle setting the tile mma GmmaMajorB correctly.
static constexpr cute::GMMA::Major TiledMmaGmmaMajorB = SwapAB ? GmmaMajorA : GmmaMajorB;
@ -400,6 +441,7 @@ public:
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
using SmemLayoutAtomA = decltype(detail::rs_smem_selector<GmmaMajorA, ElementA,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), IsWarpSpecializedTransposeB>());
using SmemLayoutAtomB = decltype(detail::rs_smem_selector<GmmaMajorB, ElementB,
@ -407,8 +449,8 @@ public:
using RealElementA = cute::conditional_t<SwapAB, ElementB, ElementA>;
using RealElementB = cute::conditional_t<SwapAB, ElementA, ElementB>;
static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes,
RealElementA, RealElementB, TileShape_MNK>(StageCountType{});
static constexpr int PipelineStages = detail::compute_stage_count_or_override_single_affine_transformed_input<detail::sm90_smem_capacity_bytes,
RealElementA, RealElementB, ElementScale, ElementZero, TileShape_MNK>(StageCountType{});
using SmemCopyAtomA = cute::conditional_t<SwapAB, void, Copy_Atom<cute::DefaultCopy, ElementA>>;
using SmemCopyAtomB = cute::conditional_t<SwapAB, Copy_Atom<cute::DefaultCopy, ElementB>, void>;
@ -416,35 +458,24 @@ public:
using DispatchPolicy = MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput<PipelineStages, ClusterShape_MNK, KernelScheduleType>;
// We pack the scale data with the operand that will be optionally scaled and converted before MMA.
using StrideAPair = TagToStrideA_t<GmemLayoutA>;
using StrideBPair = TagToStrideB_t<GmemLayoutB>;
using StrideA = TagToStrideA_t<GmemLayoutA>;
using StrideB = TagToStrideB_t<GmemLayoutB>;
using GmemTiledCopyAPair = GmemTiledCopyA;
using SmemLayoutAtomAPair = SmemLayoutAtomA;
using SmemCopyAtomAPair = SmemCopyAtomA;
using GmemTiledCopyBPair = GmemTiledCopyB;
using SmemLayoutAtomBPair = SmemLayoutAtomB;
using SmemCopyAtomBPair = SmemCopyAtomB;
// If the src type of the converter is the same as ElementA,
// interpret this as if the user wanted to apply the scale to the A matrix.
using CollectiveOp = CollectiveMma<
DispatchPolicy,
TileShape_MNK,
ElementPairA_,
StrideAPair,
ElementPairB_,
StrideBPair,
ElementPairA,
StrideA,
ElementPairB,
StrideB,
TiledMma,
GmemTiledCopyAPair,
SmemLayoutAtomAPair,
SmemCopyAtomAPair,
GmemTiledCopyA,
SmemLayoutAtomA,
SmemCopyAtomA,
cute::identity,
GmemTiledCopyBPair,
SmemLayoutAtomBPair,
SmemCopyAtomBPair,
GmemTiledCopyB,
SmemLayoutAtomB,
SmemCopyAtomB,
cute::identity
>;
@ -483,7 +514,9 @@ struct CollectiveBuilder<
cute::enable_if_t<
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedFP8FastAccum> ||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpongFP8FastAccum> ||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8FastAccum>>
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8FastAccum> ||
cute::is_same_v<KernelScheduleType, KernelArrayTmaWarpSpecializedCooperativeFP8FastAccum> ||
cute::is_same_v<KernelScheduleType, KernelGroupTmaWarpSpecializedCooperativeFP8FastAccum>>
> {
static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
@ -495,13 +528,16 @@ struct CollectiveBuilder<
static_assert(!detail::is_use_rmem_A<ElementA, GmemLayoutA, ElementB, GmemLayoutB>(),
"Not supported for fp8 non-TN warp specialized kernels yet\n");
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
static_assert(cutlass::detail::dependent_false<ElementA> == 0, "Unsupported Toolkit for SM90 Collective Builder\n");
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
#endif
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementA, GmemLayoutA>();
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementB, GmemLayoutB>();
using AtomLayoutMNK = cute::conditional_t<cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8FastAccum>,
static constexpr bool IsArrayOfPointersGemm = (cute::is_same_v<KernelScheduleType, KernelArrayTmaWarpSpecializedCooperativeFP8FastAccum> ||
cute::is_same_v<KernelScheduleType, KernelGroupTmaWarpSpecializedCooperativeFP8FastAccum>);
using AtomLayoutMNK = cute::conditional_t<cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8FastAccum> ||
IsArrayOfPointersGemm,
Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>;
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<
@ -517,8 +553,9 @@ struct CollectiveBuilder<
static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes,
ElementA, ElementB, TileShape_MNK>(StageCountType{});
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecialized<
PipelineStages, ClusterShape_MNK, KernelScheduleType>;
using DispatchPolicy = cute::conditional_t<IsArrayOfPointersGemm,
MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>>;
using SmemCopyAtomA = void;
using SmemCopyAtomB = void;
@ -580,7 +617,7 @@ struct CollectiveBuilder<
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
"Should meet TMA alignment requirement\n");
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
static_assert(cutlass::detail::dependent_false<ElementA> == 0, "Unsupported Toolkit for SM90 Collective Builder\n");
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
#endif
// For fp32 types, map to tf32 MMA value type
@ -721,7 +758,7 @@ struct CollectiveBuilder<
static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
static_assert(cutlass::detail::dependent_false<ElementA> == 0, "Unsupported Toolkit for SM90 Collective Builder\n");
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
#endif
// For fp32 types, map to tf32 MMA value type
@ -819,7 +856,7 @@ struct CollectiveBuilder<
static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
static_assert(cutlass::detail::dependent_false<ElementA> == 0, "Unsupported Toolkit for SM90 Collective Builder\n");
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
#endif
// For fp32 types, map to tf32 MMA value type
@ -918,13 +955,20 @@ struct CollectiveBuilder<
static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
static_assert(cutlass::detail::dependent_false<ElementA> == 0, "Unsupported Toolkit for SM90 Collective Builder\n");
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
#endif
static constexpr bool IsTmaCompatible = detail::is_aligned<
ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>();
using ExtractedElementA = detail::deduce_mixed_width_dtype_t<0, ElementA>;
using ExtractedElementB = detail::deduce_mixed_width_dtype_t<0, ElementB>;
static constexpr bool IsMixedWidthInput = cute::sizeof_bits_v<ElementA> != cute::sizeof_bits_v<ElementB>;
static constexpr bool IsTmaCompatible = detail::is_aligned<
ExtractedElementA, AlignmentA, ExtractedElementB, AlignmentB, detail::tma_alignment_bytes>();
// Users opt into scales via the builder by passing a tuple of Elements for the input that will be scaled. We detect
// scale support if ONLY one of the inputs have tuples to describe them.
static constexpr bool OnlyOneIsTuple = cute::is_tuple<ElementA>::value ^ cute::is_tuple<ElementB>::value;
static constexpr bool IsDifferentWidth = sizeof_bits<ExtractedElementA>::value != sizeof_bits<ExtractedElementB>::value;
static constexpr bool IsMixedWidthInput = IsDifferentWidth || (IsDifferentWidth && OnlyOneIsTuple);
#if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 1)))
// Persistent schedules perform best for CUDA Toolkits with version >= 12.1

View File

@ -56,7 +56,7 @@ template <
class TransformB
>
struct CollectiveMma {
static_assert(cutlass::detail::dependent_false<ElementA> == 0, "Could not find a mainloop specialization.");
static_assert(cutlass::detail::dependent_false<ElementA>, "Could not find a mainloop specialization.");
};
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -73,5 +73,6 @@ struct CollectiveMma {
#include "cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp"
#include "cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp"
#include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp"
#include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp"
#include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,741 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cute/arch/cluster_sm90.hpp"
#include "cute/arch/copy_sm90.hpp"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cute/algorithm/functional.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cute/algorithm/gemm.hpp"
#include "cute/tensor_predicate.hpp"
#include "cute/numeric/arithmetic_tuple.hpp"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/trace.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective {
using namespace cute;
/////////////////////////////////////////////////////////////////////////////////////////////////
// WarpSpecialized Mainloop
template <
int Stages,
class ClusterShape,
class KernelSchedule,
class TileShape_,
class ElementA_,
class StrideA_,
class ElementB_,
class StrideB_,
class TiledMma_,
class GmemTiledCopyA_,
class SmemLayoutAtomA_,
class SmemCopyAtomA_,
class TransformA_,
class GmemTiledCopyB_,
class SmemLayoutAtomB_,
class SmemCopyAtomB_,
class TransformB_>
struct CollectiveMma<
MainloopSm90ArrayTmaGmmaWarpSpecialized<Stages, ClusterShape, KernelSchedule>,
TileShape_,
ElementA_,
StrideA_,
ElementB_,
StrideB_,
TiledMma_,
GmemTiledCopyA_,
SmemLayoutAtomA_,
SmemCopyAtomA_,
TransformA_,
GmemTiledCopyB_,
SmemLayoutAtomB_,
SmemCopyAtomB_,
TransformB_>
{
//
// Type Aliases
//
using DispatchPolicy = MainloopSm90ArrayTmaGmmaWarpSpecialized<Stages, ClusterShape, KernelSchedule>;
using TileShape = TileShape_;
using ElementA = ElementA_;
using StrideA = StrideA_;
using ElementB = ElementB_;
using StrideB = StrideB_;
using TiledMma = TiledMma_;
using ElementAccumulator = typename TiledMma::ValTypeC;
using GmemTiledCopyA = GmemTiledCopyA_;
using GmemTiledCopyB = GmemTiledCopyB_;
using SmemLayoutAtomA = SmemLayoutAtomA_;
using SmemLayoutAtomB = SmemLayoutAtomB_;
using SmemCopyAtomA = SmemCopyAtomA_;
using SmemCopyAtomB = SmemCopyAtomB_;
using TransformA = TransformA_;
using TransformB = TransformB_;
using ArchTag = typename DispatchPolicy::ArchTag;
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
using PipelineParams = typename MainloopPipeline::Params;
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
// Tile along modes in a way that maximizes the TMA box size.
using SmemLayoutA = decltype(tile_to_shape(
SmemLayoutAtomA{},
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
using SmemLayoutB = decltype(tile_to_shape(
SmemLayoutAtomB{},
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more.");
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value &&
cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
"MMA atom must source both A and B operand from smem_desc for this mainloop.");
static_assert(cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
static_assert(cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
// TMA converts f32 input to tf32 when copying from GMEM to SMEM
// For all other types, cast to size equivalent uint type to avoid any rounding by TMA.
static constexpr bool ConvertF32toTF32A = cute::is_same_v<float, ElementA>;
static constexpr bool ConvertF32toTF32B = cute::is_same_v<float, ElementB>;
using InternalElementA = cute::conditional_t<ConvertF32toTF32A, tfloat32_t, uint_bit_t<sizeof_bits_v<ElementA>>>;
using InternalElementB = cute::conditional_t<ConvertF32toTF32B, tfloat32_t, uint_bit_t<sizeof_bits_v<ElementB>>>;
// Assumption: StrideA is congruent with Problem_MK
using TMA_A = decltype(make_tma_copy(
GmemTiledCopyA{},
make_tensor(static_cast<InternalElementA const*>(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}),
SmemLayoutA{}(_,_,cute::Int<0>{}),
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any
// Assumption: StrideB is congruent with Problem_NK
using TMA_B = decltype(make_tma_copy(
GmemTiledCopyB{},
make_tensor(static_cast<InternalElementB const*>(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}),
SmemLayoutB{}(_,_,cute::Int<0>{}),
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
struct SharedStorage {
struct TensorStorage : cute::aligned_struct<128> {
cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A;
cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B;
} tensors;
struct TensorMapStorage : cute::aligned_struct<128> {
cute::TmaDescriptor smem_tensormap_A;
cute::TmaDescriptor smem_tensormap_B;
} tensormaps;
using PipelineStorage = typename MainloopPipeline::SharedStorage;
PipelineStorage pipeline;
};
using TensorStorage = typename SharedStorage::TensorStorage;
using TensorMapStorage = typename SharedStorage::TensorMapStorage;
using PipelineStorage = typename SharedStorage::PipelineStorage;
static constexpr bool IsGroupedGemmKernel = cute::is_base_of_v<KernelGroupTmaWarpSpecializedCooperative, KernelSchedule>;
using StridesA = cute::conditional_t<IsGroupedGemmKernel, StrideA const*, StrideA>;
using StridesB = cute::conditional_t<IsGroupedGemmKernel, StrideB const*, StrideB>;
// Host side kernel arguments
struct Arguments {
ElementA const** ptr_A;
StridesA dA;
ElementB const** ptr_B;
StridesB dB;
};
// Device side kernel params
struct Params {
TMA_A tma_load_a;
TMA_B tma_load_b;
void* tensormaps;
InternalElementA const** ptr_A;
StridesA dA;
InternalElementB const** ptr_B;
StridesB dB;
};
//
// Methods
//
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(
ProblemShape problem_shapes,
Arguments const& args,
void* workspace) {
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(0), 1);
auto [M,N,K,L] = problem_shape_MNKL;
const uint32_t mock_L = 1;
// These tensor pointers are only used to create tensormap/tma desc.
// This address to the tensor will be replaced with correct address before the initial tma load
InternalElementA const* ptr_A_first_batch = reinterpret_cast<InternalElementA const*>(args.ptr_A);
InternalElementB const* ptr_B_first_batch = reinterpret_cast<InternalElementA const*>(args.ptr_B);
cudaError_t cuda_error = cudaGetLastError(); // clear previous error
StrideA stride_a;
StrideB stride_b;
if constexpr (IsGroupedGemmKernel) {
// Strides for Grouped Gemm will be replaced prior to the first access regardless
stride_a = StrideA{};
stride_b = StrideB{};
}
else {
stride_a = args.dA;
stride_b = args.dB;
}
Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(M,K,mock_L), stride_a));
Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(N,K,mock_L), stride_b));
TMA_A tma_load_a = make_tma_copy(
GmemTiledCopyA{},
tensor_a,
SmemLayoutA{}(_,_,cute::Int<0>{}),
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
TMA_B tma_load_b = make_tma_copy(
GmemTiledCopyB{},
tensor_b,
SmemLayoutB{}(_,_,cute::Int<0>{}),
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
void* tensormaps = workspace;
return {
tma_load_a,
tma_load_b,
tensormaps,
reinterpret_cast<InternalElementA const**>(args.ptr_A),
args.dA,
reinterpret_cast<InternalElementB const**>(args.ptr_B),
args.dB
};
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) {
constexpr uint32_t NumInputTensors = 2;
constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor);
// Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies
return (NumInputTensors * SizeOfCuTensorMap * sm_count);
}
template <class ProblemShape>
static cutlass::Status
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream) {
return cutlass::Status::kSuccess;
}
template<class ProblemShape>
CUTLASS_HOST_DEVICE static bool
can_implement(
ProblemShape problem_shapes,
Arguments const& args) {
constexpr int tma_alignment_bits = 128;
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
bool implementable = true;
// Check alignment for all problem sizes
for (int i = 0; i < problem_shapes.groups(); i++) {
auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1);
auto [M,N,K,L] = problem_shape_MNKL;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{});
}
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
}
return implementable;
}
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
static constexpr int K_PIPE_MMAS = 1;
static constexpr uint32_t TmaTransactionBytes =
(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(sizeof_bits<ElementA>::value)) / 8+
(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(sizeof_bits<ElementB>::value)) / 8;
// Set up the data needed by this collective for load and mma.
// Returns a tuple of tensors. The collective and the kernel layer have the contract that the
// returned tuple must contain at least two elements, with the first two elements being:
// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
// The rest of the tensors can be specified as needed by this collective.
template <class ProblemShape_MNKL>
CUTLASS_DEVICE auto
load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const {
using X = Underscore;
// Separate out problem shape for convenience
auto [M,N,K,L] = problem_shape_MNKL;
const int32_t mock_L = 1;
// TMA requires special handling of strides to deal with coord codomain mapping
// Represent the full tensors -- get these from TMA
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,mock_L)); // (m,k,l)
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,mock_L)); // (n,k,l)
// Make tiled views, defer the slice
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
return cute::make_tuple(gA_mkl, gB_nkl);
}
// Perform a collective-scoped matrix multiply-accumulate
// Producer Perspective
template <
class TensorA, class TensorB,
class TensorMapA, class TensorMapB,
class KTileIterator, class BlockCoord
>
CUTLASS_DEVICE void
load(
Params const& mainloop_params,
MainloopPipeline pipeline,
PipelineState smem_pipe_write,
cute::tuple<TensorA, TensorB> const& load_inputs,
cute::tuple<TensorMapA, TensorMapB> const& input_tensormaps,
BlockCoord const& blk_coord,
KTileIterator k_tile_iter, int k_tile_count,
int thread_idx,
uint32_t block_rank_in_cluster,
TensorStorage& shared_tensors) {
int warp_idx = canonical_warp_idx_sync();
int warp_idx_in_warp_group = warp_idx % 4;
int lane_predicate = cute::elect_one_sync();
if (warp_idx_in_warp_group == 0 and lane_predicate) {
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
//
// Prepare the TMA loads for A and B
//
constexpr uint32_t cluster_shape_x = get<0>(DispatchPolicy::ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
Tensor gA_mkl = get<0>(load_inputs);
Tensor gB_nkl = get<1>(load_inputs);
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
// Partition the inputs based on the current block coordinates.
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
// Applies the mapping from block_tma_a
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
uint16_t mcast_mask_a = 0;
uint16_t mcast_mask_b = 0;
// Issue TmaLoads
// Maps the tile -> block, value
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
for (int n = 0; n < size<1>(block_layout); ++n) {
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
}
}
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
for (int m = 0; m < size<0>(block_layout); ++m) {
mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
}
}
// Mainloop
CUTLASS_PRAGMA_NO_UNROLL
for ( ; k_tile_count > 0; --k_tile_count)
{
// LOCK smem_pipe_write for _writing_
pipeline.producer_acquire(smem_pipe_write);
//
// Copy gmem to smem for *k_tile_iter
//
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
int write_stage = smem_pipe_write.index();
copy(mainloop_params.tma_load_a.with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
copy(mainloop_params.tma_load_b.with(get<1>(input_tensormaps), *tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
++k_tile_iter;
// Advance smem_pipe_write
++smem_pipe_write;
}
}
}
// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
CUTLASS_DEVICE void
load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) {
int warp_idx = canonical_warp_idx_sync();
int warp_idx_in_warp_group = warp_idx % 4;
int lane_predicate = cute::elect_one_sync();
// Issue the epilogue waits
if (warp_idx_in_warp_group == 0 and lane_predicate) {
// This helps avoid early exit of blocks in Cluster.
// Waits for all stages to either be released (all
// Consumer UNLOCKs), or if the stage was never used
// then it would just be acquired since the phase was
// still inverted from make_producer_start_state.
pipeline.producer_tail(smem_pipe_write);
}
}
/// Perform a collective-scoped matrix multiply-accumulate
/// Consumer Perspective
template <
class FrgTensorC
>
CUTLASS_DEVICE void
mma(MainloopPipeline pipeline,
PipelineState smem_pipe_read,
FrgTensorC& accum,
int k_tile_count,
int thread_idx,
TensorStorage& shared_tensors,
Params const& mainloop_params) {
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::is_void_v<SmemCopyAtomA>,
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
static_assert(cute::is_void_v<SmemCopyAtomB>,
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
//
// Define C accumulators and A/B partitioning
//
TiledMma tiled_mma;
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
// Allocate "fragments/descriptors"
Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
//
// PIPELINED MAIN LOOP
//
static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX),
"ERROR : Incorrect number of MMAs in flight");
// We release buffers to producer warps(dma load) with some mmas in flight
PipelineState smem_pipe_release = smem_pipe_read;
// Prologue GMMAs
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
warpgroup_fence_operand(accum);
CUTLASS_PRAGMA_UNROLL
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue)
{
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
int read_stage = smem_pipe_read.index();
warpgroup_arrive();
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
// (V,M,K) x (V,N,K) => (V,M,N)
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_commit_batch();
++smem_pipe_read;
}
warpgroup_fence_operand(accum);
// Mainloop GMMAs
k_tile_count -= prologue_mma_count;
CUTLASS_PRAGMA_NO_UNROLL
for ( ; k_tile_count > 0; --k_tile_count)
{
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
//
// Compute on k_tile
//
int read_stage = smem_pipe_read.index();
warpgroup_fence_operand(accum);
warpgroup_arrive();
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
// (V,M,K) x (V,N,K) => (V,M,N)
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_commit_batch();
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
warpgroup_wait<K_PIPE_MMAS>();
warpgroup_fence_operand(accum);
// UNLOCK smem_pipe_release, done _computing_ on it
pipeline.consumer_release(smem_pipe_release);
// Advance smem_pipe_read and smem_pipe_release
++smem_pipe_read;
++smem_pipe_release;
}
warpgroup_fence_operand(accum);
}
/// Perform a Consumer Epilogue to release all buffers
CUTLASS_DEVICE void
mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) {
// Prologue GMMAs
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
k_tile_count -= prologue_mma_count;
smem_pipe_release.advance(k_tile_count);
// Wait on all GMMAs to complete
warpgroup_wait<0>();
for (int count = 0; count < prologue_mma_count; ++count) {
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
++smem_pipe_release;
}
}
//
// Methods to perform different parts of TMA/Tensormap modifications
//
CUTLASS_DEVICE auto
tensormaps_init(Params const& mainloop_params, int32_t const sm_count, int32_t const sm_idx) const {
cute::TmaDescriptor* gmem_tensormap = reinterpret_cast<cute::TmaDescriptor*>(mainloop_params.tensormaps);
cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx];
cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count];
if (cute::elect_one_sync()) {
// Bringing tensormaps from params to gmem for modification later
Tensor pA_tensormap = make_tensor(mainloop_params.tma_load_a.get_tma_descriptor(), Int<1>{}, Int<1>{});
Tensor gA_tensormap = make_tensor(tma_desc_a, Int<1>{}, Int<1>{});
Tensor pB_tensormap = make_tensor(mainloop_params.tma_load_b.get_tma_descriptor(), Int<1>{}, Int<1>{});
Tensor gB_tensormap = make_tensor(tma_desc_b, Int<1>{}, Int<1>{});
copy(recast<uint128_t>(pA_tensormap), recast<uint128_t>(gA_tensormap));
copy(recast<uint128_t>(pB_tensormap), recast<uint128_t>(gB_tensormap));
}
return cute::make_tuple(tma_desc_a, tma_desc_b);
}
// Bringing tensormaps to smem (to be done by single thread)
template <class TensorMapA, class TensorMapB>
CUTLASS_DEVICE
void
tensormaps_fetch_to_smem(
TensorMapStorage& shared_tensormap,
cute::tuple<TensorMapA, TensorMapB> const& input_tensormaps) const {
Tensor gA_tensormap = make_tensor(make_gmem_ptr(get<0>(input_tensormaps)), Int<1>{}, Int<1>{});
Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_A), Int<1>{}, Int<1>{});
Tensor gB_tensormap = make_tensor(make_gmem_ptr(get<1>(input_tensormaps)), Int<1>{}, Int<1>{});
Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_B), Int<1>{}, Int<1>{});
copy(recast<uint128_t>(gA_tensormap), recast<uint128_t>(sA_tensormap));
copy(recast<uint128_t>(gB_tensormap), recast<uint128_t>(sB_tensormap));
cp_async_fence();
cp_async_wait<0>();
}
// Replace address for the global tensor (to be done by single thread)
CUTLASS_DEVICE
void
tensormaps_replace_global_address(
TensorMapStorage& shared_tensormap,
Params const& mainloop_params,
int32_t next_batch) {
// Replacing global_address for the next batch
cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_A,
mainloop_params.ptr_A[next_batch]);
cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_B,
mainloop_params.ptr_B[next_batch]);
}
// Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread)
template <class ProblemShape_MNKL>
CUTLASS_DEVICE
void
tensormaps_replace_global_tensor_properties(
TensorMapStorage& shared_tensormap,
Params const& mainloop_params,
int32_t next_group,
ProblemShape_MNKL problem_shape_mnkl) {
const uint32_t M = get<0>(problem_shape_mnkl);
const uint32_t N = get<1>(problem_shape_mnkl);
const uint32_t K = get<2>(problem_shape_mnkl);
// Only consider dimensions and strides that we need to recalculate and replace for each group
constexpr int TensorRank = rank(ProblemShape_MNKL{}) - 1; // excluding either M or N
static_assert(TensorRank == Int<3>{},
"Descriptor modification for global dims & strides expects rank as 3.");
cute::array<uint32_t, TensorRank> prob_shape_A = {1,1,1};
cute::array<uint64_t, TensorRank> prob_stride_A = {0,0,0};
cute::array<uint32_t, TensorRank> prob_shape_B = {1,1,1};
cute::array<uint64_t, TensorRank> prob_stride_B = {0,0,0};
InternalElementA const* ptr_A = nullptr;
Tensor tensor_a = make_tensor(ptr_A, make_shape(M,K,Int<1>{}), mainloop_params.dA[next_group]);
InternalElementB const* ptr_B = nullptr;
Tensor tensor_b = make_tensor(ptr_B, make_shape(N,K,Int<1>{}), mainloop_params.dB[next_group]);
cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_a, tensor_a,
prob_shape_A, prob_stride_A);
cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_b, tensor_b,
prob_shape_B, prob_stride_B);
cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormap.smem_tensormap_A,
prob_shape_A,
prob_stride_A);
cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormap.smem_tensormap_B,
prob_shape_B,
prob_stride_B);
}
template <class TensorMapA, class TensorMapB, class ProblemShape_MNKL>
CUTLASS_DEVICE
void
tensormaps_perform_update(
TensorMapStorage& shared_tensormap,
Params const& mainloop_params,
cute::tuple<TensorMapA, TensorMapB> const& input_tensormaps,
ProblemShape_MNKL problem_shape_mnkl,
int32_t next_batch) {
if (cute::elect_one_sync()) {
// Bringing tensormaps to smem
tensormaps_fetch_to_smem(shared_tensormap, input_tensormaps);
// Replacing global_address for the next batch
tensormaps_replace_global_address(shared_tensormap, mainloop_params, next_batch);
if constexpr (IsGroupedGemmKernel) {
// Replacing global dims and strides for the next batch
tensormaps_replace_global_tensor_properties(shared_tensormap,
mainloop_params, next_batch, problem_shape_mnkl);
}
}
}
template <class TensorMapA, class TensorMapB>
CUTLASS_DEVICE
void
tensormaps_cp_fence_release (
TensorMapStorage& shared_tensormap,
cute::tuple<TensorMapA, TensorMapB> const& input_tensormaps) {
// Entire warp must do this (ie its aligned)
tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormap.smem_tensormap_A);
tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormap.smem_tensormap_B);
}
template <class TensorMapA, class TensorMapB>
CUTLASS_DEVICE
void
tensormaps_fence_acquire(cute::tuple<TensorMapA, TensorMapB> const& input_tensormaps) {
cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps));
cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps));
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -134,9 +134,7 @@ struct CollectiveMma<
using TransformB = TransformB_;
using ArchTag = typename DispatchPolicy::ArchTag;
using MainloopPipeline = cutlass::PipelineTmaAsync<
DispatchPolicy::Stages,
typename DispatchPolicy::ClusterShape>;
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
using PipelineParams = typename MainloopPipeline::Params;
@ -336,22 +334,20 @@ struct CollectiveMma<
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& mainloop_params)
{
static void prefetch_tma_descriptors(Params const& mainloop_params) {
cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
}
/// Set up the data needed by this collective for load and mma.
/// Returns a tuple of tensors. The collective and the kernel layer have the contract
/// that the tuple must contain at least two elements, with the first two elements being:
/// Returned tuple must contain at least two elements, with the first two elements being:
/// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
/// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
/// The rest of the tensors can be specified as needed by this collective.
template <class ProblemShape_MNKL,
class TileShapeMNK>
template <class ProblemShape_MNKL>
CUTLASS_DEVICE auto
tile_input_tensors(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params, TileShapeMNK const& tileshape_mnk) {
load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const {
using X = Underscore;
// Separate out problem shape for convenience
auto [M,N,K,L] = problem_shape_MNKL;
@ -362,8 +358,8 @@ struct CollectiveMma<
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l)
// Make tiled views, defer the slice
Tensor gA_mkl = local_tile(mA_mkl, tileshape_mnk, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
Tensor gB_nkl = local_tile(mB_nkl, tileshape_mnk, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
return cute::make_tuple(gA_mkl, gB_nkl);
}
@ -379,7 +375,7 @@ struct CollectiveMma<
Params const& mainloop_params,
MainloopPipeline pipeline,
PipelineState smem_pipe_write,
cute::tuple<TensorA, TensorB> const& tiled_tensors,
cute::tuple<TensorA, TensorB> const& load_inputs,
BlockCoord const& blk_coord,
KTileIterator k_tile_iter, int k_tile_count,
int thread_idx,
@ -405,16 +401,16 @@ struct CollectiveMma<
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
Tensor gA_mkl = get<0>(tiled_tensors);
Tensor gB_nkl = get<1>(tiled_tensors);
Tensor gA_mkl = get<0>(load_inputs);
Tensor gB_nkl = get<1>(load_inputs);
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
// Partition the inputs based on the current block coordinates.
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
// Applies the mapping from block_tma_a
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)

View File

@ -106,9 +106,7 @@ struct CollectiveMma<
using TransformB = TransformB_;
using ArchTag = typename DispatchPolicy::ArchTag;
using MainloopPipeline = cutlass::PipelineTmaAsync<
DispatchPolicy::Stages,
typename DispatchPolicy::ClusterShape>;
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename cutlass::PipelineState<DispatchPolicy::Stages>;
@ -147,8 +145,7 @@ struct CollectiveMma<
using InternalElementA = cute::conditional_t<ConvertF32toTF32A, tfloat32_t, uint_bit_t<sizeof_bits_v<ElementA>>>;
using InternalElementB = cute::conditional_t<ConvertF32toTF32B, tfloat32_t, uint_bit_t<sizeof_bits_v<ElementB>>>;
struct SharedStorage
{
struct SharedStorage {
cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A;
cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B;
@ -244,13 +241,13 @@ struct CollectiveMma<
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& mainloop_params)
{
static void prefetch_tma_descriptors(Params const& mainloop_params) {
cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
}
/// Perform a collective-scoped matrix multiply-accumulate
/// Producer Perspective
template <
class TensorA, class TMA_LOAD_A,
class TensorB, class TMA_LOAD_B,
@ -281,8 +278,8 @@ struct CollectiveMma<
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
SharedStorage& storage = *reinterpret_cast<SharedStorage*>(shared_memory);
Tensor sA = make_tensor(make_smem_ptr(storage.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(storage.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
Tensor sA = make_tensor(make_smem_ptr(storage.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(storage.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
//
// Prepare the TMA loads for A and B
@ -295,11 +292,11 @@ struct CollectiveMma<
auto block_tma_b = tma_load_b.get_slice(cluster_local_block_id.x);
// Applies the mapping from block_tma_a
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
//
// Prepare TMA membars and PREFETCH
@ -335,9 +332,7 @@ struct CollectiveMma<
params.is_leader = warp_group_thread_idx == 0;
params.num_consumers = NumThreadsPerWarpGroup;
MainloopPipeline pipeline(
storage.pipeline_storage,
params);
MainloopPipeline pipeline(storage.pipeline_storage, params, ClusterShape{});
// State variables used for iterating the circular buffer
// smem_pipe_read / release is used by the consumer of SMEM data - i.e MMA

View File

@ -107,9 +107,7 @@ struct CollectiveMma<
using TransformB = TransformB_;
using ArchTag = typename DispatchPolicy::ArchTag;
using MainloopPipeline = cutlass::PipelineTmaAsync<
DispatchPolicy::Stages,
typename DispatchPolicy::ClusterShape>;
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
using PipelineParams = typename MainloopPipeline::Params;
@ -255,22 +253,20 @@ struct CollectiveMma<
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& mainloop_params)
{
static void prefetch_tma_descriptors(Params const& mainloop_params) {
cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
}
/// Set up the data needed by this collective for load and mma.
/// Returns a tuple of tensors. The collective and the kernel layer have the contract
/// that the tuple must contain at least two elements, with the first two elements being:
/// Returned tuple must contain at least two elements, with the first two elements being:
/// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
/// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
/// The rest of the tensors can be specified as needed by this collective.
template <class ProblemShape_MNKL,
class TileShapeMNK>
template <class ProblemShape_MNKL>
CUTLASS_DEVICE auto
tile_input_tensors(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params, TileShapeMNK const& tileshape_mnk) {
load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const {
using X = Underscore;
// Separate out problem shape for convenience
auto [M,N,K,L] = problem_shape_MNKL;
@ -281,8 +277,8 @@ struct CollectiveMma<
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l)
// Make tiled views, defer the slice
Tensor gA_mkl = local_tile(mA_mkl, tileshape_mnk, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
Tensor gB_nkl = local_tile(mB_nkl, tileshape_mnk, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
return cute::make_tuple(gA_mkl, gB_nkl);
}
@ -298,14 +294,12 @@ struct CollectiveMma<
Params const& mainloop_params,
MainloopPipeline pipeline,
PipelineState smem_pipe_write,
cute::tuple<TensorA, TensorB> const& tiled_tensors,
cute::tuple<TensorA, TensorB> const& load_inputs,
BlockCoord const& blk_coord,
KTileIterator k_tile_iter, int k_tile_count,
int thread_idx,
uint32_t block_rank_in_cluster,
TensorStorage& shared_tensors)
{
TensorStorage& shared_tensors) {
using namespace cute;
int warp_idx = canonical_warp_idx_sync();
int warp_idx_in_warp_group = warp_idx % 4;
@ -322,16 +316,16 @@ struct CollectiveMma<
constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
Tensor gA_mkl = get<0>(tiled_tensors);
Tensor gB_nkl = get<1>(tiled_tensors);
Tensor gA_mkl = get<0>(load_inputs);
Tensor gB_nkl = get<1>(load_inputs);
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
// Partition the inputs based on the current block coordinates.
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
// Applies the mapping from block_tma_a
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
@ -386,10 +380,7 @@ struct CollectiveMma<
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
CUTLASS_DEVICE void
load_tail(
MainloopPipeline pipeline,
PipelineState smem_pipe_write)
{
load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) {
int warp_idx = canonical_warp_idx_sync();
int warp_idx_in_warp_group = warp_idx % 4;
int lane_predicate = cute::elect_one_sync();
@ -418,8 +409,7 @@ struct CollectiveMma<
int k_tile_count,
int thread_idx,
TensorStorage& shared_tensors,
Params const& mainloop_params)
{
Params const& mainloop_params) {
using namespace cute;
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");

View File

@ -108,9 +108,7 @@ struct CollectiveMma<
using TransformB = TransformB_;
using ArchTag = typename DispatchPolicy::ArchTag;
using MainloopPipeline = cutlass::PipelineTmaAsync<
DispatchPolicy::Stages,
typename DispatchPolicy::ClusterShape>;
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
using PipelineParams = typename MainloopPipeline::Params;
@ -261,14 +259,12 @@ struct CollectiveMma<
/// Set up the data needed by this collective for load and mma.
/// Returns a tuple of tensors. The collective and the kernel layer have the contract
/// that the tuple must contain at least two elements, with the first two elements being:
/// Returned tuple must contain at least two elements, with the first two elements being:
/// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
/// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
/// The rest of the tensors can be specified as needed by this collective.
template <class ProblemShape_MNKL,
class TileShapeMNK>
template <class ProblemShape_MNKL>
CUTLASS_DEVICE auto
tile_input_tensors(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params, TileShapeMNK const& tileshape_mnk) {
load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const {
using X = Underscore;
// Separate out problem shape for convenience
auto [M,N,K,L] = problem_shape_MNKL;
@ -279,8 +275,8 @@ struct CollectiveMma<
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l)
// Make tiled views, defer the slice
Tensor gA_mkl = local_tile(mA_mkl, tileshape_mnk, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
Tensor gB_nkl = local_tile(mB_nkl, tileshape_mnk, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
return cute::make_tuple(gA_mkl, gB_nkl);
}
@ -296,7 +292,7 @@ struct CollectiveMma<
Params const& mainloop_params,
MainloopPipeline pipeline,
PipelineState smem_pipe_write,
cute::tuple<TensorA, TensorB> const& tiled_tensors,
cute::tuple<TensorA, TensorB> const& load_inputs,
BlockCoord const& blk_coord,
KTileIterator k_tile_iter, int k_tile_count,
int thread_idx,
@ -320,8 +316,8 @@ struct CollectiveMma<
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
Tensor gA_mkl = get<0>(tiled_tensors);
Tensor gB_nkl = get<1>(tiled_tensors);
Tensor gA_mkl = get<0>(load_inputs);
Tensor gB_nkl = get<1>(load_inputs);
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);

View File

@ -42,6 +42,7 @@
#include "cutlass/gemm/gemm.h"
#include "cutlass/detail/layout.hpp"
#include "cutlass/detail/mma.hpp"
#include "cutlass/cuda_host_adapter.hpp"
#if !defined(__CUDACC_RTC__)
#include "cutlass/cluster_launch.hpp"
@ -106,9 +107,12 @@ public:
using LayoutC = gemm::detail::StrideToLayoutTagC_t<typename GemmKernel::StrideC>;
using LayoutD = gemm::detail::StrideToLayoutTagC_t<typename GemmKernel::StrideD>;
// NOTE: 3.0 kernels do not support complex transforms for now ...
static ComplexTransform const kTransformA = ComplexTransform::kNone;
static ComplexTransform const kTransformB = ComplexTransform::kNone;
static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER;
static ComplexTransform const kTransformA = cute::is_same_v<typename GemmKernel::CollectiveMainloop::TransformA, cute::conjugate> ?
ComplexTransform::kConjugate : ComplexTransform::kNone;
static ComplexTransform const kTransformB = cute::is_same_v<typename GemmKernel::CollectiveMainloop::TransformB, cute::conjugate> ?
ComplexTransform::kConjugate : ComplexTransform::kNone;
// Legacy: Assume MultiplyAdd only since we do not use this tag type in 3.0
using MathOperator = cutlass::arch::OpMultiplyAdd;
@ -141,17 +145,17 @@ public:
static int const kThreadCount = GemmKernel::MaxThreadsPerBlock;
// Warp shape is not a primary API type in 3.x
// But we can best approximate it by inspecting the TiledMma::TiledShape_MNK
// But we can best approximate it by inspecting the TiledMma
// For this, we make the assumption that we always have 4 warps along M, and rest along N, none along K
// We also always round up the warp count to 4 if the tiled mma is smaller than 128 threads
static constexpr int WarpsInMma = cute::max(4, cute::size(typename GemmKernel::TiledMma{}) / 32);
static constexpr int WarpsInMma = cute::max(4, CUTE_STATIC_V(cute::size(typename GemmKernel::TiledMma{})) / 32);
static constexpr int WarpsInMmaM = 4;
static constexpr int WarpsInMmaN = cute::ceil_div(WarpsInMma, WarpsInMmaM);
using WarpCount = cutlass::gemm::GemmShape<WarpsInMmaM, WarpsInMmaN, 1>;
using WarpShape = cutlass::gemm::GemmShape<
cute::size<0>(typename CollectiveMainloop::TiledMma::TiledShape_MNK{}) / WarpsInMmaM,
cute::size<1>(typename CollectiveMainloop::TiledMma::TiledShape_MNK{}) / WarpsInMmaN,
cute::size<2>(typename CollectiveMainloop::TiledMma::TiledShape_MNK{})>;
CUTE_STATIC_V(cute::tile_size<0>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaM,
CUTE_STATIC_V(cute::tile_size<1>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaN,
CUTE_STATIC_V(cute::tile_size<2>(typename CollectiveMainloop::TiledMma{}))>;
static int constexpr kStages = CollectiveMainloop::DispatchPolicy::Stages;
@ -270,7 +274,12 @@ public:
/// Initializes GEMM state from arguments.
Status
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
initialize(
Arguments const& args,
void* workspace = nullptr,
cudaStream_t stream = nullptr,
CudaHostAdapter *cuda_adapter = nullptr) {
CUTLASS_TRACE_HOST("GemmUniversal::initialize() - workspace "
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
@ -283,20 +292,33 @@ public:
// Initialize the Params structure
params_ = GemmKernel::to_underlying_arguments(args, workspace);
// account for dynamic smem capacity if needed
int smem_size = GemmKernel::SharedStorageSize;
if (smem_size >= (48 << 10)) {
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
cudaError_t result = cudaFuncSetAttribute(
device_kernel<GemmKernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
return Status::kErrorInternal;
// Don't set the function attributes - require the CudaHostAdapter to set it.
if constexpr (kEnableCudaHostAdapter) {
CUTLASS_ASSERT(cuda_adapter);
return Status::kSuccess;
}
else {
//
// Account for dynamic smem capacity if needed
//
int smem_size = GemmKernel::SharedStorageSize;
CUTLASS_ASSERT(cuda_adapter == nullptr);
if (smem_size >= (48 << 10)) {
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
cudaError_t result = cudaFuncSetAttribute(
device_kernel<GemmKernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
}
}
return Status::kSuccess;
}
@ -317,7 +339,7 @@ public:
/// Primary run() entry point API that is static allowing users to create and manage their own params.
/// Supplied params struct must be construct by calling GemmKernel::to_underling_arguments()
static Status
run(Params& params, cudaStream_t stream = nullptr) {
run(Params& params, cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) {
CUTLASS_TRACE_HOST("GemmUniversal::run()");
dim3 const block = GemmKernel::get_block_shape();
dim3 const grid = get_grid_shape(params);
@ -331,13 +353,54 @@ public:
dim3 cluster(cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}),
cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}),
cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{}));
void const* kernel = (void const*) device_kernel<GemmKernel>;
void* kernel_params[] = {&params};
launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params);
if constexpr (kEnableCudaHostAdapter) {
//
// Use the cuda host adapter
//
CUTLASS_ASSERT(cuda_adapter);
if (cuda_adapter) {
launch_result = cuda_adapter->launch(
grid, cluster, block, smem_size, stream, kernel_params, 0
);
}
else {
return Status::kErrorInternal;
}
}
else {
CUTLASS_ASSERT(cuda_adapter == nullptr);
void const* kernel = (void const*) device_kernel<GemmKernel>;
launch_result = ClusterLauncher::launch(
grid, cluster, block, smem_size, stream, kernel, kernel_params);
}
}
else {
launch_result = Status::kSuccess;
device_kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params);
if constexpr (kEnableCudaHostAdapter) {
CUTLASS_ASSERT(cuda_adapter);
if (cuda_adapter) {
void* kernel_params[] = {&params};
launch_result = cuda_adapter->launch(
grid, block, smem_size, stream, kernel_params, 0
);
}
else {
return Status::kErrorInternal;
}
}
else {
CUTLASS_ASSERT(cuda_adapter == nullptr);
device_kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params);
}
}
cudaError_t result = cudaGetLastError();
@ -356,18 +419,27 @@ public:
/// Launches the kernel after first constructing Params internal state from supplied arguments.
Status
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
run(
Arguments const& args,
void* workspace = nullptr,
cudaStream_t stream = nullptr,
CudaHostAdapter *cuda_adapter = nullptr
) {
Status status = initialize(args, workspace, stream);
if (Status::kSuccess == status) {
status = run(params_, stream);
status = run(params_, stream, cuda_adapter);
}
return status;
}
/// Launches the kernel after first constructing Params internal state from supplied arguments.
Status
operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
return run(args, workspace, stream);
operator()(
Arguments const& args,
void* workspace = nullptr,
cudaStream_t stream = nullptr,
CudaHostAdapter *cuda_adapter = nullptr) {
return run(args, workspace, stream, cuda_adapter);
}
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
@ -387,7 +459,7 @@ public:
////////////////////////////// CUTLASS 2.x API /////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
template <typename GemmKernel_>
template <class GemmKernel_>
class GemmUniversalAdapter<
GemmKernel_,
cute::enable_if_t<not gemm::detail::IsCutlass3GemmKernel<GemmKernel_>::value>>
@ -501,9 +573,14 @@ public:
}
/// Initializes GEMM state from arguments.
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
Status initialize(
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr,
CudaHostAdapter *cuda_adapter = nullptr
) {
return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream);
return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream, cuda_adapter);
}
/// Lightweight update given a subset of arguments.
@ -513,13 +590,18 @@ public:
}
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr) {
Status run(
cudaStream_t stream = nullptr,
CudaHostAdapter *cuda_adapter = nullptr) {
return underlying_operator_.run(stream);
return underlying_operator_.run(stream, cuda_adapter);
}
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr) {
Status operator()(
cudaStream_t stream = nullptr,
CudaHostAdapter *cuda_adapter = nullptr) {
return run(stream);
}
@ -527,12 +609,13 @@ public:
Status operator()(
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr) {
cudaStream_t stream = nullptr,
CudaHostAdapter *cuda_adapter = nullptr) {
Status status = initialize(args, workspace, stream);
Status status = initialize(args, workspace, stream, cuda_adapter);
if (status == Status::kSuccess) {
status = run(stream);
status = run(stream, cuda_adapter);
}
return status;

View File

@ -46,6 +46,7 @@
#include "cutlass/numeric_types.h"
#include "cutlass/arch/arch.h"
#include "cutlass/device_kernel.h"
#include "cutlass/cuda_host_adapter.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/gemm_universal.h"
@ -69,6 +70,8 @@ class GemmUniversalBase {
public:
using GemmKernel = GemmKernel_;
static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER;
using ThreadblockShape = typename GemmKernel::Mma::Shape;
using ElementA = typename GemmKernel::ElementA;
@ -295,7 +298,8 @@ public:
Status initialize(
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr)
cudaStream_t stream = nullptr,
CudaHostAdapter *cuda_adapter = nullptr)
{
CUTLASS_TRACE_HOST("GemmUniversalBase::initialize() - workspace "
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
@ -323,9 +327,8 @@ public:
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr)
Status run(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr)
{
CUTLASS_TRACE_HOST("GemmUniversalBase::run()");
@ -339,13 +342,27 @@ public:
"block: (" << block << "), "
"SMEM: (" << smem_size_ << ")");
Kernel2<GemmKernel><<<grid, block, smem_size_, stream>>>(params_);
if constexpr (kEnableCudaHostAdapter) {
CUTLASS_ASSERT(cuda_adapter);
if (cuda_adapter) {
void* kernel_params[] = {&params_};
return cuda_adapter->launch(grid, block, smem_size_, stream, kernel_params, 0);
}
else {
return Status::kErrorInternal;
}
}
else {
CUTLASS_ASSERT(cuda_adapter == nullptr);
// Query for errors
cudaError_t result = cudaGetLastError();
if (result != cudaSuccess) {
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
return Status::kErrorInternal;
Kernel2<GemmKernel><<<grid, block, smem_size_, stream>>>(params_);
// Query for errors
cudaError_t result = cudaGetLastError();
if (result != cudaSuccess) {
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
}
return Status::kSuccess;
@ -363,12 +380,13 @@ public:
Status operator()(
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr)
cudaStream_t stream = nullptr,
CudaHostAdapter *cuda_adapter = nullptr)
{
Status status = initialize(args, workspace, stream);
if (status == Status::kSuccess) {
status = run(stream);
status = run(stream, cuda_adapter);
}
return status;

View File

@ -53,6 +53,8 @@ struct KernelTma { };
struct KernelTmaWarpSpecialized { };
struct KernelTmaWarpSpecializedPingpong { };
struct KernelTmaWarpSpecializedCooperative { };
struct KernelArrayTmaWarpSpecializedCooperative { };
struct KernelGroupTmaWarpSpecializedCooperative { };
//////////////////////////////////////////////////////////////////////////////
@ -65,6 +67,8 @@ struct KernelTmaWarpSpecializedCooperative { };
struct KernelTmaWarpSpecializedFP8FastAccum : KernelTmaWarpSpecialized { };
struct KernelTmaWarpSpecializedPingpongFP8FastAccum : KernelTmaWarpSpecializedPingpong { };
struct KernelTmaWarpSpecializedCooperativeFP8FastAccum: KernelTmaWarpSpecializedCooperative { };
struct KernelArrayTmaWarpSpecializedCooperativeFP8FastAccum : KernelArrayTmaWarpSpecializedCooperative { };
struct KernelGroupTmaWarpSpecializedCooperativeFP8FastAccum : KernelGroupTmaWarpSpecializedCooperative { };
// Policies to opt into mixed type GEMMs
struct KernelTmaWarpSpecializedMixedInput : KernelTmaWarpSpecialized { };
@ -225,6 +229,23 @@ struct MainloopSm90TmaGmmaWarpSpecializedFP8
"KernelSchedule must be one of the warp specialized policies");
};
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp specialized dynamic schedule for Ptr-Array and Grouped Gemm
template<
int Stages_,
class ClusterShape_ = Shape<_1,_1,_1>,
class KernelSchedule = KernelGroupTmaWarpSpecializedCooperative
>
struct MainloopSm90ArrayTmaGmmaWarpSpecialized {
constexpr static int Stages = Stages_;
using ClusterShape = ClusterShape_;
using ArchTag = arch::Sm90;
using Schedule = KernelSchedule;
static_assert(
cute::is_base_of_v<KernelArrayTmaWarpSpecializedCooperative, KernelSchedule> ||
cute::is_base_of_v<KernelGroupTmaWarpSpecializedCooperative, KernelSchedule>,
"KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative policies");
};
//////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm

View File

@ -69,6 +69,7 @@ enum class GemmUniversalMode {
kGemmSplitKParallel,
kBatched,
kArray,
kGrouped,
kInvalid
};

View File

@ -0,0 +1,111 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief This file contains definitions and utility functions for describing problem shapes
for 3.x Ptr-Array GEMMs and Grouped GEMMs.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/tensor_coord.h"
#include "cute/container/array.hpp"
#if ! defined(__CUDACC_RTC__)
#include <initializer_list>
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <class ProblemShape_>
struct GroupProblemShape {
using UnderlyingProblemShape = ProblemShape_;
int32_t num_groups = 1;
UnderlyingProblemShape* problem_shapes = nullptr;
UnderlyingProblemShape const* host_problem_shapes = nullptr;
CUTLASS_HOST_DEVICE
int32_t groups() const { return num_groups; }
CUTLASS_HOST_DEVICE
UnderlyingProblemShape const
get_problem_shape(int32_t group_idx) const {
return problem_shapes[group_idx];
}
CUTLASS_HOST_DEVICE
UnderlyingProblemShape const
get_host_problem_shape(int32_t group_idx) const {
return host_problem_shapes[group_idx];
}
};
template <class ProblemShape_>
class ArrayProblemShape {
public:
using UnderlyingProblemShape = ProblemShape_;
ArrayProblemShape() = default;
ArrayProblemShape(UnderlyingProblemShape ps) : problem_shape_(ps) {}
// Num of groups for Ptr-Array GEMM always remain one, just the number of batches (l) can vary
// This is just to maintain uniformity with GroupProblemShape
constexpr int32_t groups() const { return 1; }
UnderlyingProblemShape* problem_shapes() const {
return &problem_shape_;
}
UnderlyingProblemShape const* host_problem_shapes() const {
return &problem_shape_;
}
// This is just to maintain uniformity with GroupProblemShape
CUTLASS_HOST_DEVICE
UnderlyingProblemShape const
get_problem_shape(int32_t /* unused */ = 0) const {
return problem_shape_;
}
CUTLASS_HOST_DEVICE
UnderlyingProblemShape const
get_host_problem_shape(int32_t /* unused */ = 0) const {
return problem_shape_;
}
private:
UnderlyingProblemShape problem_shape_{};
};
} // namespace cutlass::gemm

View File

@ -851,7 +851,9 @@ protected:
}
ptr_D += tile_work.tiled_coord.k() * params.batch_stride_D;
if (ptr_Tensor) {
ptr_Tensor += tile_work.tiled_coord.k() * params.batch_stride_Tensor;
ptr_Tensor = ReferenceFactory<typename Epilogue::ElementTensor>::add_pointer_offset(
ptr_Tensor,
tile_work.tiled_coord.k() * params.batch_stride_Tensor);
}
if (ptr_Vector) {
ptr_Vector += tile_work.tiled_coord.k() * params.batch_stride_Vector;
@ -2024,7 +2026,9 @@ protected:
ptr_C += tile_work.tiled_coord.k() * params.batch_stride_C;
ptr_D += tile_work.tiled_coord.k() * params.batch_stride_D;
if (ptr_Tensor) {
ptr_Tensor += tile_work.tiled_coord.k() * params.batch_stride_Tensor;
ptr_Tensor = ReferenceFactory<typename Epilogue::ElementTensor>::add_pointer_offset(
ptr_Tensor,
tile_work.tiled_coord.k() * params.batch_stride_Tensor);
}
if (ptr_Vector) {
ptr_Vector += tile_work.tiled_coord.k() * params.batch_stride_Vector;

View File

@ -67,9 +67,9 @@ class GemmUniversal<
Epilogue_,
ThreadblockSwizzle_,
void,
// 3.x kernels use the first template argument to define the ProblemShape tuple
// 3.x kernels use the first template argument to define the ProblemShape
// We use this invariant to SFINAE dispatch against either the 2.x API or the 3.x API
cute::enable_if_t<not cute::is_tuple<Mma_>::value>
cute::enable_if_t<not (cute::is_tuple<Mma_>::value || IsCutlass3ArrayKernel<Mma_>::value)>
> {
public:

View File

@ -61,6 +61,19 @@ template <
>
class GemmUniversal;
////////////////////////////////////////////////////////////////////////////////
// In cases where ProblemShape is not a tuple, this is used to check if the
// underlying problem shape type is aliased within or not.
// Used for dispatching GemmUniversal to 2.x API or 3.x API
template <class ProblemShape, class = void>
struct IsCutlass3ArrayKernel : cute::false_type { };
template <typename ProblemShape>
struct IsCutlass3ArrayKernel<ProblemShape, cute::void_t<typename ProblemShape::UnderlyingProblemShape>>
: cute::true_type { };
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::kernel
@ -75,4 +88,5 @@ class GemmUniversal;
#include "cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp"
#include "cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp"
#include "cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp"
#include "cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp"
////////////////////////////////////////////////////////////////////////////////

View File

@ -42,7 +42,7 @@
#include "cutlass/complex.h"
#include "cutlass/semaphore.h"
#include "cutlass/gemm/kernel/params_universal_base.h"
#include "cutlass/subbyte_reference.h"
#include "cutlass/trace.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -676,7 +676,9 @@ public:
}
ptr_D += threadblock_tile_offset.k() * params.batch_stride_D;
if (ptr_Tensor) {
ptr_Tensor += threadblock_tile_offset.k() * params.batch_stride_Tensor;
ptr_Tensor = ReferenceFactory<typename Epilogue::ElementTensor>::add_pointer_offset(
ptr_Tensor,
threadblock_tile_offset.k() * params.batch_stride_Tensor);
}
if (ptr_Vector) {
ptr_Vector += threadblock_tile_offset.k() * params.batch_stride_Vector;
@ -1387,7 +1389,9 @@ public:
ptr_C += threadblock_tile_offset.k() * params.batch_stride_C;
ptr_D += threadblock_tile_offset.k() * params.batch_stride_D;
if (ptr_Tensor) {
ptr_Tensor += threadblock_tile_offset.k() * params.batch_stride_Tensor;
ptr_Tensor = ReferenceFactory<typename Epilogue::ElementTensor>::add_pointer_offset(
ptr_Tensor,
threadblock_tile_offset.k() * params.batch_stride_Tensor);
}
if (ptr_Vector) {
ptr_Vector += threadblock_tile_offset.k() * params.batch_stride_Vector;

View File

@ -217,6 +217,8 @@ public:
// Only used by device-level operator
GemmCoord *host_problem_sizes;
bool allow_early_exit;
//
// Methods
//
@ -235,7 +237,8 @@ public:
ldb(nullptr),
ldc(nullptr),
ldd(nullptr),
host_problem_sizes(nullptr)
host_problem_sizes(nullptr),
allow_early_exit(false)
{
}
@ -256,7 +259,8 @@ public:
typename LayoutB::Stride::LongIndex *ldb,
typename LayoutC::Stride::LongIndex *ldc,
typename LayoutC::Stride::LongIndex *ldd,
GemmCoord *host_problem_sizes=nullptr
GemmCoord *host_problem_sizes=nullptr,
bool allow_early_exit=false
):
mode(mode),
problem_sizes(problem_sizes),
@ -271,7 +275,8 @@ public:
ldb(ldb),
ldc(ldc),
ldd(ldd),
host_problem_sizes(host_problem_sizes)
host_problem_sizes(host_problem_sizes),
allow_early_exit(allow_early_exit)
{
}
@ -303,6 +308,7 @@ public:
typename LayoutC::Stride::LongIndex *ldc;
typename LayoutC::Stride::LongIndex *ldd;
bool allow_early_exit;
//
// Methods
@ -318,7 +324,8 @@ public:
lda(nullptr),
ldb(nullptr),
ldc(nullptr),
ldd(nullptr)
ldd(nullptr),
allow_early_exit(false)
{ }
CUTLASS_HOST_DEVICE
@ -333,7 +340,8 @@ public:
lda(args.lda),
ldb(args.ldb),
ldc(args.ldc),
ldd(args.ldd)
ldd(args.ldd),
allow_early_exit(args.allow_early_exit)
{
}
@ -388,6 +396,12 @@ public:
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {
// Early exit following LAPACK's definition
if (params.allow_early_exit &&
(params.output_op.alpha == ElementC(0)) && (params.output_op.beta == ElementC(1))) {
return;
}
//
// Problem visitor.
//

Some files were not shown because too many files have changed in this diff Show More