CUTLASS 3.4.0 (#1286)
* CUTLASS 3.4.0 * Update CHANGELOG.md --------- Co-authored-by: Pradeep Ramani <prramani@nvidia.com>
This commit is contained in:
@ -1,4 +1,12 @@
|
||||
# NVIDIA CUTLASS Changelog
|
||||
## [3.4](https://github.com/NVIDIA/cutlass/releases/tag/v3.4) (2023-12-29)
|
||||
* Expanded [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) support covering {16-bit, 8-bit} x {8-bit, 4-bit} input types with fast numerical converters and group scaling factors.
|
||||
* Performance improvements to [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm)
|
||||
* Beta release of [Pointer-Array Batched GEMMs](/examples/56_hopper_ptr_array_batched_gemm) now available on Hopper GPUs utilizing TMA and WGMMA (requires CUDA 12.3 or above).
|
||||
* Beta release of [Group-GEMM](/examples/57_hopper_grouped_gemm) utilizing TMA and WGMMA (requires CUDA 12.3 or above).
|
||||
* NamedBarriers usability improvement and list of [ReservedNamedBarriers](/include/cutlass/arch/barrier.h) has been officially released.
|
||||
* Improved [CuTe TMA Tensor](/media/docs/cute/0z_tma_tensors.md) documentation.
|
||||
|
||||
|
||||
## [3.3](https://github.com/NVIDIA/cutlass/releases/tag/v3.3) (2023-10-31)
|
||||
* [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input operand types.
|
||||
|
||||
@ -40,7 +40,7 @@ endif()
|
||||
message(STATUS "CMake Version: ${CMAKE_VERSION}")
|
||||
set(IMPLICIT_CMAKE_CXX_STANDARD OFF CACHE BOOL "Do not explicitly specify -std=c++11 if set")
|
||||
|
||||
project(CUTLASS VERSION 3.3.0 LANGUAGES CXX)
|
||||
project(CUTLASS VERSION 3.4.0 LANGUAGES CXX)
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)
|
||||
|
||||
if (CUDA_VERSION VERSION_LESS 11.3)
|
||||
@ -681,6 +681,12 @@ endif()
|
||||
|
||||
################################################################################
|
||||
|
||||
set(CUTLASS_DEFAULT_ACTIVE_TEST_SETS "default" CACHE STRING "Default
|
||||
activated test sets. In `make test` mode, this string determines the
|
||||
active set of tests. In `ctest` mode, this value can be overriden
|
||||
with CUTLASS_TEST_SETS environment variable when running the ctest
|
||||
executable.")
|
||||
|
||||
set(CUTLASS_CTEST_TEMPLATE_FILE ${CMAKE_CURRENT_LIST_DIR}/cmake/CTestTestfile.configure.cmake)
|
||||
set(CUTLASS_CTEST_GENERATED_FILES "" CACHE INTERNAL "")
|
||||
|
||||
@ -701,11 +707,12 @@ function(cutlass_add_executable_tests NAME TARGET)
|
||||
# generating the full variable name to be referenced.
|
||||
# RESULT_CACHE_FILE: A file to be installed alongside the test executable with pre-computed
|
||||
# test results to speed up test runtime.
|
||||
# TEST_SETS_SUPPORTED: A list of test set names these tests support.
|
||||
#
|
||||
|
||||
set(options DISABLE_EXECUTABLE_INSTALL_RULE)
|
||||
set(oneValueArgs DISABLE_TESTS RESULT_CACHE_FILE TEST_COMMAND_OPTIONS_PREFIX)
|
||||
set(multiValueArgs DEPENDS DEPENDEES TEST_COMMAND_OPTIONS)
|
||||
set(multiValueArgs DEPENDS DEPENDEES TEST_COMMAND_OPTIONS TEST_SETS_SUPPORTED)
|
||||
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
if (NOT DEFINED __DISABLE_TESTS)
|
||||
@ -715,6 +722,12 @@ function(cutlass_add_executable_tests NAME TARGET)
|
||||
set(TEST_EXE $<TARGET_FILE_NAME:${TARGET}>)
|
||||
set(TEST_EXE_WORKING_DIRECTORY ./${CMAKE_INSTALL_BINDIR})
|
||||
|
||||
if (NOT DEFINED __TEST_SETS_SUPPORTED)
|
||||
set(__TEST_SETS_SUPPORTED ${CUTLASS_DEFAULT_ACTIVE_TEST_SETS})
|
||||
endif()
|
||||
|
||||
set(TEST_SETS_SUPPORTED ${__TEST_SETS_SUPPORTED})
|
||||
|
||||
if (__RESULT_CACHE_FILE)
|
||||
|
||||
add_custom_command(
|
||||
@ -816,8 +829,6 @@ function(cutlass_add_executable_tests NAME TARGET)
|
||||
set(TEST_GEN_DIR ${CMAKE_CURRENT_BINARY_DIR}/ctest/${TEST_NAME})
|
||||
file(MAKE_DIRECTORY ${TEST_GEN_DIR})
|
||||
|
||||
set(TEST_SETS_SUPPORTED default)
|
||||
|
||||
set(TEST_EXE_PATH $<TARGET_FILE:${TARGET}>)
|
||||
set(TEST_USE_EXTENDED_FORMAT ON)
|
||||
configure_file("${CUTLASS_CTEST_TEMPLATE_FILE}" "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.cmake" @ONLY)
|
||||
@ -873,9 +884,9 @@ if (CUTLASS_INSTALL_TESTS)
|
||||
file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/ctest")
|
||||
|
||||
file(WRITE "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "# Generated File\n\n")
|
||||
|
||||
file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "if (NOT DEFINED ENV{CUTLASS_TEST_SET})\n")
|
||||
file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" " set(ENV{CUTLASS_TEST_SET} \"default\")\n")
|
||||
file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "cmake_policy(SET CMP0057 NEW) # Allow IN_LIST for if()\n\n")
|
||||
file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "if (NOT DEFINED ENV{CUTLASS_TEST_SETS})\n")
|
||||
file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" " set(ENV{CUTLASS_TEST_SETS} ${CUTLASS_DEFAULT_ACTIVE_TEST_SETS})\n")
|
||||
file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "endif()\n\n")
|
||||
|
||||
foreach(GENERATED_FILE ${CUTLASS_CTEST_GENERATED_FILES})
|
||||
@ -897,9 +908,15 @@ write_basic_package_version_file(
|
||||
${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfigVersion.cmake
|
||||
COMPATIBILITY AnyNewerVersion)
|
||||
|
||||
configure_file(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cmake/NvidiaCutlassConfig.cmake.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfig.cmake
|
||||
@ONLY
|
||||
)
|
||||
|
||||
install(
|
||||
FILES
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cmake/NvidiaCutlassConfig.cmake
|
||||
${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfig.cmake
|
||||
${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfigVersion.cmake
|
||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/NvidiaCutlass/
|
||||
)
|
||||
|
||||
@ -13,6 +13,7 @@ Cris Cecka<br />
|
||||
Aniket Shivam<br />
|
||||
Jack Kosaian<br />
|
||||
Mark Hoemmen<br />
|
||||
Richard Cai<br />
|
||||
Honghao Lu<br />
|
||||
Ethan Yan<br />
|
||||
Haicheng Wu<br />
|
||||
@ -21,6 +22,8 @@ Dustyn Blasig<br />
|
||||
Fengqi Qiao<br />
|
||||
Duane Merrill<br />
|
||||
Yujia Zhai<br />
|
||||
Rawn Henry<br />
|
||||
Sergey Klevtsov<br />
|
||||
Shang Zhang<br />
|
||||
Piotr Majcher<br />
|
||||
Paul Springer<br />
|
||||
@ -55,6 +58,7 @@ Alan Kaatz<br />
|
||||
Tina Li<br />
|
||||
Timmy Liu<br />
|
||||
Wei Liu<br />
|
||||
Tim Martin<br />
|
||||
Duane Merrill<br />
|
||||
Kevin Siu<br />
|
||||
Markus Tavenrath<br />
|
||||
|
||||
@ -248,11 +248,15 @@ function(cutlass_unify_source_files TARGET_ARGS_VAR)
|
||||
message(FATAL_ERROR "TARGET_ARGS_VAR parameter is required")
|
||||
endif()
|
||||
|
||||
if (NOT DEFINED __BATCH_SOURCES)
|
||||
set(__BATCH_SOURCES ON)
|
||||
endif()
|
||||
|
||||
if (__BATCH_SOURCES AND NOT DEFINED __BATCH_SIZE)
|
||||
set(__BATCH_SIZE ${CUTLASS_UNITY_BUILD_BATCH_SIZE})
|
||||
endif()
|
||||
|
||||
if (CUTLASS_UNITY_BUILD_ENABLED AND DEFINED __BATCH_SIZE AND __BATCH_SIZE GREATER 1)
|
||||
if (CUTLASS_UNITY_BUILD_ENABLED AND __BATCH_SOURCES AND __BATCH_SIZE GREATER 1)
|
||||
|
||||
set(CUDA_FILE_ARGS)
|
||||
set(TARGET_SOURCE_ARGS)
|
||||
|
||||
22
README.md
22
README.md
@ -1,8 +1,8 @@
|
||||

|
||||
|
||||
# CUTLASS 3.3
|
||||
# CUTLASS 3.4
|
||||
|
||||
_CUTLASS 3.3 - October 2023_
|
||||
_CUTLASS 3.4 - December 2023_
|
||||
|
||||
CUTLASS is a collection of CUDA C++ template abstractions for implementing
|
||||
high-performance matrix-matrix multiplication (GEMM) and related computations at all levels
|
||||
@ -41,17 +41,14 @@ and improves code composability and readability. More documentation specific to
|
||||
|
||||
In addition to GEMMs, CUTLASS implements high-performance convolution via the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline. This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components.
|
||||
|
||||
# What's New in CUTLASS 3.3
|
||||
# What's New in CUTLASS 3.4
|
||||
|
||||
CUTLASS 3.3.0 is an update to CUTLASS adding:
|
||||
CUTLASS 3.4.0 is an update to CUTLASS adding:
|
||||
|
||||
- New [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input types with optimal performance.
|
||||
- New [Mixed-input Ampere GEMMs](https://github.com/NVIDIA/cutlass/pull/1084) with support for canonical layouts (TN). The implementation supports upcast on operandB {fp16, bf16} x {s8, u8} and upcast on operandA {s8, u8} x {fp16, bf16}. They also include fast numeric conversion recipes and warp level shuffles to achieve optimal performance.
|
||||
- New [Copy Async based Hopper GEMMs](/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_cooperative.cu) - which support lower than 16B aligned input tensors (across s8/fp8/fp16/bf16/tf32 types) with optimal performance. As a part of this, new kernel schedules, and Copy Ops [SM80\_CP\_ASYNC\_CACHE\_\*](/include/cute/arch/copy_sm80.hpp) were also added.
|
||||
- EVT Support for RELU with Aux bitmap tensor store (used in dRELU). See [SM90 EVT fusions](/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp) for details.
|
||||
- Various subbyte enhancements like tagged device ptrs, support for vectorized copy, various operators to treat subbyte iterators as pointers, and full-fledged CuTe Tensor support.
|
||||
- Support for Clang as a host compiler.
|
||||
- Support for void-C kernels and SM80 mixed-input GEMMs in the CUTLASS Python interface
|
||||
- Improved [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) supporting {16-bit, 8-bit} x {8-bit, 4-bit} input types with fast numerical converters and group scaling factors tuned for optimal performance on Hopper H100.
|
||||
- Beta release of [Pointer-Array Batched GEMMs](/examples/56_hopper_ptr_array_batched_gemm) utilizing TMA and Hopper H100 tensor cores now available. (Requires CUDA 12.3 or above)
|
||||
- Beta release of [Group-GEMM](/examples/57_hopper_grouped_gemm) - commonly used in optimization of Mixture-Of-Expert models, is now available on Hopper GPUs taking advantage of TMA and Hopper H100 tensor cores. (Requires CUDA 12.3 or above)
|
||||
- Impovements to NamedBarriers including details of [ReservedNamedBarriers](/include/cutlass/arch/barrier.h) used within the CUTLASS library.
|
||||
|
||||
Minimum requirements:
|
||||
|
||||
@ -95,7 +92,7 @@ as shown in the above figure. Tensor Core operations are implemented using CUDA
|
||||
|
||||
CUTLASS requires a C++17 host compiler and
|
||||
performs best when built with the [**CUDA 12.2.2 Toolkit**](https://developer.nvidia.com/cuda-toolkit-archive).
|
||||
It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, CUDA 12.0 and CUDA 12.1.
|
||||
It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, CUDA 12.0, CUDA 12.1, CUDA 12.2.2 and CUDA 12.3.1
|
||||
|
||||
## Operating Systems
|
||||
We have tested the following environments.
|
||||
@ -107,6 +104,7 @@ We have tested the following environments.
|
||||
| Ubuntu 22.04 | GCC 11.2.0 |
|
||||
| Ubuntu 22.04 | Clang 10.0.0 |
|
||||
| Ubuntu 22.04 | Clang 14.0.6 |
|
||||
| Ubuntu 22.04 | Clang 17.0.6 |
|
||||
| Windows 10.0 | Visual Studio 2019 v16.11.27 |
|
||||
|
||||
Note: GCC 8.5.0 has known regressions regarding fold expressions and overloaded operators. Using GCC 7.5.0 or (preferred) GCC >= 9 is recommended.
|
||||
|
||||
@ -2,10 +2,16 @@
|
||||
|
||||
set(TEST_SETS_SUPPORTED @TEST_SETS_SUPPORTED@)
|
||||
|
||||
#? if (DEFINED ENV{CUTLASS_TEST_SET} AND NOT ENV{CUTLASS_TEST_SET} IN_LIST TEST_SETS_SUPPORTED)
|
||||
#? message(STATUS "Skipping tests for @TEST_EXE_PATH@ as $ENV{CUTLASS_TEST_SET} is not in the set of ${TEST_SETS_SUPPORTED}.")
|
||||
#? return()
|
||||
#? endif()
|
||||
if (NOT DEFINED ENV{CUTLASS_TEST_SETS})
|
||||
set(ENV{CUTLASS_TEST_SETS} @CUTLASS_DEFAULT_ACTIVE_TEST_SETS@)
|
||||
endif()
|
||||
|
||||
foreach(TEST_SET_REQUESTED IN ITEMS $ENV{CUTLASS_TEST_SETS})
|
||||
if (NOT TEST_SET_REQUESTED IN_LIST TEST_SETS_SUPPORTED)
|
||||
message(STATUS "Skipping tests for @TEST_EXE_PATH@ as ${TEST_SET_REQUESTED} is not in the set of [${TEST_SETS_SUPPORTED}].")
|
||||
return()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
set(TEST_EXE_PATH @TEST_EXE_PATH@)
|
||||
set(TEST_EXE_WORKING_DIRECTORY @TEST_EXE_WORKING_DIRECTORY@)
|
||||
|
||||
@ -7,6 +7,3 @@ if(TARGET nvidia::cutlass::CUTLASS)
|
||||
endif()
|
||||
|
||||
include("${NvidiaCutlass_CMAKE_DIR}/NvidiaCutlassTargets.cmake")
|
||||
|
||||
# For backward compatibility with the old name
|
||||
add_library(cutlass_lib ALIAS cutlass_library)
|
||||
@ -291,8 +291,8 @@ int run() {
|
||||
LayoutInputB,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
int32_t,
|
||||
int32_t>
|
||||
ElementComputeEpilogue,
|
||||
ElementComputeEpilogue>
|
||||
gemm_device;
|
||||
|
||||
// Launch device reference gemm kernel
|
||||
|
||||
@ -279,7 +279,7 @@ struct Options {
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "28_ampere_gemm_bias_fusion example\n\n"
|
||||
out << "23_ampere_operand_gemm_reduction_fusion\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --m=<int> GEMM M\n"
|
||||
@ -297,7 +297,7 @@ struct Options {
|
||||
<< " --tag=<string> String to replicate across the first column in the results table\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ ./examples/23_ampere_gemm_bias_fusion_example/ampere_gemm_bias_fusion --m=1024 --n=1024 --k=1024 \n\n";
|
||||
<< "$ ./examples/23_ampere_gemm_operand_reduction_fusion/23_ampere_gemm_operand_reduction_fusion --m=1024 --n=1024 --k=1024 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -602,7 +602,7 @@ Result profile_convolution(Options const &options) {
|
||||
|
||||
std::stringstream ss;
|
||||
|
||||
ss << "26_ampere_fused_wgrad_batch_normalization_"
|
||||
ss << "30_wgrad_split_k_"
|
||||
<< options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c()
|
||||
<< "_"
|
||||
<< options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c()
|
||||
|
||||
@ -251,7 +251,7 @@ struct Options {
|
||||
<< " --tag=<string> String to replicate across the first column in the results table\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ ./examples/31_transposed_conv2d/31_transposed_conv2d --n=8 --h=32 --w=32 --c=16 --k=32 --r=3 --s=3\n\n";
|
||||
<< "$ ./examples/34_transposed_conv2d/34_transposed_conv2d --n=8 --h=32 --w=32 --c=16 --k=32 --r=3 --s=3\n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -398,7 +398,7 @@ struct Options {
|
||||
<< "$ ./examples/38_syr2k_grouped/38_syr2k_grouped --benchmark=problems.txt\n\n"
|
||||
|
||||
<< "# Execute Grouped SYR2K and profile with NSight\n"
|
||||
<< "$ nv-nsight-cu-cli ./examples/24_gemm_grouped/24_gemm_grouped --n=256 --k=256 --verbose=true --iterations=1 --reference-check=false\n\n";
|
||||
<< "$ nv-nsight-cu-cli ./examples/38_syr2k_grouped/38_syr2k_grouped --n=256 --k=256 --verbose=true --iterations=1 --reference-check=false\n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -306,7 +306,7 @@ struct Options {
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream &print_usage(std::ostream &out) const {
|
||||
out << "41_depthwise_gemm_fprop example\n\n"
|
||||
out << "46_depthwise_gemm_fprop example\n\n"
|
||||
<< " This example uses Ampere's Tensor Core operators on F16 data types to compute\n"
|
||||
<< " forward convolution on tensors of layout NHWC.\n\n"
|
||||
<< "Options:\n\n"
|
||||
@ -554,7 +554,7 @@ Result profile_convolution(Options const &options) {
|
||||
if (options.save_workspace) {
|
||||
std::stringstream ss;
|
||||
|
||||
ss << "45_depthwise_simt_conv2dfprop" << options.input_size.n() << "x" << options.input_size.h()
|
||||
ss << "46_depthwise_simt_conv2dfprop" << options.input_size.n() << "x" << options.input_size.h()
|
||||
<< "x" << options.input_size.w() << "x" << options.input_size.c() << "_"
|
||||
<< options.filter_size.n() << "x" << options.filter_size.h() << "x"
|
||||
<< options.filter_size.w() << "x" << options.filter_size.c() << ".dat";
|
||||
|
||||
@ -291,7 +291,7 @@ struct ExampleRunner {
|
||||
using CustomEVT = // alpha * acc + beta * C
|
||||
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::homogeneous_multiply_add, ElementD, ElementCompute, RoundStyle>, // beta * C + (alpha * acc)
|
||||
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementScalar>, // beta
|
||||
cutlass::epilogue::fusion::Sm90SrcFetch, // C
|
||||
cutlass::epilogue::fusion::Sm90SrcFetch<ElementC>, // C
|
||||
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc
|
||||
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementScalar>, // alpha
|
||||
cutlass::epilogue::fusion::Sm90AccFetch // acc
|
||||
@ -302,7 +302,7 @@ struct ExampleRunner {
|
||||
// Users can select one of these operations by passing one of the tags defined in include/cutlass/epilogue/fusion/operations.hpp
|
||||
// to the CollectiveBuilder. This frees the user from having to compute additional parameters such as stage counts and copy atoms/layouts.
|
||||
// These tags also provide additional metadata that can be queried at compile time.
|
||||
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementScalar, RoundStyle>;
|
||||
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementScalar, RoundStyle>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
|
||||
@ -568,7 +568,7 @@ struct ExampleRunner
|
||||
if (options.reference_check) {
|
||||
if (!verify()) {
|
||||
std::cout << "Failed validation" << std::endl;
|
||||
#if 1
|
||||
#if 0
|
||||
debug_output(std::cout);
|
||||
#endif
|
||||
return false;
|
||||
|
||||
@ -122,7 +122,7 @@ public:
|
||||
static_assert(cute::size(GmemTiledCopyA{}) == cute::size(GmemTiledCopyB{}), "Number of threads in A/B tiled copies must be the same.");
|
||||
|
||||
static constexpr uint32_t NumLoadWarpGroups = cute::size(GmemTiledCopyA{}) / NumThreadsPerWarpGroup;
|
||||
static constexpr uint32_t NumMmaWarpGroups = cute::size(TiledMma{}) / NumThreadsPerWarpGroup;
|
||||
static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(cute::size(TiledMma{})) / NumThreadsPerWarpGroup;
|
||||
static constexpr uint32_t NumWarpGroups = NumLoadWarpGroups + NumMmaWarpGroups;
|
||||
static_assert(NumWarpGroups == 2 || NumWarpGroups == 3, "Number of warp groups must be 2 or 3 for good performance.");
|
||||
|
||||
|
||||
@ -109,6 +109,8 @@
|
||||
namespace example
|
||||
{
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
@ -724,6 +726,7 @@ private:
|
||||
return true;
|
||||
}
|
||||
};
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
} // namespace example
|
||||
|
||||
@ -749,7 +752,7 @@ int main(int argc, char const **argv)
|
||||
if (notSupported) {
|
||||
return EXIT_SUCCESS; // Do not fail CI checks on unsupported systems
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
example::Options options;
|
||||
options.parse(argc, argv);
|
||||
|
||||
@ -970,6 +973,6 @@ int main(int argc, char const **argv)
|
||||
result &= runner.run(options);
|
||||
}
|
||||
#endif
|
||||
|
||||
return result ? EXIT_SUCCESS : EXIT_FAILURE;
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
}
|
||||
|
||||
@ -109,9 +109,11 @@ using ElementD = ElementC;
|
||||
using LayoutD = LayoutC;
|
||||
constexpr int AlignmentD = AlignmentC;
|
||||
|
||||
// Auxiliary matrix configuration
|
||||
// Auxiliary matrix configuration and other fusion types
|
||||
using ElementAux = ElementC;
|
||||
using LayoutAux = LayoutC;
|
||||
using ElementAmax = float;
|
||||
using ElementBias = float;
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
@ -124,7 +126,7 @@ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux<
|
||||
LayoutAux, cutlass::epilogue::thread::ReLU, ElementD, ElementCompute, ElementAux>;
|
||||
LayoutAux, cutlass::epilogue::thread::ReLU, ElementD, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementC>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
|
||||
@ -38,19 +38,41 @@
|
||||
using INT8 tensor cores.
|
||||
|
||||
The narrower type always passes through the register file. Therefore, in cases where the narrower type is operand B, the collective will implicitly swap
|
||||
A and B in the main loop. Consequently, it is essential to consider this when constructing the epilogue, as illustrated in this example.
|
||||
A and B in the main loop. However, implicit swaps do not support TMA epilogues. Consequently, it is essential to consider this when constructing the epilogue,
|
||||
as illustrated in this example.
|
||||
|
||||
Note that in this example, we explicitly swap A and B in order to use TMA epilogues. We do this since TMA epilogues are more performant on problem sizes of interest.
|
||||
|
||||
It is expected that the scale's K dimension be scale_k = ceil_div(problem_k, group_size).
|
||||
|
||||
Scales are always expected to be MN major. This means the fastest changing dimension must be M if A is scaled or N if B is scaled.
|
||||
|
||||
If A is being scaled, the scales should have shape [M, scale_k], while if B is scaled, it must have shape [N, scale_k].
|
||||
|
||||
The implementation only supports "group-wise" scales. However, we can make it work for per-column scales by setting the groups size
|
||||
equal to the gemm problem K.
|
||||
|
||||
Limitations:
|
||||
1) Only supported combinations are 16-bit x {8-bit, 4-bit, 2-bit} and {8-bit} x {4-bit, 2-bit}.
|
||||
2) The narrow type must always be in K-major format.
|
||||
3) When dealing with 8-bit x {4-bit, 2-bit}, both inputs must be in K-major format.
|
||||
4) Currently, TMA epilogues cannot be used when the narrow type is the B operand. This limitation arises because the implementation always swaps the
|
||||
operands to ensure the narrow type passes through the register file, and TMA epilogues do not currently support swap + transpose operations.
|
||||
We plan to address this limitation in the future.
|
||||
3) The scales and zeros must be MN major. That means if A is scaled, it must be column major, but if B is scaled it must be row major.
|
||||
4) The scales and the zeros must have the same layout and groupsize.
|
||||
5) When dealing with 8-bit x {4-bit, 2-bit}, both inputs must be in K-major format.
|
||||
6) Currently, TMA epilogues cannot be used when the narrow type is the B operand. This limitation arises because the implementation always swaps the
|
||||
operands to ensure that the narrow type passes through the register file, and TMA epilogues do not currently support implicit swap + transpose operations.
|
||||
We plan to address this limitation in the future. However, we address this in the example by explicitly swapping and transposing the operands.
|
||||
|
||||
Examples:
|
||||
|
||||
Runs the mixed input batched gemm (with batch size 2), converting B to the type of A (mode 0)
|
||||
$ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm --m=2048 --n=2048 --k=2048 --l=2 --mode=0
|
||||
|
||||
$ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm --m=2048 --n=2048 --k=2048 --l=2
|
||||
Runs the mixed input gemm, and applies a scaling factor to B before mma (mode 1). Applies a vector of scales to the entire
|
||||
matrix (group size is the same as the gemm k dimension).
|
||||
$ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm --m=4096 --n=5120 --k=8192 --g=8192 --mode=1
|
||||
|
||||
Runs the mixed input gemm, and applies a scaling factor and adds a zero-point to B before mma (mode 2). Uses a group size of 128.
|
||||
$ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm --m=2048 --n=5120 --k=8192 --g=128 --mode=2
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
@ -79,20 +101,28 @@
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
|
||||
#include "helper.h"
|
||||
#include "unfused_weight_dequantize.h"
|
||||
#include "unfused_weight_dequantize.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
// This is just an example, so we use a regular enum so we can compare directly to the command-line int.
|
||||
enum GemmMode {
|
||||
ConvertOnly,
|
||||
ScaleOnly,
|
||||
ScaleWithZeroPoint
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
using MmaType = cutlass::half_t;
|
||||
using QuantType = int8_t;
|
||||
using MmaType = cutlass::float_e4m3_t;
|
||||
using QuantType = cutlass::int4b_t;
|
||||
constexpr int TileShapeK = 128 * 8 / sizeof_bits<MmaType>::value;
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = MmaType; // Element type for A matrix operand
|
||||
using ElementA = MmaType; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
@ -101,6 +131,14 @@ using ElementB = QuantType; // E
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// This example manually swaps and transposes, so keep transpose of input layouts
|
||||
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
|
||||
|
||||
using ElementZero = cutlass::half_t;
|
||||
using ElementScale = cutlass::half_t;
|
||||
using LayoutScale = cutlass::layout::RowMajor;
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
|
||||
@ -109,51 +147,107 @@ constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // M
|
||||
// D matrix configuration
|
||||
using ElementD = ElementC;
|
||||
using LayoutD = LayoutC;
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ElementCompute = float; // Element type for epilogue computation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_128,_256,_64>; // Threadblock-level tile size
|
||||
using TileShape = Shape<_128,_256,cute::Int<TileShapeK>>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_2,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default setting in the Collective Builder
|
||||
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput; // Kernel to launch based on the default setting in the Collective Builder
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
EpilogueTileType,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
// Lie here about layout of C and D since we do swap and transpose trick
|
||||
// Transpose layout of D here since we use explicit swap + transpose
|
||||
// the void type for C tells the builder to allocate 0 smem for the C matrix.
|
||||
// We can enable this if beta == 0 by changing ElementC to void below.
|
||||
ElementC, typename cutlass::layout::LayoutTranspose<LayoutC>::type, AlignmentC,
|
||||
ElementC, typename cutlass::layout::LayoutTranspose<LayoutC>::type, AlignmentC,
|
||||
cutlass::epilogue::NoSmemWarpSpecialized // This is the only epi supporting the required swap + transpose.
|
||||
ElementD, typename cutlass::layout::LayoutTranspose<LayoutD>::type, AlignmentD,
|
||||
EpilogueSchedule // This is the only epi supporting the required swap + transpose.
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
// ============================================================ MIXED INPUT NO SCALES ============================================================================
|
||||
// The collective will infer that the narrow type should be upcasted to the wide type.
|
||||
// We swap A and B operands to the builder here
|
||||
using CollectiveMainloopConvertOnly = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementB, LayoutB_Transpose, AlignmentB,
|
||||
ElementA, LayoutA_Transpose, AlignmentA,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
|
||||
>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelConvertOnly = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveMainloopConvertOnly,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
using GemmConvertOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelConvertOnly>;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
// =========================================================== MIXED INPUT WITH SCALES ===========================================================================
|
||||
// The Scale information must get paired with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the scale information.
|
||||
using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
cute::tuple<ElementB, ElementScale>, LayoutB_Transpose, AlignmentB,
|
||||
ElementA, LayoutA_Transpose, AlignmentA,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
|
||||
>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloopScaleOnly,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleOnly>;
|
||||
|
||||
// =========================================================== MIXED INPUT WITH SCALES AND ZEROS ==================================================================
|
||||
// We specify scale + zero elements to indicate that we require both. Scales and biases have the same format.
|
||||
using CollectiveMainloopScaleWithZeroPoint = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
cute::tuple<ElementB, ElementScale, ElementZero>, LayoutB_Transpose, AlignmentB,
|
||||
ElementA, LayoutA_Transpose, AlignmentA,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
|
||||
>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelScaleWithZeroPoint = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloopScaleWithZeroPoint,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using GemmScaleWithZeroPoint = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleWithZeroPoint>;
|
||||
// =================================================================================================================================================================
|
||||
|
||||
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
|
||||
using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>;
|
||||
using StrideC = typename GemmKernelScaleWithZeroPoint::StrideC;
|
||||
using StrideD = typename GemmKernelScaleWithZeroPoint::StrideD;
|
||||
|
||||
using StrideC_ref = cutlass::detail::TagToStrideC_t<LayoutC>;
|
||||
using StrideD_ref = cutlass::detail::TagToStrideC_t<LayoutD>;
|
||||
|
||||
//
|
||||
// Data members
|
||||
@ -163,20 +257,25 @@ using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
StrideA stride_A;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideC_ref stride_C_ref;
|
||||
StrideD stride_D;
|
||||
StrideD_ref stride_D_ref;
|
||||
uint64_t seed;
|
||||
|
||||
// Initialization functions don't handle sub-byte types so we use uint8 to initialize and a separate
|
||||
// kernel to pack the data if it is necessary.
|
||||
using InitializationType = cute::conditional_t<cute::sizeof_bits_v<QuantType> < 8, uint8_t, QuantType>;
|
||||
// Scale and Zero share a stride since the layout and shapes must be the same.
|
||||
using StrideS = typename CollectiveMainloopScaleWithZeroPoint::StrideScale;
|
||||
using StrideS_ref = cutlass::detail::TagToStrideB_t<LayoutScale>;
|
||||
StrideS stride_S;
|
||||
StrideS_ref stride_S_ref;
|
||||
|
||||
cutlass::HostTensor<typename Gemm::ElementA, LayoutA> tensor_A;
|
||||
cutlass::HostTensor<InitializationType, LayoutB> tensor_B_init;
|
||||
cutlass::HostTensor<typename Gemm::ElementB, LayoutB> tensor_B;
|
||||
cutlass::HostTensor<typename Gemm::ElementA, LayoutB> tensor_B_dq;
|
||||
cutlass::HostTensor<typename Gemm::ElementC, LayoutC> tensor_C;
|
||||
cutlass::HostTensor<typename Gemm::EpilogueOutputOp::ElementOutput, LayoutD> tensor_D;
|
||||
cutlass::HostTensor<typename Gemm::EpilogueOutputOp::ElementOutput, LayoutD> tensor_ref_D;
|
||||
cutlass::HostTensor<MmaType, LayoutA> tensor_A;
|
||||
cutlass::HostTensor<QuantType, LayoutB> tensor_B;
|
||||
cutlass::HostTensor<MmaType, LayoutB> tensor_B_dq;
|
||||
cutlass::HostTensor<ElementScale, LayoutScale> tensor_scale;
|
||||
cutlass::HostTensor<ElementZero, LayoutScale> tensor_zero;
|
||||
cutlass::HostTensor<ElementC, LayoutC> tensor_C;
|
||||
cutlass::HostTensor<typename GemmScaleWithZeroPoint::EpilogueOutputOp::ElementOutput, LayoutD> tensor_D;
|
||||
cutlass::HostTensor<typename GemmScaleWithZeroPoint::EpilogueOutputOp::ElementOutput, LayoutD> tensor_ref_D;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
@ -192,7 +291,9 @@ struct Options {
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
int iterations = 1000;
|
||||
int mode = 2;
|
||||
int m = 5120, n = 4096, k = 4096;
|
||||
int g = 128;
|
||||
int l = 1;
|
||||
|
||||
// Parses the command line
|
||||
@ -208,6 +309,8 @@ struct Options {
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("l", l);
|
||||
cmd.get_cmd_line_argument("g", g);
|
||||
cmd.get_cmd_line_argument("mode", mode);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
@ -224,13 +327,15 @@ struct Options {
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --l=<int> The number of independent gemm problems with mnk shape\n"
|
||||
<< " --g=<int> The size of each group for the scales and zeros. To broadcast a vector of scales or zeros, set the group size to K.\n"
|
||||
<< " --mode=<int> The mode to run the gemm. 0 does (A @ B), 1 means A @ (scale * B), 2 means A @ (scale * B + zero-point).\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "55_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 --l=10 --alpha=2 --beta=0.707 \n\n";
|
||||
<< "$ " << "55_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 -g 0 --l=10 --alpha=2 --mode=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
@ -289,114 +394,146 @@ bool initialize_tensor(
|
||||
scope_min = -8;
|
||||
}
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
view, seed, scope_max, scope_min);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class QuantElement, typename Element, typename Layout>
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_quant_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
uint64_t seed=2023) {
|
||||
|
||||
Element scope_max, scope_min;
|
||||
constexpr int bits_input = cute::sizeof_bits_v<QuantElement>;
|
||||
static_assert(bits_input <= 8, "Quantization type can be at most 8 bits");
|
||||
|
||||
if constexpr (bits_input == 8) {
|
||||
// Directly init 1-byte types
|
||||
static_assert(cute::is_same_v<QuantElement, Element>, "Init type should equal quant type for 1 byte types");
|
||||
scope_max = std::numeric_limits<QuantElement>::max();
|
||||
scope_min = std::numeric_limits<QuantElement>::min();
|
||||
} else {
|
||||
static_assert(cute::is_same_v<uint8_t, Element>, "Init type should be uint8_t for sub-byte types");
|
||||
scope_max = (1 << bits_input);
|
||||
scope_min = 0;
|
||||
}
|
||||
float scope_min = float(cutlass::platform::numeric_limits<Element>::lowest());
|
||||
float scope_max = float(cutlass::platform::numeric_limits<Element>::max());
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
view, seed, scope_max, scope_min);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class Element, class Layout>
|
||||
bool initialize_with_one(
|
||||
cutlass::TensorView<Element, Layout> view) {
|
||||
cutlass::reference::host::TensorFill(view, Element(1.0f));
|
||||
bool initialize_scale(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
const Options &options) {
|
||||
|
||||
if (options.mode == GemmMode::ConvertOnly) {
|
||||
// No scales, so just initialize with 1 so we can use the same kernel to dequantize the data.
|
||||
cutlass::reference::host::TensorFill(view, Element(1.0f));
|
||||
}
|
||||
else {
|
||||
float elt_max_f = float(cutlass::platform::numeric_limits<QuantType>::max());
|
||||
const float max_dequant_val = 4.f;
|
||||
const float min_dequant_val = 0.5f;
|
||||
|
||||
float scope_max(max_dequant_val / elt_max_f);
|
||||
float scope_min(min_dequant_val / elt_max_f);
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class ElementDst, class ElementSrc, class Layout, class L>
|
||||
void prepare_packed_data(cutlass::HostTensor<ElementDst, Layout> view_dst_data,
|
||||
cutlass::HostTensor<ElementSrc, Layout> view_src_data,
|
||||
const L& cute_layout) {
|
||||
if constexpr (cute::is_same_v<ElementSrc, ElementDst>) {
|
||||
view_dst_data.copy_in_device_to_device(view_src_data.device_data());
|
||||
}
|
||||
else {
|
||||
pack_data(view_dst_data.device_data(), view_src_data.device_data(), cute_layout);
|
||||
template <class Element, class Layout>
|
||||
bool initialize_zero(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
const Options &options) {
|
||||
|
||||
if (options.mode == GemmMode::ScaleWithZeroPoint) {
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, 2.0f, -2.0f);
|
||||
} else {
|
||||
// No bias, so just initialize with 1 so we can use the same kernel to dequantize the data.
|
||||
cutlass::reference::host::TensorFill(view, Element(0.0f));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
|
||||
auto shape_b = cute::make_shape(options.n, options.k, options.l);
|
||||
const int scale_k = (options.k + options.g - 1) / options.g;
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_b);
|
||||
// Reverse stride here due to swap and transpose
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.n, options.m, options.l));
|
||||
stride_C_ref = cutlass::make_cute_packed_stride(StrideC_ref{}, cute::make_shape(options.m, options.n, options.l));
|
||||
// Reverse stride here due to swap and transpose
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.n, options.m, options.l));
|
||||
stride_D_ref = cutlass::make_cute_packed_stride(StrideD_ref{}, cute::make_shape(options.m, options.n, options.l));
|
||||
|
||||
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
|
||||
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
|
||||
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
|
||||
|
||||
tensor_A.resize(a_coord);
|
||||
tensor_B_init.resize(b_coord);
|
||||
tensor_B.resize(b_coord);
|
||||
tensor_B_dq.resize(b_coord);
|
||||
tensor_C.resize(c_coord);
|
||||
tensor_D.resize(c_coord);
|
||||
tensor_ref_D.resize(c_coord);
|
||||
|
||||
// We need scales since the "dequantize" kernels expects them. We just set them to 1 so the values get converted
|
||||
// to the mma type.
|
||||
cutlass::HostTensor<MmaType, cutlass::layout::RowMajor> tensor_scale;
|
||||
tensor_scale.resize({1 * options.l, options.n});
|
||||
tensor_scale.resize({scale_k * options.l, options.n});
|
||||
tensor_zero.resize({scale_k * options.l, options.n});
|
||||
|
||||
initialize_tensor(tensor_A.host_view(), seed + 2022);
|
||||
initialize_quant_tensor<QuantType>(tensor_B_init.host_view(), seed + 2021);
|
||||
initialize_quant_tensor(tensor_B.host_view(), seed + 2021);
|
||||
initialize_tensor(tensor_C.host_view(), seed + 2020);
|
||||
initialize_with_one(tensor_scale.host_view());
|
||||
initialize_scale(tensor_scale.host_view(), options);
|
||||
initialize_zero(tensor_zero.host_view(), options);
|
||||
|
||||
tensor_A.sync_device();
|
||||
tensor_B_init.sync_device();
|
||||
tensor_B.sync_device();
|
||||
tensor_C.sync_device();
|
||||
tensor_scale.sync_device();
|
||||
tensor_zero.sync_device();
|
||||
|
||||
auto layout_B = make_layout(shape_b, stride_B);
|
||||
prepare_packed_data(tensor_B, tensor_B_init, layout_B);
|
||||
|
||||
auto shape_scale = cute::make_shape(options.n, 1, options.l);
|
||||
auto layout_scale = make_layout(shape_scale);
|
||||
dequantize_weight(tensor_B_dq.device_data(), tensor_B.device_data(), layout_B, tensor_scale.device_data(), layout_scale);
|
||||
auto shape_scale_zero = cute::make_shape(options.n, scale_k, options.l);
|
||||
stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(options.n, scale_k, options.l));
|
||||
stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l));
|
||||
auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref);
|
||||
|
||||
tensor_B.sync_host();
|
||||
dequantize_weight(tensor_B_dq.device_data(), tensor_B.device_data(), layout_B, tensor_scale.device_data(), tensor_zero.device_data(), layout_scale_zero, options.g);
|
||||
tensor_B_dq.sync_host();
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options)
|
||||
template <typename Args>
|
||||
Args args_from_options(const Options &options)
|
||||
{
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, options.l},
|
||||
{tensor_A.device_data(), stride_A, tensor_B.device_data(), stride_B},
|
||||
{{options.alpha, options.beta}, tensor_C.device_data(), stride_C, tensor_D.device_data(), stride_D}
|
||||
};
|
||||
|
||||
return arguments;
|
||||
// Swap the A and B tensors, as well as problem shapes here.
|
||||
if (options.mode == GemmMode::ConvertOnly) {
|
||||
return Args {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.n, options.m, options.k, options.l},
|
||||
{tensor_B.device_data(), stride_B, tensor_A.device_data(), stride_A},
|
||||
{{options.alpha, options.beta}, tensor_C.device_data(), stride_C, tensor_D.device_data(), stride_D}
|
||||
};
|
||||
}
|
||||
else if (options.mode == GemmMode::ScaleOnly) {
|
||||
return Args {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.n, options.m, options.k, options.l},
|
||||
{tensor_B.device_data(), stride_B, tensor_A.device_data(), stride_A, tensor_scale.device_data(), stride_S, options.g},
|
||||
{{options.alpha, options.beta}, tensor_C.device_data(), stride_C, tensor_D.device_data(), stride_D}
|
||||
};
|
||||
}
|
||||
else if (options.mode == GemmMode::ScaleWithZeroPoint) {
|
||||
return Args {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.n, options.m, options.k, options.l},
|
||||
{tensor_B.device_data(), stride_B, tensor_A.device_data(), stride_A, tensor_scale.device_data(), stride_S, options.g, tensor_zero.device_data()},
|
||||
{{options.alpha, options.beta}, tensor_C.device_data(), stride_C, tensor_D.device_data(), stride_D}
|
||||
};
|
||||
} else {
|
||||
std::cerr << "Invalid mode " << options.mode << ". Must be 0, 1 or 2." << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
@ -404,44 +541,64 @@ bool verify(const Options &options) {
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
auto A = cute::make_tensor(tensor_A.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A));
|
||||
auto B = cute::make_tensor(tensor_B_dq.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B));
|
||||
auto C = cute::make_tensor(tensor_C.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_C));
|
||||
auto D = cute::make_tensor(tensor_ref_D.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D));
|
||||
// In this example, we use the GPU default kernels as a reference (unfused scale)
|
||||
// This is to avoid numerical differences from different accumulation order.
|
||||
|
||||
using unused_t = decltype(D);
|
||||
// Again, due to numerical differences, we must use fast acc here when the mma type is
|
||||
// FP8 as the fused implementation only supports fast acc at the moment.
|
||||
constexpr bool IsFP8Input = cute::is_same_v<MmaType, cutlass::float_e4m3_t> || cute::is_same_v<MmaType, cutlass::float_e5m2_t>;
|
||||
using FP8Sched = cute::conditional_t<size<0>(TileShape{}) == 64, cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum>;
|
||||
using ScheduleRef = cute::conditional_t<IsFP8Input, FP8Sched, cutlass::gemm::collective::KernelScheduleAuto>;
|
||||
|
||||
cutlass::reference::host::GettMainloopParams<ElementAccumulator, decltype(A), decltype(B)> mainloop_params{A, B};
|
||||
|
||||
cutlass::reference::host::GettEpilogueParams<
|
||||
typename Gemm::EpilogueOutputOp::ElementScalar,
|
||||
typename Gemm::EpilogueOutputOp::ElementScalar,
|
||||
using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
MmaType, LayoutA, AlignmentA,
|
||||
MmaType, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
decltype(C),
|
||||
decltype(D),
|
||||
unused_t, // bias
|
||||
unused_t, // aux
|
||||
unused_t, // valpha
|
||||
unused_t // vbeta
|
||||
> epilogue_params;
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
ScheduleRef
|
||||
>::CollectiveOp;
|
||||
|
||||
epilogue_params.C = C;
|
||||
epilogue_params.D = D;
|
||||
epilogue_params.alpha = options.alpha;
|
||||
epilogue_params.beta = options.beta;
|
||||
using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutD, AlignmentD,
|
||||
cutlass::epilogue::NoSmemWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
// get reference result
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloopRef,
|
||||
CollectiveEpilogueRef
|
||||
>;
|
||||
|
||||
using GemmRef = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelRef>;
|
||||
|
||||
typename GemmRef::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, options.l},
|
||||
{tensor_A.device_data(), stride_A, tensor_B_dq.device_data(), stride_B},
|
||||
{{options.alpha, options.beta}, tensor_C.device_data(), stride_C_ref, tensor_ref_D.device_data(), stride_D_ref}
|
||||
};
|
||||
|
||||
// Run the gemm where the scaling is performed outside of the kernel.
|
||||
GemmRef gemm_ref;
|
||||
size_t workspace_size = GemmRef::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
CUTLASS_CHECK(gemm_ref.can_implement(arguments));
|
||||
CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm_ref.run());
|
||||
|
||||
// compare_reference
|
||||
tensor_D.sync_host();
|
||||
bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view());
|
||||
tensor_ref_D.sync_host();
|
||||
const ElementD epsilon(1e-2f);
|
||||
const ElementD non_zero_floor(1e-4f);
|
||||
bool passed = cutlass::reference::host::TensorRelativelyEquals(tensor_ref_D.host_view(), tensor_D.host_view(), epsilon, non_zero_floor);
|
||||
return passed;
|
||||
}
|
||||
|
||||
@ -455,7 +612,7 @@ int run(Options &options)
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
auto arguments = args_from_options<typename Gemm::Arguments>(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
@ -549,7 +706,26 @@ int main(int argc, char const **args) {
|
||||
//
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
if (options.mode == GemmMode::ConvertOnly) {
|
||||
std::cout << "Running in no scale mode." << std::endl;
|
||||
run<GemmConvertOnly>(options);
|
||||
}
|
||||
else if (options.mode == GemmMode::ScaleOnly) {
|
||||
if (options.g == options.k) {
|
||||
std::cout << "Running in per-column scale mode." << std::endl;
|
||||
} else {
|
||||
std::cout << "Running in group scale mode." << std::endl;
|
||||
}
|
||||
run<GemmScaleOnly>(options);
|
||||
}
|
||||
else if (options.mode == GemmMode::ScaleWithZeroPoint) {
|
||||
if (options.g == options.k) {
|
||||
std::cout << "Running in per-column scale and zero mode." << std::endl;
|
||||
} else {
|
||||
std::cout << "Running in group scale and zero mode." << std::endl;
|
||||
}
|
||||
run<GemmScaleWithZeroPoint>(options);
|
||||
}
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
|
||||
@ -27,9 +27,33 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
# Note that we set --iterations=0 for all tests below to disable the performance benchmarking.
|
||||
# Only the correctness check will be run by these commands.
|
||||
|
||||
set(TEST_DIRECT_BATCHED --m=2048 --n=2048 --k=2048 --l=2 --mode=0 --iterations=0) # Direct conversion
|
||||
|
||||
set(TEST_SCALE_PERCOL --m=4096 --n=5120 --k=8192 --g=8192 --mode=1 --iterations=0) # Per Column scaling
|
||||
set(TEST_SCALE_ZERO_PERCOL --m=4096 --n=5120 --k=8192 --g=8192 --mode=2 --iterations=0) # Per Column scaling
|
||||
|
||||
set(TEST_SCALE_GROUP --m=2048 --n=5120 --k=8192 --g=512 --mode=1 --iterations=0) # Group-wise scaling
|
||||
set(TEST_SCALE_ZERO_GROUPED --m=2048 --n=5120 --k=8192 --g=256 --mode=2 --iterations=0) # Group-wise scaling with zero-point
|
||||
|
||||
set(TEST_SCALE_RESIDUE --m=128 --n=128 --k=320 --g=128 --mode=1 --iterations=0) # Final group has residue
|
||||
set(TEST_SCALE_ZERO_RESIDUE --m=128 --n=128 --k=192 --g=128 --mode=2 --iterations=0) # Final group has residue
|
||||
|
||||
set(TEST_ALPHA_BETA --alpha=0.5 --beta=0.7 --mode=2 --iterations=0) # Alpha and Beta with default shapes
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
55_hopper_mixed_dtype_gemm
|
||||
55_hopper_mixed_dtype_gemm.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_DIRECT_BATCHED
|
||||
TEST_SCALE_PERCOL
|
||||
TEST_SCALE_ZERO_PERCOL
|
||||
TEST_SCALE_GROUP
|
||||
TEST_SCALE_ZERO_GROUPED
|
||||
TEST_SCALE_RESIDUE
|
||||
TEST_SCALE_ZERO_RESIDUE
|
||||
TEST_ALPHA_BETA
|
||||
)
|
||||
|
||||
@ -9,17 +9,14 @@ This first version only supports mixed type GEMMs using TMA.
|
||||
|
||||
## Performance
|
||||
|
||||
While the example offers a harness for straightforward benchmarking, this initial implementation isn't optimized for performance in the majority of scenarios. We expect this implementation to be performant for `{fp16, bf16} x int8` for problems that are compute bound.
|
||||
While the example offers a harness for straightforward benchmarking, this initial implementation isn't optimized for performance in the majority of scenarios. We expect this implementation to be performant for `{fp16, bf16} x {int8, int4}` and `{fp8} x {int4}` for problems that are compute bound. Additionally, we expect good performance for `fp16, bf16` or `fp32` scales and zero-points. For best performance, it is ideal to have the scales and zero-points be the same type.
|
||||
|
||||
We are currently optimizing the following cases:
|
||||
1. Memory bound cases for all types
|
||||
1. Compute bound cases for `{16-bit, 8-bit} x {4-bit, 2-bit}`
|
||||
|
||||
As a result, we do not suggest using this example as a benchmarking reference until all of our optimizations are complete (this will be clearly stated in this README in a future release).
|
||||
|
||||
## Limitations
|
||||
|
||||
* The type that needs to be converted must go through the register file. This means that the collective will swap and transpose whenever the type with fewer bits is the B operand. The user must be aware of when these swaps happen to control the layout of the epilogue as shown in the example. Note that TMA epilogues currently do not support swap + transpose, so non-tma epilogues must be used in this case. We plan to relax this limitation in a future release.
|
||||
* The type that needs to be converted must go through the register file. This means that the collective will swap and transpose whenever the type with fewer bits is the B operand. The user must be aware of when these swaps happen. Note that TMA epilogues currently do not support *implicit* swap + transpose, so non-tma epilogues must be used in this case. We plan to relax this limitation in a future release.
|
||||
|
||||
* The layout of the narrow type must be K-major. This means the following:
|
||||
* Narrow type is the A operand: Must be Row-Major
|
||||
@ -29,8 +26,12 @@ As a result, we do not suggest using this example as a benchmarking reference un
|
||||
|
||||
* TMA requires an alignment of 128 bits. As a result, for a type with `B` bits, `B x TILE_K` must be a multiple of 128 bits.
|
||||
|
||||
* The type of the scale and zero-point type must be two bytes or more.
|
||||
|
||||
* The group size must be equal to gemm-k size (indicating a broadcast), or it must be a multiple of the threadblock-k size.
|
||||
|
||||
## Upcoming features
|
||||
|
||||
* Support for applying scales after conversion, but before issuing tensor core math (input scale fusion) is planned for v3.4.
|
||||
* Optimizations for memory bound cases.
|
||||
|
||||
* Many optimizations for SOL performance.
|
||||
* Optimizations for scale and zero-point loading when the group size is not equal to the threadblock-k size.
|
||||
|
||||
@ -9,13 +9,15 @@ template <class QuantizedElement,
|
||||
class DequantizedElement,
|
||||
class OperandLayout,
|
||||
class ElementScale,
|
||||
class ElementZero,
|
||||
class ScaleBroadCastLayout,
|
||||
class ThrLayout>
|
||||
__global__ void dequantize_weight_kernel(DequantizedElement* dq_buffer,
|
||||
const QuantizedElement* q_buffer,
|
||||
const OperandLayout operand_layout,
|
||||
const ElementScale* scale_buffer,
|
||||
const ScaleBroadCastLayout broadcasted_scale_layout,
|
||||
QuantizedElement const* q_buffer,
|
||||
OperandLayout const operand_layout,
|
||||
ElementScale const* scale_buffer,
|
||||
ElementZero const* zero_buffer,
|
||||
ScaleBroadCastLayout const broadcasted_scale_layout,
|
||||
ThrLayout thr_layout) {
|
||||
using namespace cute;
|
||||
|
||||
@ -33,6 +35,7 @@ __global__ void dequantize_weight_kernel(DequantizedElement* dq_buffer,
|
||||
// While the scales are expected to have shape [MN, G, L] but with a stride to allow broadcasting
|
||||
// It is expected that K % G == 0
|
||||
Tensor gmem_scale_broadcasted = make_tensor(make_gmem_ptr(scale_buffer), broadcasted_scale_layout);
|
||||
Tensor gmem_zero_broadcasted = make_tensor(make_gmem_ptr(zero_buffer), broadcasted_scale_layout);
|
||||
|
||||
// Assign 1 thread per element in the thread block
|
||||
auto blk_shape = make_shape(size<0>(thr_layout), _1{}, _1{}); //
|
||||
@ -41,16 +44,21 @@ __global__ void dequantize_weight_kernel(DequantizedElement* dq_buffer,
|
||||
// Tile across the block
|
||||
auto gOp_dq = local_tile(gmem_op_dq, blk_shape, blk_coord);
|
||||
auto gScale = local_tile(gmem_scale_broadcasted, blk_shape, blk_coord);
|
||||
auto gZero = local_tile(gmem_zero_broadcasted, blk_shape, blk_coord);
|
||||
auto gOp_q = local_tile(gmem_op_q, blk_shape, blk_coord);
|
||||
|
||||
auto tOpDq_gOpDq = local_partition(gOp_dq, thr_layout, threadIdx.x);
|
||||
auto tScale_gScale = local_partition(gScale, thr_layout, threadIdx.x);
|
||||
auto tZero_gZero = local_partition(gZero, thr_layout, threadIdx.x);
|
||||
auto tOpQ_gOpQ = local_partition(gOp_q, thr_layout, threadIdx.x);
|
||||
|
||||
// Make a fragment of registers to hold gmem loads
|
||||
Tensor rmem_op_q = make_fragment_like(tOpQ_gOpQ(_, _, _, 0));
|
||||
Tensor rmem_scale = make_fragment_like(tScale_gScale(_, _, _, 0));
|
||||
Tensor rmem_zero = make_fragment_like(tZero_gZero(_, _, _, 0));
|
||||
Tensor rmem_op_dq = make_fragment_like(tOpDq_gOpDq(_, _, _, 0));
|
||||
Tensor rmem_op_scaled = make_fragment_like<ElementScale>(rmem_op_dq);
|
||||
Tensor rmem_zero_buf = make_fragment_like<ElementScale>(rmem_zero);
|
||||
|
||||
Tensor pred_id = make_identity_tensor(shape(operand_layout));
|
||||
auto pred_blk_tile = local_tile(pred_id, blk_shape, blk_coord);
|
||||
@ -63,8 +71,12 @@ __global__ void dequantize_weight_kernel(DequantizedElement* dq_buffer,
|
||||
if (thread_offset < size<0>(operand_layout)) {
|
||||
copy(tOpQ_gOpQ(_, _, _, ii), rmem_op_q);
|
||||
copy(tScale_gScale(_, _, _, ii), rmem_scale);
|
||||
transform(rmem_op_q, rmem_op_dq, [] (const QuantizedElement& elt) { return DequantizedElement(elt); } );
|
||||
transform(rmem_op_dq, rmem_scale, rmem_op_dq, multiplies{});
|
||||
copy(tZero_gZero(_, _, _, ii), rmem_zero);
|
||||
transform(rmem_op_q, rmem_op_scaled, [] (const QuantizedElement& elt) { return ElementScale(elt); } );
|
||||
transform(rmem_zero, rmem_zero_buf, [] (const ElementZero& elt) { return ElementScale(elt); } );
|
||||
transform(rmem_op_scaled, rmem_scale, rmem_op_scaled, multiplies{});
|
||||
transform(rmem_op_scaled, rmem_zero_buf, rmem_op_scaled, plus{});
|
||||
transform(rmem_op_scaled, rmem_op_dq, [] (const ElementScale& elt) { return DequantizedElement(elt); } );
|
||||
copy(rmem_op_dq, tOpDq_gOpDq(_, _, _, ii));
|
||||
}
|
||||
}
|
||||
@ -74,12 +86,15 @@ template <class QuantizedElement,
|
||||
class DequantizedElement,
|
||||
class OperandLayout,
|
||||
class ElementScale,
|
||||
class ElementZero,
|
||||
class ScaleLayout>
|
||||
void dequantize_weight(DequantizedElement* dq_buffer,
|
||||
const QuantizedElement* q_buffer,
|
||||
const OperandLayout operand_layout,
|
||||
const ElementScale* scale_buffer,
|
||||
const ScaleLayout scale_layout) {
|
||||
QuantizedElement const* q_buffer,
|
||||
OperandLayout const operand_layout,
|
||||
ElementScale const* scale_buffer,
|
||||
ElementZero const* zero_buffer,
|
||||
ScaleLayout const scale_layout,
|
||||
int const group_size) {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
@ -87,9 +102,9 @@ void dequantize_weight(DequantizedElement* dq_buffer,
|
||||
auto thr_layout = make_layout(make_shape(Int<tpb>{}));
|
||||
|
||||
const auto num_rows = get<0>(shape(operand_layout));
|
||||
const auto num_cols = get<1>(shape(operand_layout)); // [MN, K, L]
|
||||
const auto gemm_k = get<1>(shape(operand_layout)); // [MN, K, L]
|
||||
const auto batches = get<2>(shape(operand_layout)); // [MN, K, L]
|
||||
const auto num_cols_scale = get<1>(shape(scale_layout)); // [MN, G, L]
|
||||
const auto scale_k = get<1>(shape(scale_layout)); // [MN, Scale_K, L]
|
||||
|
||||
if (num_rows != size<0>(scale_layout)) {
|
||||
std::cerr << "Invalid first dimension for scales. Must match first dim for weights."
|
||||
@ -98,81 +113,18 @@ void dequantize_weight(DequantizedElement* dq_buffer,
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
if (num_cols % num_cols_scale != 0) {
|
||||
std::cerr << "Invalid shape for weight / scales. Weight cols must be a multiple of scale cols."
|
||||
<< " But got shapes " << shape(operand_layout) << " " << shape(scale_layout)
|
||||
<< std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
const auto scale_stride0 = get<0>(stride(scale_layout));
|
||||
const auto scale_stride1 = get<1>(stride(scale_layout));
|
||||
const auto scale_stride2 = get<2>(stride(scale_layout));
|
||||
|
||||
auto scale_shape_bcast = make_shape(num_rows, make_shape(num_cols / num_cols_scale, num_cols_scale), batches);
|
||||
auto scale_shape_bcast = make_shape(num_rows, make_shape(group_size, scale_k), batches);
|
||||
auto scale_stride_bcast = make_stride(scale_stride0, make_stride(0, scale_stride1), scale_stride2);
|
||||
auto scale_layout_bcast = make_layout(scale_shape_bcast, scale_stride_bcast);
|
||||
|
||||
const auto blocks_x = num_cols;
|
||||
const auto blocks_x = gemm_k;
|
||||
const auto blocks_y = batches;
|
||||
|
||||
dim3 blocks(blocks_x, blocks_y, 1);
|
||||
dequantize_weight_kernel<<<blocks, tpb>>>(dq_buffer, q_buffer, operand_layout, scale_buffer, scale_layout_bcast, thr_layout);
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
|
||||
template <int ELTS_PER_THREAD,
|
||||
class SubbyteType>
|
||||
__global__ void pack_data_kernel(SubbyteType* packed_data_ptr,
|
||||
const uint8_t* unpacked_data_ptr,
|
||||
const size_t max_elts) {
|
||||
using namespace cute;
|
||||
|
||||
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
uint8_t data[ELTS_PER_THREAD];
|
||||
if (tid < max_elts) {
|
||||
const uint8_t* read_ptr = unpacked_data_ptr + tid * ELTS_PER_THREAD;
|
||||
for (int ii = 0; ii < ELTS_PER_THREAD; ++ii) {
|
||||
data[ii] = read_ptr[ii];
|
||||
}
|
||||
|
||||
|
||||
using WriteType = cute::array_subbyte<SubbyteType, ELTS_PER_THREAD>;
|
||||
WriteType* write_ptr = reinterpret_cast<WriteType*>(packed_data_ptr);
|
||||
|
||||
WriteType packed_data;
|
||||
for (int ii = 0; ii < ELTS_PER_THREAD; ++ii) {
|
||||
SubbyteType elt(data[ii]);
|
||||
packed_data[ii] = elt;
|
||||
}
|
||||
write_ptr[tid] = packed_data;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template <class SubbyteType,
|
||||
class OperandLayout>
|
||||
void pack_data(SubbyteType* packed_data, const uint8_t* unpacked_data, const OperandLayout operand_layout) {
|
||||
static_assert(cute::sizeof_bits_v<SubbyteType> < 8, "First operand must be a sub-byte type");
|
||||
constexpr int packed_elements = 8 / cute::sizeof_bits_v<SubbyteType>;
|
||||
|
||||
if (cute::stride<0>(operand_layout) == 1 && (cute::shape<0>(operand_layout) % packed_elements)) {
|
||||
std::cerr << "Invalid shape / stride for dimension 0. Contiguous dimension must be a multiple of "
|
||||
<< packed_elements << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
if (cute::stride<1>(operand_layout) == 1 && (cute::shape<1>(operand_layout) % packed_elements)) {
|
||||
std::cerr << "Invalid shape / stride for dimension 1. Contiguous dimension must be a multiple of "
|
||||
<< packed_elements << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
const int64_t total_threads = cute::size(operand_layout) / packed_elements;
|
||||
|
||||
const int threads_per_block = 256;
|
||||
const int64_t num_blocks = (total_threads + threads_per_block - 1) / threads_per_block;
|
||||
pack_data_kernel<packed_elements><<<int(num_blocks), threads_per_block>>>(packed_data, unpacked_data, total_threads);
|
||||
dequantize_weight_kernel<<<blocks, tpb>>>(dq_buffer, q_buffer, operand_layout, scale_buffer, zero_buffer, scale_layout_bcast, thr_layout);
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
@ -0,0 +1,520 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Hopper Ptr-Array Batched GEMM example using CUTLASS 3 APIs for NVIDIA Hopper architecture.
|
||||
|
||||
This example demonstrates an implementation of Ptr-Array Batched GEMM using a TMA + GMMA
|
||||
warp-specialized cooperative kernel.
|
||||
The new feature showcased in this example is on-the-fly modification of TMA descriptors
|
||||
to move between batches (represented by l).
|
||||
|
||||
To run this example:
|
||||
|
||||
$ ./examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm --m=2048 --n=2048 --k=2048 --l=10
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::half_t; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::half_t; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_256,_128,_64>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
|
||||
using KernelSchedule = cutlass::gemm::KernelArrayTmaWarpSpecializedCooperative; // Kernel to launch
|
||||
using EpilogueSchedule = cutlass::epilogue::NoSmemWarpSpecializedArray; // Epilogue to launch
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
EpilogueSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cutlass::gemm::ArrayProblemShape<Shape<int,int,int,int>>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Reference device GEMM implementation type
|
||||
using DeviceGemmReference = cutlass::reference::device::Gemm<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator>;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
StrideA stride_A;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
uint64_t seed;
|
||||
|
||||
std::vector<int64_t> offset_A;
|
||||
std::vector<int64_t> offset_B;
|
||||
std::vector<int64_t> offset_C;
|
||||
std::vector<int64_t> offset_D;
|
||||
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementB> block_B;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementC> block_C;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_D;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_ref_D;
|
||||
|
||||
cutlass::DeviceAllocation<const typename Gemm::ElementA *> ptr_A;
|
||||
cutlass::DeviceAllocation<const typename Gemm::ElementB *> ptr_B;
|
||||
cutlass::DeviceAllocation<const typename Gemm::ElementC *> ptr_C;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput *> ptr_D;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput *> ptr_ref_D;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
int iterations = 10;
|
||||
int m = 1024, n = 512, k = 1024, l = 10;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("l", l);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "56_hopper_ptr_array_batched_gemm\n\n"
|
||||
<< " Hopper FP32 GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --l=<int> Sets the batch count for Ptr-Array GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "56_hopper_ptr_array_batched_gemm" << " --m=1024 --n=512 --k=1024 --l=10 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k * l;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms = 0.0;
|
||||
double gflops = 0.0;
|
||||
cutlass::Status status = cutlass::Status::kSuccess;
|
||||
cudaError_t error = cudaSuccess;
|
||||
bool passed = false;
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
bool initialize_block(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023) {
|
||||
|
||||
Element scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
} else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Allocates device-side data
|
||||
void allocate(const Options &options) {
|
||||
int64_t total_elements_A = 0;
|
||||
int64_t total_elements_B = 0;
|
||||
int64_t total_elements_C = 0;
|
||||
int64_t total_elements_D = 0;
|
||||
|
||||
for (int32_t i = 0; i < options.l; ++i) {
|
||||
|
||||
offset_A.push_back(total_elements_A);
|
||||
offset_B.push_back(total_elements_B);
|
||||
offset_C.push_back(total_elements_C);
|
||||
offset_D.push_back(total_elements_D);
|
||||
|
||||
int64_t elements_A = options.m * options.k;
|
||||
int64_t elements_B = options.k * options.n;
|
||||
int64_t elements_C = options.m * options.n;
|
||||
int64_t elements_D = options.m * options.n;
|
||||
|
||||
total_elements_A += elements_A;
|
||||
total_elements_B += elements_B;
|
||||
total_elements_C += elements_C;
|
||||
total_elements_D += elements_D;
|
||||
}
|
||||
|
||||
block_A.reset(total_elements_A);
|
||||
block_B.reset(total_elements_B);
|
||||
block_C.reset(total_elements_C);
|
||||
block_D.reset(total_elements_D);
|
||||
block_ref_D.reset(total_elements_D);
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l));
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l));
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l));
|
||||
|
||||
//
|
||||
// Assign pointers
|
||||
//
|
||||
|
||||
std::vector<ElementA *> ptr_A_host(options.l);
|
||||
std::vector<ElementB *> ptr_B_host(options.l);
|
||||
std::vector<ElementC *> ptr_C_host(options.l);
|
||||
std::vector<ElementC *> ptr_D_host(options.l);
|
||||
|
||||
for (int32_t i = 0; i < options.l; ++i) {
|
||||
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
|
||||
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
|
||||
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
|
||||
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
|
||||
}
|
||||
|
||||
ptr_A.reset(options.l);
|
||||
ptr_A.copy_from_host(ptr_A_host.data());
|
||||
|
||||
ptr_B.reset(options.l);
|
||||
ptr_B.copy_from_host(ptr_B_host.data());
|
||||
|
||||
ptr_C.reset(options.l);
|
||||
ptr_C.copy_from_host(ptr_C_host.data());
|
||||
|
||||
ptr_D.reset(options.l);
|
||||
ptr_D.copy_from_host(ptr_D_host.data());
|
||||
|
||||
initialize_block(block_A, seed + 2023);
|
||||
initialize_block(block_B, seed + 2022);
|
||||
initialize_block(block_C, seed + 2021);
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options)
|
||||
{
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
// to use a GPU other than that with device ID 0.
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kArray,
|
||||
{{options.m, options.n, options.k, options.l}},
|
||||
{ptr_A.get(), stride_A, ptr_B.get(), stride_B},
|
||||
{{options.alpha, options.beta}, ptr_C.get(), stride_C, ptr_D.get(), stride_D},
|
||||
hw_info
|
||||
};
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
bool passed = true;
|
||||
for (int32_t i = 0; i < options.l; ++i) {
|
||||
cutlass::TensorRef ref_A(block_A.get() + offset_A.at(i), Gemm::LayoutA::packed({options.m, options.k}));
|
||||
cutlass::TensorRef ref_B(block_B.get() + offset_B.at(i), Gemm::LayoutB::packed({options.k, options.n}));
|
||||
cutlass::TensorRef ref_C(block_C.get() + offset_C.at(i), Gemm::LayoutC::packed({options.m, options.n}));
|
||||
cutlass::TensorRef ref_D(block_ref_D.get() + offset_D.at(i), Gemm::LayoutD::packed({options.m, options.n}));
|
||||
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
DeviceGemmReference gemm_reference;
|
||||
|
||||
// Launch device reference gemm kernel
|
||||
gemm_reference(
|
||||
{options.m, options.n, options.k},
|
||||
ElementAccumulator(options.alpha),
|
||||
ref_A,
|
||||
ref_B,
|
||||
ElementAccumulator(options.beta),
|
||||
ref_C,
|
||||
ref_D);
|
||||
|
||||
// Wait for kernel to finish
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
passed &= cutlass::reference::device::BlockCompareEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), options.m * options.n);
|
||||
}
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
allocate(options);
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average setup and runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
|
||||
std::cout << " Batches : " << options.l << std::endl;
|
||||
std::cout << " Alpha, Beta : " << options.alpha << ',' << options.beta << std::endl;
|
||||
std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS : " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.3 Toolkit to run this example
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 3)) {
|
||||
std::cerr << "This example requires CUDA 12.3 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major < 9) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
52
examples/56_hopper_ptr_array_batched_gemm/CMakeLists.txt
Normal file
52
examples/56_hopper_ptr_array_batched_gemm/CMakeLists.txt
Normal file
@ -0,0 +1,52 @@
|
||||
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
# Note that we set --iterations=0 for all tests below to disable the performance benchmarking.
|
||||
# Only the correctness check will be run by these commands.
|
||||
|
||||
set(TEST_SQUARE --m=2048 --n=2048 --k=2048 -l=10 --iterations=0) # Square problem sizes
|
||||
set(TEST_SQUARE_LARGE_BATCH --m=2048 --n=2048 --k=2048 -l=500 --iterations=0) # Square problem sizes
|
||||
|
||||
set(TEST_EPILOGUE --alpha=0.5 --beta=0.7 --iterations=0) # Default problem sizes
|
||||
set(TEST_EPILOGUE_LARGE_BATCH --alpha=1.5 --beta=2.0 -l=500 --iterations=0) # Default problem sizes
|
||||
|
||||
set(TEST_SMALLK --m=2048 --n=5120 --k=128 --l=5 --iterations=0) # Small-k problem sizes
|
||||
set(TEST_SMALLK_LARGE_BATCH --m=1024 --n=512 --k=64 --l=500 --iterations=0) # Small-k problem sizes
|
||||
|
||||
cutlass_example_add_executable(
|
||||
56_hopper_ptr_array_batched_gemm
|
||||
56_hopper_ptr_array_batched_gemm.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_SQUARE
|
||||
TEST_SQUARE_LARGE_BATCH
|
||||
TEST_EPILOGUE
|
||||
TEST_EPILOGUE_LARGE_BATCH
|
||||
TEST_SMALLK
|
||||
TEST_SMALLK_LARGE_BATCH
|
||||
)
|
||||
677
examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu
Normal file
677
examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu
Normal file
@ -0,0 +1,677 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Hopper Grouped GEMM example using CUTLASS 3 APIs for NVIDIA Hopper architecture.
|
||||
|
||||
This example demonstrates an implementation of Grouped GEMM using a TMA + GMMA
|
||||
warp-specialized cooperative kernel.
|
||||
For this example all scheduling work is performed on the device.
|
||||
The new feature showcased in this example is on-the-fly modification of TMA descriptors
|
||||
to move between groups/problem_count (represented by groups).
|
||||
|
||||
To run this example:
|
||||
|
||||
$ ./examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm --m=2048 --n=2048 --k=2048 --groups=10
|
||||
|
||||
The above example command makes all 10 groups to be sized at the given m, n, k sizes.
|
||||
Skipping any of the problem dimensions randomizes it across the different groups.
|
||||
|
||||
To run this example for a set of problems using the benchmark option:
|
||||
|
||||
$ ./examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm --benchmark=./test_benchmark.txt
|
||||
|
||||
Where the test_benchmark.txt may look as such:
|
||||
0 256x512x128
|
||||
1 256x512x512
|
||||
2 512x256x128
|
||||
3 256x256x128
|
||||
4 256x512x1024
|
||||
5 1024x512x128 and so on
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
|
||||
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
|
||||
using ElementB = cutlass::float_e5m2_t; // Element type for B matrix operand
|
||||
using ElementC = float; // Element type for C and D matrix operands
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_256,_128,_64>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
|
||||
using KernelSchedule = cutlass::gemm::KernelGroupTmaWarpSpecializedCooperativeFP8FastAccum; // Kernel to launch
|
||||
using EpilogueSchedule = cutlass::epilogue::NoSmemWarpSpecializedGroup; // Epilogue to launch
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
EpilogueSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Reference device GEMM implementation type
|
||||
using DeviceGemmReference = cutlass::reference::device::Gemm<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator>;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
// Host-side allocations
|
||||
std::vector<int64_t> offset_A;
|
||||
std::vector<int64_t> offset_B;
|
||||
std::vector<int64_t> offset_C;
|
||||
std::vector<int64_t> offset_D;
|
||||
|
||||
std::vector<StrideA> stride_A_host;
|
||||
std::vector<StrideB> stride_B_host;
|
||||
std::vector<StrideC> stride_C_host;
|
||||
std::vector<StrideD> stride_D_host;
|
||||
|
||||
// Device-side allocations
|
||||
cutlass::DeviceAllocation<typename ProblemShape::UnderlyingProblemShape> problem_sizes;
|
||||
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementB> block_B;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementC> block_C;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_D;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_ref_D;
|
||||
|
||||
cutlass::DeviceAllocation<const typename Gemm::ElementA *> ptr_A;
|
||||
cutlass::DeviceAllocation<const typename Gemm::ElementB *> ptr_B;
|
||||
cutlass::DeviceAllocation<const typename Gemm::ElementC *> ptr_C;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput *> ptr_D;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput *> ptr_ref_D;
|
||||
|
||||
cutlass::DeviceAllocation<StrideA> stride_A;
|
||||
cutlass::DeviceAllocation<StrideB> stride_B;
|
||||
cutlass::DeviceAllocation<StrideC> stride_C;
|
||||
cutlass::DeviceAllocation<StrideD> stride_D;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
int iterations = 10;
|
||||
int m = 1024, n = 2048, k = 512, groups = 10;
|
||||
std::string benchmark_path;
|
||||
std::vector<typename ProblemShape::UnderlyingProblemShape> problem_sizes_host;
|
||||
int const tma_alignment_bits = 128;
|
||||
int const alignment = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("groups", groups);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("benchmark", benchmark_path);
|
||||
|
||||
// Decide how to initialize the problems
|
||||
if (!benchmark_path.empty()) {
|
||||
if (!benchmark_problems()) {
|
||||
problem_sizes_host.clear();
|
||||
return;
|
||||
}
|
||||
}
|
||||
else {
|
||||
randomize_problems(cmd);
|
||||
}
|
||||
}
|
||||
|
||||
void randomize_problems(cutlass::CommandLine &cmd) {
|
||||
int cmd_line_m = -1;
|
||||
int cmd_line_n = -1;
|
||||
int cmd_line_k = -1;
|
||||
|
||||
cmd.get_cmd_line_argument("m", cmd_line_m);
|
||||
cmd.get_cmd_line_argument("n", cmd_line_n);
|
||||
cmd.get_cmd_line_argument("k", cmd_line_k);
|
||||
|
||||
problem_sizes_host.reserve(groups);
|
||||
|
||||
for (int i = groups; i > 0; i--) {
|
||||
|
||||
int m = cmd_line_m;
|
||||
int n = cmd_line_n;
|
||||
int k = cmd_line_k;
|
||||
|
||||
if (m < 1) {
|
||||
m = ((rand() % 512) + 1);
|
||||
}
|
||||
|
||||
if (n < 1) {
|
||||
n = ((rand() % 512) + 1);
|
||||
}
|
||||
|
||||
if (k < 1) {
|
||||
k = alignment * ((rand() % 64) + 1);
|
||||
}
|
||||
problem_sizes_host.push_back({m, n, k});
|
||||
}
|
||||
}
|
||||
|
||||
/// Load a benchmark
|
||||
bool benchmark_problems() {
|
||||
std::ifstream file(benchmark_path);
|
||||
if (!file.good()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
while (file.good()) {
|
||||
|
||||
int idx = -1;
|
||||
std::string extent_str;
|
||||
|
||||
file >> idx >> extent_str;
|
||||
|
||||
if (idx < 0 || extent_str.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
cutlass::gemm::GemmCoord extent;
|
||||
std::vector<std::string> tokens;
|
||||
|
||||
cutlass::CommandLine::tokenize(tokens, extent_str, 'x');
|
||||
|
||||
for (int i = 0; i < int(tokens.size()); ++i) {
|
||||
int x = std::atoi(tokens.at(i).c_str());
|
||||
|
||||
// round up
|
||||
if (x % alignment) {
|
||||
x += (alignment - (x % alignment));
|
||||
}
|
||||
|
||||
extent.at(i) = x;
|
||||
}
|
||||
|
||||
if (extent.product()) {
|
||||
problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()});
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "57_hopper_grouped_gemm\n\n"
|
||||
<< " Hopper FP8 Grouped GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM for all groups\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM for all groups\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM for all groups\n"
|
||||
<< " --groups=<int> Sets the number of individual GEMM problems for Grouped GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform\n\n"
|
||||
<< " --benchmark=<str> Executes a benchmark problem size.\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "57_hopper_grouped_gemm" << " --m=1024 --n=512 --k=1024 --groups=10 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s, std::vector<typename ProblemShape::UnderlyingProblemShape> problem_sizes_host) const
|
||||
{
|
||||
// Number of real-valued multiply-adds
|
||||
uint64_t fmas = uint64_t();
|
||||
|
||||
for (auto const & problem : problem_sizes_host) {
|
||||
fmas += cute::size(problem);
|
||||
}
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * uint64_t(fmas);
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms = 0.0;
|
||||
double gflops = 0.0;
|
||||
cutlass::Status status = cutlass::Status::kSuccess;
|
||||
cudaError_t error = cudaSuccess;
|
||||
bool passed = false;
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
bool initialize_block(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023) {
|
||||
|
||||
Element scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = static_cast<Element>(2);
|
||||
scope_min = static_cast<Element>(0);
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = static_cast<Element>(2);
|
||||
scope_min = static_cast<Element>(-2);
|
||||
} else {
|
||||
scope_max = static_cast<Element>(8);
|
||||
scope_min = static_cast<Element>(-8);
|
||||
}
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Allocates device-side data
|
||||
void allocate(const Options &options) {
|
||||
int64_t total_elements_A = 0;
|
||||
int64_t total_elements_B = 0;
|
||||
int64_t total_elements_C = 0;
|
||||
int64_t total_elements_D = 0;
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
|
||||
auto problem = options.problem_sizes_host.at(i);
|
||||
auto M = get<0>(problem);
|
||||
auto N = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
|
||||
offset_A.push_back(total_elements_A);
|
||||
offset_B.push_back(total_elements_B);
|
||||
offset_C.push_back(total_elements_C);
|
||||
offset_D.push_back(total_elements_D);
|
||||
|
||||
int64_t elements_A = M * K;
|
||||
int64_t elements_B = K * N;
|
||||
int64_t elements_C = M * N;
|
||||
int64_t elements_D = M * N;
|
||||
|
||||
total_elements_A += elements_A;
|
||||
total_elements_B += elements_B;
|
||||
total_elements_C += elements_C;
|
||||
total_elements_D += elements_D;
|
||||
|
||||
stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, Int<1>{})));
|
||||
stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, Int<1>{})));
|
||||
stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, Int<1>{})));
|
||||
stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, Int<1>{})));
|
||||
}
|
||||
|
||||
block_A.reset(total_elements_A);
|
||||
block_B.reset(total_elements_B);
|
||||
block_C.reset(total_elements_C);
|
||||
block_D.reset(total_elements_D);
|
||||
block_ref_D.reset(total_elements_D);
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
|
||||
uint64_t seed = 2020;
|
||||
|
||||
problem_sizes.reset(options.groups);
|
||||
problem_sizes.copy_from_host(options.problem_sizes_host.data());
|
||||
|
||||
//
|
||||
// Assign pointers
|
||||
//
|
||||
|
||||
std::vector<ElementA *> ptr_A_host(options.groups);
|
||||
std::vector<ElementB *> ptr_B_host(options.groups);
|
||||
std::vector<ElementC *> ptr_C_host(options.groups);
|
||||
std::vector<ElementC *> ptr_D_host(options.groups);
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
|
||||
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
|
||||
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
|
||||
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
|
||||
}
|
||||
|
||||
ptr_A.reset(options.groups);
|
||||
ptr_A.copy_from_host(ptr_A_host.data());
|
||||
|
||||
ptr_B.reset(options.groups);
|
||||
ptr_B.copy_from_host(ptr_B_host.data());
|
||||
|
||||
ptr_C.reset(options.groups);
|
||||
ptr_C.copy_from_host(ptr_C_host.data());
|
||||
|
||||
ptr_D.reset(options.groups);
|
||||
ptr_D.copy_from_host(ptr_D_host.data());
|
||||
|
||||
stride_A.reset(options.groups);
|
||||
stride_A.copy_from_host(stride_A_host.data());
|
||||
|
||||
stride_B.reset(options.groups);
|
||||
stride_B.copy_from_host(stride_B_host.data());
|
||||
|
||||
stride_C.reset(options.groups);
|
||||
stride_C.copy_from_host(stride_C_host.data());
|
||||
|
||||
stride_D.reset(options.groups);
|
||||
stride_D.copy_from_host(stride_D_host.data());
|
||||
|
||||
initialize_block(block_A, seed + 2023);
|
||||
initialize_block(block_B, seed + 2022);
|
||||
initialize_block(block_C, seed + 2021);
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options)
|
||||
{
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
// to use a GPU other than that with device ID 0.
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), options.problem_sizes_host.data()},
|
||||
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()},
|
||||
{{options.alpha, options.beta}, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
|
||||
hw_info
|
||||
};
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
bool passed = true;
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
auto problem = options.problem_sizes_host.at(i);
|
||||
auto M = get<0>(problem);
|
||||
auto N = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
cutlass::TensorRef ref_A(block_A.get() + offset_A.at(i), Gemm::LayoutA::packed({M, K}));
|
||||
cutlass::TensorRef ref_B(block_B.get() + offset_B.at(i), Gemm::LayoutB::packed({K, N}));
|
||||
cutlass::TensorRef ref_C(block_C.get() + offset_C.at(i), Gemm::LayoutC::packed({M, N}));
|
||||
cutlass::TensorRef ref_D(block_ref_D.get() + offset_D.at(i), Gemm::LayoutD::packed({M, N}));
|
||||
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
DeviceGemmReference gemm_reference;
|
||||
|
||||
// Launch device reference gemm kernel
|
||||
gemm_reference(
|
||||
{M, N, K},
|
||||
ElementAccumulator(options.alpha),
|
||||
ref_A,
|
||||
ref_B,
|
||||
ElementAccumulator(options.beta),
|
||||
ref_C,
|
||||
ref_D);
|
||||
|
||||
// Wait for kernel to finish
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
passed &= cutlass::reference::device::BlockCompareEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N);
|
||||
#if 0
|
||||
std::cout << "Group: " << i << " Status: " << passed << std::endl;
|
||||
#endif
|
||||
}
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
allocate(options);
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average setup and runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0, options.problem_sizes_host);
|
||||
|
||||
std::cout << " Problem Sizes: " << std::endl;
|
||||
for (auto const & problem : options.problem_sizes_host) {
|
||||
std::cout << " " << problem << std::endl;
|
||||
}
|
||||
std::cout << " Groups : " << options.groups << std::endl;
|
||||
std::cout << " Alpha, Beta : " << options.alpha << ',' << options.beta << std::endl;
|
||||
std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS : " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.3 Toolkit to run this example
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 3)) {
|
||||
std::cerr << "This example requires CUDA 12.3 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major < 9) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
56
examples/57_hopper_grouped_gemm/CMakeLists.txt
Normal file
56
examples/57_hopper_grouped_gemm/CMakeLists.txt
Normal file
@ -0,0 +1,56 @@
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
# Note that we set --iterations=0 for all tests below to disable the performance benchmarking.
|
||||
# Only the correctness check will be run by these commands.
|
||||
|
||||
set(TEST_RANDOM --iterations=0) # Random problem sizes
|
||||
set(TEST_RANDOM_LARGE_GROUP --groups=500 --iterations=0) # Random problem sizes
|
||||
|
||||
set(TEST_EPILOGUE --alpha=0.5 --beta=0.7 --iterations=0) # Random problem sizes
|
||||
set(TEST_EPILOGUE_LARGE_GROUP --alpha=1.5 --beta=2.0 --groups=500 --iterations=0) # Random problem sizes
|
||||
|
||||
set(TEST_FIXED --m=2048 --n=5120 --k=8192 --groups=50 --iterations=0) # Fixed problem sizes
|
||||
set(TEST_FIXED_LARGE_GROUP --m=2048 --n=512 --k=512 --groups=512 --iterations=0) # Fixed problem sizes
|
||||
|
||||
set(TEST_RANDOM_PERF --iterations=10) # Random problem sizes
|
||||
set(TEST_RANDOM_PERF_LARGE_GROUP --groups=500 --iterations=10) # Random problem sizes
|
||||
|
||||
cutlass_example_add_executable(
|
||||
57_hopper_grouped_gemm
|
||||
57_hopper_grouped_gemm.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_RANDOM
|
||||
TEST_RANDOM_LARGE_GROUP
|
||||
TEST_EPILOGUE
|
||||
TEST_EPILOGUE_LARGE_GROUP
|
||||
TEST_FIXED
|
||||
TEST_FIXED_LARGE_GROUP
|
||||
TEST_RANDOM_PERF
|
||||
TEST_RANDOM_PERF_LARGE_GROUP
|
||||
)
|
||||
@ -37,9 +37,10 @@
|
||||
|
||||
cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
|
||||
|
||||
project(cutlass_import_example VERSION 0.1 LANGUAGES CXX CUDA)
|
||||
project(cutlass_import_example VERSION 0.2 LANGUAGES CXX CUDA)
|
||||
|
||||
if (DEFINED CUTLASS_DIR)
|
||||
if (CUTLASS_DIR)
|
||||
message(STATUS "Using CUTLASS specified at ${CUTLASS_DIR}.")
|
||||
list(APPEND CMAKE_PREFIX_PATH ${CUTLASS_DIR})
|
||||
endif()
|
||||
|
||||
|
||||
@ -44,7 +44,7 @@ function(cutlass_example_add_executable NAME)
|
||||
set(__DISABLE_TESTS OFF)
|
||||
endif()
|
||||
|
||||
cutlass_add_executable(${NAME} ${__UNPARSED_ARGUMENTS})
|
||||
cutlass_add_executable(${NAME} ${__UNPARSED_ARGUMENTS} BATCH_SOURCES OFF)
|
||||
|
||||
add_dependencies(cutlass_examples ${NAME})
|
||||
|
||||
@ -136,6 +136,8 @@ foreach(EXAMPLE
|
||||
53_hopper_gemm_permute
|
||||
54_hopper_fp8_warp_specialized_gemm
|
||||
55_hopper_mixed_dtype_gemm
|
||||
56_hopper_ptr_array_batched_gemm
|
||||
57_hopper_grouped_gemm
|
||||
)
|
||||
|
||||
add_subdirectory(${EXAMPLE})
|
||||
|
||||
@ -27,7 +27,14 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
sgemm_nt_1
|
||||
sgemm_nt_1.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
tiled_copy
|
||||
tiled_copy.cu
|
||||
)
|
||||
|
||||
|
||||
260
examples/cute/tutorial/tiled_copy.cu
Normal file
260
examples/cute/tutorial/tiled_copy.cu
Normal file
@ -0,0 +1,260 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include <thrust/host_vector.h>
|
||||
#include <thrust/device_vector.h>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include "cutlass/util/print_error.hpp"
|
||||
#include "cutlass/util/GPU_Clock.hpp"
|
||||
#include "cutlass/util/helper_cuda.hpp"
|
||||
|
||||
// This is a simple tutorial showing several ways to partition a tensor into tiles then
|
||||
// perform efficient, coalesced copies. This example also shows how to vectorize accesses
|
||||
// which may be a useful optimization or required for certain workloads.
|
||||
//
|
||||
// `copy_kernel()` and `copy_kernel_vectorized()` each assume a pair of tensors with
|
||||
// dimensions (m, n) have been partitioned via `tiled_divide()`.
|
||||
//
|
||||
// The result are a part of compatible tensors with dimensions ((M, N), m', n'), where
|
||||
// (M, N) denotes a statically sized tile, and m' and n' denote the number of such tiles
|
||||
// within the tensor.
|
||||
//
|
||||
// Each statically sized tile is mapped to a CUDA threadblock which performs efficient
|
||||
// loads and stores to Global Memory.
|
||||
//
|
||||
// `copy_kernel()` uses `cute::local_partition()` to partition the tensor and map
|
||||
// the result to threads using a striped indexing scheme. Threads themselve are arranged
|
||||
// in a (ThreadShape_M, ThreadShape_N) arrangement which is replicated over the tile.
|
||||
//
|
||||
// `copy_kernel_vectorized()` uses `cute::make_tiled_copy()` to perform a similar
|
||||
// partitioning using `cute::Copy_Atom` to perform vectorization. The actual vector
|
||||
// size is defined by `ThreadShape`.
|
||||
//
|
||||
// This example assumes the overall tensor shape is divisible by the tile size and
|
||||
// does not perform predication.
|
||||
|
||||
|
||||
/// Simple copy kernel.
|
||||
//
|
||||
// Uses local_partition() to partition a tile among threads arranged as (THR_M, THR_N).
|
||||
template <class TensorS, class TensorD, class ThreadLayout>
|
||||
__global__ void copy_kernel(TensorS S, TensorD D, ThreadLayout)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
// Slice the tiled tensors
|
||||
Tensor tile_S = S(make_coord(_,_), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N)
|
||||
Tensor tile_D = D(make_coord(_,_), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N)
|
||||
|
||||
// Construct a partitioning of the tile among threads with the given thread arrangement.
|
||||
|
||||
// Concept: Tensor Layout Index
|
||||
Tensor thr_tile_S = local_partition(tile_S, ThreadLayout{}, threadIdx.x);
|
||||
Tensor thr_tile_D = local_partition(tile_D, ThreadLayout{}, threadIdx.x);
|
||||
|
||||
// Construct a register-backed Tensor with the same shape as each thread's partition
|
||||
auto fragment = make_fragment_like(thr_tile_S);
|
||||
|
||||
// Copy from GMEM to RMEM and from RMEM to GMEM
|
||||
copy(thr_tile_S, fragment);
|
||||
copy(fragment, thr_tile_D);
|
||||
}
|
||||
|
||||
/// Vectorized copy kernel.
|
||||
///
|
||||
/// Uses `make_tiled_copy()` to perform a copy using vector instructions. This operation
|
||||
/// has the precondition that pointers are aligned to the vector size.
|
||||
///
|
||||
template <class TensorS, class TensorD, class ThreadLayout, class VecLayout>
|
||||
__global__ void copy_kernel_vectorized(TensorS S, TensorD D, ThreadLayout, VecLayout)
|
||||
{
|
||||
using namespace cute;
|
||||
using Element = typename TensorS::value_type;
|
||||
|
||||
// Slice the tensors to obtain a view into each tile.
|
||||
Tensor tile_S = S(make_coord(_, _), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N)
|
||||
Tensor tile_D = D(make_coord(_, _), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N)
|
||||
|
||||
// Define `AccessType` which controls the size of the actual memory access.
|
||||
using AccessType = cutlass::AlignedArray<Element, size(shape(VecLayout{}))>;
|
||||
|
||||
// A copy atom corresponds to one hardware memory access.
|
||||
using Atom = Copy_Atom<UniversalCopy<AccessType>, Element>;
|
||||
|
||||
// Construct tiled copy, a tiling of copy atoms.
|
||||
//
|
||||
// Note, this assumes the vector and thread layouts are aligned with contigous data
|
||||
// in GMEM. Alternative thread layouts are possible but may result in uncoalesced
|
||||
// reads. Alternative vector layouts are also possible, though incompatible layouts
|
||||
// will result in compile time errors.
|
||||
auto tiled_copy =
|
||||
make_tiled_copy(
|
||||
Atom{}, // access size
|
||||
ThreadLayout{}, // thread layout
|
||||
VecLayout{}); // vector layout (e.g. 4x1)
|
||||
|
||||
// Construct a Tensor corresponding to each thread's slice.
|
||||
auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x);
|
||||
|
||||
Tensor thr_tile_S = thr_copy.partition_S(tile_S);
|
||||
Tensor thr_tile_D = thr_copy.partition_D(tile_D);
|
||||
|
||||
// Construct a register-backed Tensor with the same shape as each thread's partition
|
||||
auto fragment = make_fragment_like(thr_tile_D);
|
||||
|
||||
// Copy from GMEM to RMEM and from RMEM to GMEM
|
||||
copy(tiled_copy, thr_tile_S, fragment);
|
||||
copy(tiled_copy, fragment, thr_tile_D);
|
||||
}
|
||||
|
||||
/// Helper to convert a shape to a dim3
|
||||
template <class Shape>
|
||||
dim3 shape_to_dim3(Shape shape)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
CUTE_STATIC_ASSERT_V(rank(shape) <= Int<3>{});
|
||||
auto result = append<3>(product_each(shape), 1u);
|
||||
|
||||
return dim3(get<0>(result), get<1>(result), get<2>(result));
|
||||
}
|
||||
|
||||
/// Main function
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
//
|
||||
// Given a 2D shape, perform an efficient copy
|
||||
//
|
||||
|
||||
using namespace cute;
|
||||
using Element = float;
|
||||
|
||||
// Define a tensor shape with dynamic extents (m, n)
|
||||
auto tensor_shape = make_shape(256, 512);
|
||||
|
||||
thrust::host_vector<Element> h_S(size(tensor_shape));
|
||||
thrust::host_vector<Element> h_D(size(tensor_shape));
|
||||
|
||||
//
|
||||
// Initialize
|
||||
//
|
||||
|
||||
for (size_t i = 0; i < h_S.size(); ++i) {
|
||||
h_S[i] = static_cast<Element>(i);
|
||||
h_D[i] = Element{};
|
||||
}
|
||||
|
||||
thrust::device_vector<Element> d_S = h_S;
|
||||
thrust::device_vector<Element> d_D = h_D;
|
||||
|
||||
//
|
||||
// Make tensors
|
||||
//
|
||||
|
||||
Tensor tensor_S = make_tensor(make_gmem_ptr(d_S.data().get()), make_layout(tensor_shape));
|
||||
Tensor tensor_D = make_tensor(make_gmem_ptr(d_D.data().get()), make_layout(tensor_shape));
|
||||
|
||||
//
|
||||
// Partition
|
||||
//
|
||||
|
||||
|
||||
// Define a statically sized block (M, N).
|
||||
//
|
||||
// Note, by convention, capital letters are used to represent static modes.
|
||||
auto block_shape = make_shape(Int<128>{}, Int<64>{});
|
||||
|
||||
if ((get<0>(tensor_shape) % get<0>(block_shape)) || (get<1>(tensor_shape) % get<1>(block_shape))) {
|
||||
std::cerr << "The tensor shape must be divisible by the block shape." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Tile the tensor (m, m) ==> ((M, N), m', n') where (M, N) is the static tile
|
||||
// shape, and modes (m', n') correspond to the number of tiles.
|
||||
//
|
||||
// These will be used to determine the CUDA kernel grid dimensinos.
|
||||
Tensor tiled_tensor_S = tiled_divide(tensor_S, block_shape);
|
||||
Tensor tiled_tensor_D = tiled_divide(tensor_D, block_shape);
|
||||
|
||||
// Thread arrangement
|
||||
Layout thr_layout = make_layout(make_shape(Int<32>{}, Int< 8>{}));
|
||||
|
||||
// Vector dimensions
|
||||
Layout vec_layout = make_layout(make_shape(Int<4>{}, Int<1>{}));
|
||||
|
||||
//
|
||||
// Determine grid and block dimensions
|
||||
//
|
||||
|
||||
dim3 gridDim = shape_to_dim3(select<1,2>(shape(tiled_tensor_D))); // Grid shape corresponds to modes m' and n'
|
||||
dim3 blockDim(size(shape(thr_layout)));
|
||||
|
||||
//
|
||||
// Launch the kernel
|
||||
//
|
||||
copy_kernel_vectorized<<< gridDim, blockDim >>>(
|
||||
tiled_tensor_S,
|
||||
tiled_tensor_D,
|
||||
thr_layout,
|
||||
vec_layout);
|
||||
|
||||
cudaError result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "CUDA Runtime error: " << cudaGetErrorString(result) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
//
|
||||
// Verify
|
||||
//
|
||||
|
||||
h_D = d_D;
|
||||
|
||||
int32_t errors = 0;
|
||||
int32_t const kErrorLimit = 10;
|
||||
|
||||
for (size_t i = 0; i < h_D.size(); ++i) {
|
||||
if (h_S[i] != h_D[i]) {
|
||||
std::cerr << "Error. S[" << i << "]: " << h_S[i] << ", D[" << i << "]: " << h_D[i] << std::endl;
|
||||
|
||||
if (++errors >= kErrorLimit) {
|
||||
std::cerr << "Aborting on " << kErrorLimit << "nth error." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Success." << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -357,7 +357,7 @@
|
||||
"## Handling errors\n",
|
||||
"The CUTLASS Python interface attempts to catch runtime and compilation errors in Python so as to provide more understandable error messages.\n",
|
||||
"\n",
|
||||
"Here's an example in which we try to use too many stages for a given GEMM kernel. Normally, this would result in a runtime error due to the GPU having insufficient shared memory to launch the kernel with 8 stages. The CUTLASS Python interface is able to detect this issue before compiling the kernel, and reports it back to the user."
|
||||
"Here's an example in which we try to use too many stages for a given GEMM kernel. Normally, this would result in a runtime error due to the GPU having insufficient shared memory to launch the kernel with 8 stages. The CUTLASS Python interface is able to detect this issue before compiling the kernel, and reports it back to the user. Uncomment and run the code below to see this error."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -371,6 +371,75 @@
|
||||
"# td.stages = 8\n",
|
||||
"# plan.compile(td)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Specializations for other data types\n",
|
||||
"\n",
|
||||
"Various CUTLASS kernels specialized for specific data types can also be run via the Python interface.\n",
|
||||
"\n",
|
||||
"For example, the code below shows how to declare and run a GEMM using the 3xTF32 feature (see corresponding C++ example [here](https://github.com/NVIDIA/cutlass/blob/main/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu))."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from cutlass.backend.utils.device import device_cc\n",
|
||||
"\n",
|
||||
"# 3xTF32 requires SM80 or higher\n",
|
||||
"if device_cc() >= 80:\n",
|
||||
" plan = cutlass.op.Gemm(element=np.float32, layout=cutlass.LayoutType.RowMajor)\n",
|
||||
" plan.math_operation = cutlass.MathOperation.multiply_add_fast_f32\n",
|
||||
"\n",
|
||||
" # Create input/output tensors in FP32\n",
|
||||
" A, B = [np.ones((128, 128)).astype(np.float32) for _ in range(2)]\n",
|
||||
" C, D = [np.zeros((128, 128)).astype(np.float32) for _ in range(2)]\n",
|
||||
"\n",
|
||||
" # Run the GEMM\n",
|
||||
" plan.run(A, B, C, D, print_module=print_module)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Additionally, one can run CUTLASS's FP8 GEMMs if using a frontend library capable of allocating and initializing FP8 tensors (e.g., PyTorch)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"try:\n",
|
||||
" import torch\n",
|
||||
"except ImportError:\n",
|
||||
" print(\"PyTorch is not available. Skipping FP8 example\")\n",
|
||||
" import sys; sys.exit(0)\n",
|
||||
"\n",
|
||||
"if not hasattr(torch, \"float8_e4m3fn\"):\n",
|
||||
" print(\"Version of PyTorch does not have the float8_e4m3fn data type. Skipping FP8 example\")\n",
|
||||
" import sys; sys.exit(0)\n",
|
||||
"\n",
|
||||
"# FP8 is supported through the CUTLASS Python interface on SM90 and higher\n",
|
||||
"if device_cc() >= 90:\n",
|
||||
" plan = cutlass.op.Gemm(element=torch.float8_e4m3fn, element_C=torch.float32, element_accumulator=torch.float32,\n",
|
||||
" layout_A=cutlass.LayoutType.RowMajor, layout_B=cutlass.LayoutType.ColumnMajor,\n",
|
||||
" layout_C=cutlass.LayoutType.ColumnMajor)\n",
|
||||
"\n",
|
||||
" # Create input/output tensors in FP8\n",
|
||||
" A, B = [torch.ones((128, 128)).to(torch.float8_e4m3fn).to(\"cuda\") for _ in range(2)]\n",
|
||||
" C, D = [torch.zeros((128, 128)).to(torch.float8_e4m3fn).to(\"cuda\") for _ in range(2)]\n",
|
||||
"\n",
|
||||
" # Run the GEMM\n",
|
||||
" plan.run(A, B, C, D, print_module=print_module)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
@ -134,7 +134,7 @@
|
||||
"id": "590a3bc5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We'll next run a group of 50 GEMMs via the CUTLASS Python interface and via PyTorch."
|
||||
"We'll next run a group of 20 GEMMs via the CUTLASS Python interface and via PyTorch."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -144,7 +144,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"As, Bs, Cs, Ds, = generate_problems(50)\n",
|
||||
"As, Bs, Cs, Ds, = generate_problems(20)\n",
|
||||
"\n",
|
||||
"plan.run(As, Bs, Cs, Ds, print_module=True)\n",
|
||||
"Ds_torch = [a @ b for a, b in zip(As, Bs)]\n",
|
||||
|
||||
@ -95,6 +95,7 @@ CUTE_DEVICE dim3 cluster_grid_dims()
|
||||
return gridDim;
|
||||
#elif defined(_MSC_VER)
|
||||
CUTE_RUNTIME_ASSERT("cluster_grid_dims() can only be called on device");
|
||||
return {0, 0, 0};
|
||||
#else
|
||||
return {0, 0, 0};
|
||||
#endif
|
||||
@ -114,6 +115,7 @@ CUTE_DEVICE dim3 cluster_id_in_grid()
|
||||
return blockIdx;
|
||||
#elif defined(_MSC_VER)
|
||||
CUTE_RUNTIME_ASSERT("cluster_id_in_grid() can only be called on device");
|
||||
return {0, 0, 0};
|
||||
#else
|
||||
return {0, 0, 0};
|
||||
#endif
|
||||
|
||||
@ -40,6 +40,11 @@
|
||||
# define CUTE_ARCH_TMA_SM90_ENABLED
|
||||
#endif
|
||||
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED) && \
|
||||
((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 3)))
|
||||
# define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED
|
||||
#endif
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
|
||||
@ -42,6 +42,7 @@
|
||||
|
||||
#include <cute/container/alignment.hpp>
|
||||
#include <cute/container/bit_field.hpp>
|
||||
#include <cute/container/array.hpp>
|
||||
#include <cute/numeric/int.hpp> // to_Format<[u]intX>
|
||||
#include <cute/numeric/half.hpp> // to_Format<half_t>
|
||||
|
||||
@ -200,6 +201,141 @@ prefetch_tma_descriptor(TmaDescriptor const* desc_ptr)
|
||||
#endif
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Perform a TensorMap modification (by each field)
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Replace tensor pointer directly in GMEM
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
tma_descriptor_replace_addr_in_global_mem(TmaDescriptor const* desc_ptr,
|
||||
void const* const new_tensor_ptr)
|
||||
{
|
||||
#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
uint64_t const new_desc_addr = reinterpret_cast<uint64_t>(new_tensor_ptr);
|
||||
asm volatile (
|
||||
"tensormap.replace.tile.global_address.global.b1024.b64 [%0], %1;"
|
||||
:: "l"(gmem_int_desc), "l"(new_desc_addr));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3");
|
||||
#endif
|
||||
}
|
||||
|
||||
// Replace tensor pointer by bringing the tensormap from GMEM into the shared memory
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
tma_descriptor_replace_addr_in_shared_mem(TmaDescriptor& smem_desc,
|
||||
void const* const new_tensor_ptr)
|
||||
{
|
||||
#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED)
|
||||
uint32_t smem_int_desc = cast_smem_ptr_to_uint(&smem_desc);
|
||||
uint64_t const new_desc_addr = reinterpret_cast<uint64_t>(new_tensor_ptr);
|
||||
uint64_t const smem_int64_desc = 0;
|
||||
asm volatile (
|
||||
"cvt.u64.u32 %0, %1;"
|
||||
:: "l"(smem_int64_desc), "r"(smem_int_desc));
|
||||
asm volatile (
|
||||
"tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;"
|
||||
:: "l"(smem_int64_desc), "l"(new_desc_addr));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3");
|
||||
#endif
|
||||
}
|
||||
|
||||
// Replace tensor dims and strides for GEMMs by bringing the tensormap from GMEM into the shared memory
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
tma_descriptor_replace_dims_strides_in_shared_mem(TmaDescriptor & smem_desc,
|
||||
cute::array<uint32_t, 3> const& prob_shape,
|
||||
cute::array<uint64_t, 3> const& prob_stride)
|
||||
{
|
||||
#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED)
|
||||
uint32_t smem_int_desc = cast_smem_ptr_to_uint(&smem_desc);
|
||||
uint64_t const smem_int64_desc = 0;
|
||||
asm volatile (
|
||||
"cvt.u64.u32 %0, %1;"
|
||||
:: "l"(smem_int64_desc), "r"(smem_int_desc));
|
||||
asm volatile (
|
||||
"tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 0, %1;"
|
||||
:: "l"(smem_int64_desc), "r"(prob_shape[0]));
|
||||
asm volatile (
|
||||
"tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 1, %1;"
|
||||
:: "l"(smem_int64_desc), "r"(prob_shape[1]));
|
||||
asm volatile (
|
||||
"tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 2, %1;"
|
||||
:: "l"(smem_int64_desc), "r"(prob_shape[2]));
|
||||
// Strides must be a multiple of 16. Also, stride for the intermost dimension is implicitly 1
|
||||
asm volatile (
|
||||
"tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;"
|
||||
:: "l"(smem_int64_desc), "l"(prob_stride[1] >> 4));
|
||||
asm volatile (
|
||||
"tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 1, %1;"
|
||||
:: "l"(smem_int64_desc), "l"(prob_stride[2] >> 4));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3");
|
||||
#endif
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Perform a fused copy and fence operation (needed when modifying tensormap in shared memory)
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
tma_descriptor_cp_fence_release(TmaDescriptor const* gmem_desc_ptr, TmaDescriptor& smem_desc)
|
||||
{
|
||||
#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(gmem_desc_ptr);
|
||||
uint32_t smem_int_desc = cast_smem_ptr_to_uint(&smem_desc);
|
||||
asm volatile (
|
||||
"tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned [%0], [%1], 128;"
|
||||
:: "l"(gmem_int_desc), "r"(smem_int_desc));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3");
|
||||
#endif
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Perform a release fence operation (needed when modifying tensormap directly in GMEM)
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
tma_descriptor_fence_release()
|
||||
{
|
||||
#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED)
|
||||
asm volatile ("fence.proxy.tensormap::generic.release.gpu;");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3");
|
||||
#endif
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Perform a acquire fence operation
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
tma_descriptor_fence_acquire(TmaDescriptor const* desc_ptr)
|
||||
{
|
||||
#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
asm volatile (
|
||||
"fence.proxy.tensormap::generic.acquire.gpu [%0], 128;"
|
||||
:
|
||||
: "l"(gmem_int_desc)
|
||||
: "memory");
|
||||
asm volatile (
|
||||
"cvta.global.u64 %0, %0;"
|
||||
:
|
||||
: "l"(gmem_int_desc), "l"(gmem_int_desc)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3");
|
||||
#endif
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
@ -775,6 +775,104 @@ struct SM90_TMA_STORE
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// TMA_STORE im2col: Initiates a TMA copy, in im2col mode, from shared memory to global memory
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct SM90_TMA_STORE_IM2COL_3D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.3d.global.shared::cta.im2col_no_offs.bulk_group"
|
||||
" [%0, {%2, %3, %4}], [%1];"
|
||||
:
|
||||
: "l"(gmem_int_desc), "r"(smem_int_ptr),
|
||||
"r"(coord_c), "r"(coord_w), "r"(coord_n)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_TMA_STORE_IM2COL_4D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.4d.global.shared::cta.im2col_no_offs.bulk_group"
|
||||
" [%0, {%2, %3, %4, %5}], [%1];"
|
||||
:
|
||||
: "l"(gmem_int_desc), "r"(smem_int_ptr),
|
||||
"r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_TMA_STORE_IM2COL_5D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.5d.global.shared::cta.im2col_no_offs.bulk_group"
|
||||
" [%0, {%2, %3, %4, %5, %6}], [%1];"
|
||||
:
|
||||
: "l"(gmem_int_desc), "r"(smem_int_ptr),
|
||||
"r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_d), "r"(coord_n)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_TMA_STORE_IM2COL
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n)
|
||||
{
|
||||
return SM90_TMA_STORE_IM2COL_3D::copy(desc_ptr, smem_ptr, coord_c, coord_w, coord_n);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n)
|
||||
{
|
||||
return SM90_TMA_STORE_IM2COL_4D::copy(desc_ptr, smem_ptr, coord_c, coord_w, coord_h, coord_n);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n)
|
||||
{
|
||||
return SM90_TMA_STORE_IM2COL_5D::copy(desc_ptr, smem_ptr, coord_c, coord_w, coord_h, coord_d, coord_n);
|
||||
}
|
||||
};
|
||||
|
||||
// Indicate arrival of warp issuing TMA_STORE
|
||||
CUTE_HOST_DEVICE static void
|
||||
tma_store_arrive() {
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
/**************************************************************************************************
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
|
||||
@ -31,26 +31,28 @@
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/arch/copy.hpp>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include <cute/atom/copy_traits.hpp>
|
||||
#include <cute/atom/mma_atom.hpp>
|
||||
|
||||
#include <cute/util/type_traits.hpp>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
template <class... Args>
|
||||
struct Copy_Atom;
|
||||
|
||||
template <class CopyOperation, class T>
|
||||
struct Copy_Atom<CopyOperation, T> : Copy_Atom<Copy_Traits<CopyOperation>, T>
|
||||
template <class CopyOperation, class CopyInternalType>
|
||||
struct Copy_Atom<CopyOperation, CopyInternalType> : Copy_Atom<Copy_Traits<CopyOperation>, CopyInternalType>
|
||||
{};
|
||||
|
||||
template <class... Args, class T>
|
||||
struct Copy_Atom<Copy_Traits<Args...>, T>
|
||||
template <class... Args, class CopyInternalType>
|
||||
struct Copy_Atom<Copy_Traits<Args...>, CopyInternalType>
|
||||
: Copy_Traits<Args...>
|
||||
{
|
||||
using Traits = Copy_Traits<Args...>;
|
||||
@ -61,7 +63,7 @@ struct Copy_Atom<Copy_Traits<Args...>, T>
|
||||
using BitLayoutDst = typename Traits::DstLayout;
|
||||
using BitLayoutRef = typename Traits::RefLayout;
|
||||
|
||||
using ValType = T;
|
||||
using ValType = CopyInternalType;
|
||||
|
||||
using ValLayoutSrc = decltype(upcast<sizeof_bits<ValType>::value>(BitLayoutSrc{}));
|
||||
using ValLayoutDst = decltype(upcast<sizeof_bits<ValType>::value>(BitLayoutDst{}));
|
||||
@ -80,7 +82,7 @@ struct Copy_Atom<Copy_Traits<Args...>, T>
|
||||
auto
|
||||
with(TraitsArgs&&... args) const {
|
||||
auto traits = Traits::with(std::forward<TraitsArgs>(args)...);
|
||||
return Copy_Atom<decltype(traits), T>{traits};
|
||||
return Copy_Atom<decltype(traits), CopyInternalType>{traits};
|
||||
}
|
||||
|
||||
//
|
||||
@ -88,19 +90,19 @@ struct Copy_Atom<Copy_Traits<Args...>, T>
|
||||
//
|
||||
|
||||
// Check and call instruction, or recurse
|
||||
template <class TS, class SLayout,
|
||||
class TD, class DLayout>
|
||||
template <class SEngine, class SLayout,
|
||||
class DEngine, class DLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
call(Tensor<TS,SLayout> const& src,
|
||||
Tensor<TD,DLayout> & dst) const
|
||||
call(Tensor<SEngine,SLayout> const& src,
|
||||
Tensor<DEngine,DLayout> & dst) const
|
||||
{
|
||||
static_assert(SLayout::rank == 1, "Expected rank-1 src tensor");
|
||||
static_assert(DLayout::rank == 1, "Expected rank-1 dst tensor");
|
||||
|
||||
if constexpr (is_constant<NumValSrc, decltype(size(src))>::value ||
|
||||
is_constant<NumValDst, decltype(size(dst))>::value) {
|
||||
// Dispatch to unpack for instruction
|
||||
// Dispatch to unpack to execute instruction
|
||||
return copy_unpack(*this, src, dst);
|
||||
} else
|
||||
if constexpr (is_tuple<decltype(shape(src))>::value &&
|
||||
@ -110,7 +112,7 @@ struct Copy_Atom<Copy_Traits<Args...>, T>
|
||||
// ((A,B,C,...)) -> (A,B,C,...)
|
||||
return copy(*this, tensor<0>(src), tensor<0>(dst));
|
||||
} else {
|
||||
static_assert(sizeof(TS) < 0, "No instruction match and no recursion possible.");
|
||||
static_assert(dependent_false<SEngine>, "No instruction match and no recursion possible.");
|
||||
}
|
||||
}
|
||||
|
||||
@ -135,7 +137,7 @@ struct ThrCopy;
|
||||
|
||||
template <class Copy_Atom,
|
||||
class LayoutCopy_TV, // (tid,vid) -> coord [Need not be 2D...]
|
||||
class ShapeTile_MN> // coord space
|
||||
class ShapeTiler_MN> // coord space
|
||||
struct TiledCopy : Copy_Atom
|
||||
{
|
||||
// Layout information from the CopyAtom
|
||||
@ -148,8 +150,7 @@ struct TiledCopy : Copy_Atom
|
||||
using AtomNumVal = decltype(size<1>(AtomLayoutRef{}));
|
||||
|
||||
// Layout information for the TiledCopy
|
||||
using Tiler_MN = ShapeTile_MN;
|
||||
using TiledShape_MN = decltype(shape(ShapeTile_MN{}));
|
||||
using Tiler_MN = ShapeTiler_MN;
|
||||
using TiledLayout_TV = LayoutCopy_TV;
|
||||
using TiledNumThr = decltype(size<0>(TiledLayout_TV{}));
|
||||
using TiledNumVal = decltype(size<1>(TiledLayout_TV{}));
|
||||
@ -172,12 +173,9 @@ struct TiledCopy : Copy_Atom
|
||||
auto
|
||||
tidfrg_S(STensor&& stensor)
|
||||
{
|
||||
constexpr int R = remove_cvref_t<STensor>::rank;
|
||||
static_assert(R >= rank_v<TiledShape_MN>, "Rank of tensor to be partitioned too small.");
|
||||
// Generalize the dimension checks for arbitrary rank
|
||||
//CUTE_STATIC_ASSERT_V(size<0>(stensor) % size<0>(TiledShape_MNK{}) == Int<0>{});
|
||||
//CUTE_STATIC_ASSERT_V(size<1>(stensor) % size<1>(TiledShape_MNK{}) == Int<0>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(stensor) >= rank(Tiler_MN{}), "Rank of tensor to be partitioned too small.");
|
||||
|
||||
// Tile the stensor and compute the (src-thr, src-val) -> (ref-thr, ref-val) layout
|
||||
return tile2thrfrg(zipped_divide(stensor,Tiler_MN{}), right_inverse(AtomLayoutRef{}).compose(AtomLayoutSrc{}));
|
||||
}
|
||||
|
||||
@ -196,17 +194,14 @@ struct TiledCopy : Copy_Atom
|
||||
auto
|
||||
tidfrg_D(DTensor&& dtensor)
|
||||
{
|
||||
constexpr int R = remove_cvref_t<DTensor>::rank;
|
||||
static_assert(R >= rank_v<TiledShape_MN>, "Rank of tensor to be partitioned too small.");
|
||||
// Generalize the dimension checks for arbitrary rank
|
||||
//CUTE_STATIC_ASSERT_V(size<0>(stensor) % size<0>(TiledShape_MNK{}) == Int<0>{});
|
||||
//CUTE_STATIC_ASSERT_V(size<1>(stensor) % size<1>(TiledShape_MNK{}) == Int<0>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(dtensor) >= rank(Tiler_MN{}), "Rank of tensor to be partitioned too small.");
|
||||
|
||||
// Tile the dtensor and compute the (dst-thr, dst-val) -> (ref-thr, ref-val) layout
|
||||
return tile2thrfrg(zipped_divide(dtensor,Tiler_MN{}), right_inverse(AtomLayoutRef{}).compose(AtomLayoutDst{}));
|
||||
}
|
||||
|
||||
// Tile a tensor or a layout from shape
|
||||
// (Tile,(RestM,RestN,...))
|
||||
// ((TileM,TileN,...), (RestM,RestN,...))
|
||||
// to shape
|
||||
// ((ThrV,ThrX),FrgV,(RestM,RestN,...))
|
||||
template <class Tensor, class Ref2TrgLayout>
|
||||
@ -232,7 +227,7 @@ struct TiledCopy : Copy_Atom
|
||||
|
||||
// Transform the tile mode
|
||||
auto tv_tensor = tensor.compose(thrval2mn, _);
|
||||
// ((thrid,val),(RM,RN,...))
|
||||
// ((thrid,val),(RestM,RestN,...))
|
||||
|
||||
// Unfold and return
|
||||
return tv_tensor(make_coord(_,_), _);
|
||||
@ -253,7 +248,7 @@ struct TiledCopy : Copy_Atom
|
||||
|
||||
auto V = size<0>(tensor);
|
||||
|
||||
auto frg_layout_mn = upcast<TiledNumThr{} * V>(right_inverse(TiledLayout_TV{}).with_shape(TiledShape_MN{}));
|
||||
auto frg_layout_mn = upcast<TiledNumThr{} * V>(right_inverse(TiledLayout_TV{}).with_shape(shape(Tiler_MN{})));
|
||||
// (m,n) -> v_idx -- The shape and order of the V inside of TiledLayout_TV
|
||||
|
||||
auto frg_layout_v = zipped_divide(logical_product(make_layout(V), right_inverse(frg_layout_mn)), make_layout(AtomNumVal{}));
|
||||
@ -278,7 +273,7 @@ struct TiledCopy : Copy_Atom
|
||||
get_layoutS_TV()
|
||||
{
|
||||
// (M,N) -> (M,N)
|
||||
auto ref_S = make_layout(make_shape(TiledShape_MN{}, Int<1>{}));
|
||||
auto ref_S = make_layout(make_shape(shape(Tiler_MN{}), Int<1>{}));
|
||||
// (thr_idx,val_idx) -> (M,N)
|
||||
return tile2thrfrg(ref_S, right_inverse(AtomLayoutRef{}).compose(AtomLayoutSrc{}))(_,_,Int<0>{});
|
||||
}
|
||||
@ -290,7 +285,7 @@ struct TiledCopy : Copy_Atom
|
||||
// (thr_idx,val_idx) -> (M,N)
|
||||
auto layoutS_TV = get_layoutS_TV();
|
||||
// (M,K) -> (thr_idx,val_idx)
|
||||
auto layoutS_MK = right_inverse(layoutS_TV).with_shape(TiledShape_MN{});
|
||||
auto layoutS_MK = right_inverse(layoutS_TV).with_shape(shape(Tiler_MN{}));
|
||||
|
||||
// athrid = (v,m,k) -> thr_idx
|
||||
auto thrID_S = make_layout(size<0>(TiledLayout_TV{}));
|
||||
@ -303,7 +298,7 @@ struct TiledCopy : Copy_Atom
|
||||
get_layoutD_TV()
|
||||
{
|
||||
// (M,N) -> (M,N)
|
||||
auto ref_D = make_layout(make_shape(TiledShape_MN{}, Int<1>{}));
|
||||
auto ref_D = make_layout(make_shape(shape(Tiler_MN{}), Int<1>{}));
|
||||
// (thr_idx,val_idx) -> (M,N)
|
||||
return tile2thrfrg(ref_D, right_inverse(AtomLayoutRef{}).compose(AtomLayoutDst{}))(_,_,Int<0>{});
|
||||
}
|
||||
@ -315,7 +310,7 @@ struct TiledCopy : Copy_Atom
|
||||
// (thr_idx,val_idx) -> (M,N)
|
||||
auto layoutD_TV = get_layoutD_TV();
|
||||
// (M,K) -> (thr_idx,val_idx)
|
||||
auto layoutD_MK = right_inverse(layoutD_TV).with_shape(TiledShape_MN{});
|
||||
auto layoutD_MK = right_inverse(layoutD_TV).with_shape(shape(Tiler_MN{}));
|
||||
|
||||
// athrid = (v,m,k) -> thr_idx
|
||||
auto thrID_D = make_layout(size<0>(TiledLayout_TV{}));
|
||||
@ -406,51 +401,44 @@ make_tiled_copy_impl(Copy_Atom<Args...> const& atom,
|
||||
// These tile the Copy_Atom as a whole
|
||||
//
|
||||
|
||||
template <class... Args,
|
||||
class TiledMMA>
|
||||
template <class... CArgs, class... MArgs>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
make_tiled_copy_A(Copy_Atom<Args...> const& copy_atom,
|
||||
TiledMMA const& tiled_mma)
|
||||
make_tiled_copy_A(Copy_Atom<CArgs...> const& copy_atom,
|
||||
TiledMMA<MArgs...> const& mma)
|
||||
{
|
||||
using MNK = typename TiledMMA::TiledShape_MNK;
|
||||
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), make_shape(size<0>(MNK{}),size<2>(MNK{})));
|
||||
return make_tiled_copy_impl(copy_atom, mma.get_layoutA_TV(), make_shape(tile_size<0>(mma),tile_size<2>(mma)));
|
||||
}
|
||||
|
||||
template <class... Args,
|
||||
class TiledMMA>
|
||||
template <class... CArgs, class... MArgs>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
make_tiled_copy_B(Copy_Atom<Args...> const& copy_atom,
|
||||
TiledMMA const& tiled_mma)
|
||||
make_tiled_copy_B(Copy_Atom<CArgs...> const& copy_atom,
|
||||
TiledMMA<MArgs...> const& mma)
|
||||
{
|
||||
using MNK = typename TiledMMA::TiledShape_MNK;
|
||||
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutB_TV(), make_shape(size<1>(MNK{}),size<2>(MNK{})));
|
||||
return make_tiled_copy_impl(copy_atom, mma.get_layoutB_TV(), make_shape(tile_size<1>(mma),tile_size<2>(mma)));
|
||||
}
|
||||
|
||||
template <class... Args,
|
||||
class TiledMMA>
|
||||
template <class... CArgs, class... MArgs>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
make_tiled_copy_C(Copy_Atom<Args...> const& copy_atom,
|
||||
TiledMMA const& tiled_mma)
|
||||
make_tiled_copy_C(Copy_Atom<CArgs...> const& copy_atom,
|
||||
TiledMMA<MArgs...> const& mma)
|
||||
{
|
||||
using MNK = typename TiledMMA::TiledShape_MNK;
|
||||
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), make_shape(size<0>(MNK{}),size<1>(MNK{})));
|
||||
return make_tiled_copy_impl(copy_atom, mma.get_layoutC_TV(), make_shape(tile_size<0>(mma),tile_size<1>(mma)));
|
||||
}
|
||||
|
||||
// returns the smallest tiled copy that can retile LayoutC_TV
|
||||
// for use with pipelined epilogues with subtiled stores
|
||||
template <class... Args,
|
||||
class TiledMMA>
|
||||
template <class... CArgs, class... MArgs>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
make_tiled_copy_C_atom(Copy_Atom<Args...> const& copy_atom,
|
||||
TiledMMA const& tiled_mma)
|
||||
make_tiled_copy_C_atom(Copy_Atom<CArgs...> const& copy_atom,
|
||||
TiledMMA<MArgs...> const& mma)
|
||||
{
|
||||
// Truncate the V-layout to just the Copy_Atom, keep the V-order
|
||||
auto layoutC_TV = tiled_mma.get_layoutC_TV();
|
||||
auto copy_V = Int<Copy_Atom<Args...>::NumValSrc>{};
|
||||
auto layoutC_TV = mma.get_layoutC_TV();
|
||||
auto copy_V = Int<Copy_Atom<CArgs...>::NumValSrc>{};
|
||||
CUTE_STATIC_ASSERT_V(copy_V <= size<1>(layoutC_TV));
|
||||
auto layout_TV = composition(layoutC_TV, make_layout(make_shape(size<0>(layoutC_TV), copy_V)));
|
||||
|
||||
@ -458,8 +446,7 @@ make_tiled_copy_C_atom(Copy_Atom<Args...> const& copy_atom,
|
||||
|
||||
// Tiler -- Find the active elements in the MMA tensor and generate a tiler to extract them
|
||||
// Convert to the awkward by-mode tiler to preserve the modes of the tiled MMA
|
||||
using MNK = typename TiledMMA::TiledShape_MNK;
|
||||
auto mma_tiler = make_shape(size<0>(MNK{}),size<1>(MNK{}));
|
||||
auto mma_tiler = make_shape(tile_size<0>(mma),tile_size<1>(mma));
|
||||
auto mma_zeros = repeat_like(mma_tiler, Int<0>{});
|
||||
|
||||
auto tiler = transform(make_seq<rank(mma_tiler)>{}, [&](auto i) {
|
||||
@ -474,8 +461,6 @@ make_tiled_copy_C_atom(Copy_Atom<Args...> const& copy_atom,
|
||||
// (tid,vid) -> tile_coord
|
||||
auto layout_tv = composition(left_inverse(tile2mma), layout_TV);
|
||||
|
||||
|
||||
using MNK = typename TiledMMA::TiledShape_MNK;
|
||||
return make_tiled_copy_impl(copy_atom, layout_tv, tiler);
|
||||
}
|
||||
|
||||
@ -655,8 +640,10 @@ print(TiledCopy<Atom, Args...> const& copy, char const* pad = "")
|
||||
template <class TiledCopy, class ThrIdx>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print(ThrCopy<TiledCopy, ThrIdx> const&)
|
||||
print(ThrCopy<TiledCopy, ThrIdx> const& thr_copy)
|
||||
{
|
||||
print("ThrCopy\n");
|
||||
print(" ThrIdx: "); print(thr_copy.thr_idx_); print("\n");
|
||||
print(TiledCopy{});
|
||||
}
|
||||
|
||||
|
||||
@ -43,10 +43,14 @@
|
||||
namespace cute
|
||||
{
|
||||
|
||||
template <class GmemStrides_, class TmaGBasis_, class TmaSwizzle_>
|
||||
template <class GmemTmaBasisStrides_, class TmaGmemBasis_, class TmaSwizzle_>
|
||||
struct AuxTmaParams {
|
||||
using GmemStrides = GmemStrides_;
|
||||
using GmemStrides = GmemTmaBasisStrides_; // Strides for Gmem mode -> Tma coord mode, may be dynamic
|
||||
GmemStrides g_stride_;
|
||||
using TmaGmemBasis = TmaGmemBasis_; // Layout for Tma box shape -> Gmem mode(s), always static
|
||||
static_assert(is_static<TmaGmemBasis>::value);
|
||||
using TmaSwizzle = TmaSwizzle_; // Tma swizzle, always Swizzle<B,M,S>
|
||||
static_assert(is_static<TmaSwizzle>::value);
|
||||
};
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
@ -138,13 +142,19 @@ struct Copy_Traits<SM90_TMA_LOAD, NumBitsPerTMA, AuxParams_>
|
||||
// Construct an executable SM90_TMA_LOAD with tma_mbar
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
Copy_Traits<SM90_TMA_LOAD_OP, NumBitsPerTMA>
|
||||
with(uint64_t& tma_mbar, uint16_t const& multicast_mask = 0) const {
|
||||
with(uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask = 0) const {
|
||||
// We accept multicast_mask here to keep the API for both atoms consistent
|
||||
// assert(multicast_mask == 0);
|
||||
(void) multicast_mask;
|
||||
return {tma_desc_, tma_mbar};
|
||||
}
|
||||
|
||||
// Construct an executable SM90_TMA_LOAD with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm)
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
Copy_Traits<SM90_TMA_LOAD_OP, NumBitsPerTMA>
|
||||
with(TmaDescriptor const* new_tma_desc, uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask = 0) const {
|
||||
// We accept multicast_mask here to keep the API for both atoms consistent
|
||||
return {*new_tma_desc, tma_mbar};
|
||||
}
|
||||
|
||||
// Generate the TMA coord tensor
|
||||
template <class GShape>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
@ -251,6 +261,13 @@ struct Copy_Traits<SM90_TMA_LOAD_MULTICAST, NumBitsPerTMA, AuxParams_>
|
||||
return {tma_desc_, tma_load_mbar, multicast_mask};
|
||||
}
|
||||
|
||||
// Construct an executable SM90_TMA_LOAD_MULTICAST_OP with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm)
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
Copy_Traits<SM90_TMA_LOAD_MULTICAST_OP, NumBitsPerTMA>
|
||||
with(TmaDescriptor const* new_tma_desc, uint64_t& tma_load_mbar, uint16_t const& multicast_mask) const {
|
||||
return {*new_tma_desc, tma_load_mbar, multicast_mask};
|
||||
}
|
||||
|
||||
// Generate the TMA coord tensor
|
||||
template <class GShape>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
@ -508,51 +525,91 @@ coalesce_256(Layout<Shape,Stride> const& layout)
|
||||
return coalesce_256_impl<1>(flat_shape, flat_stride, get<0>(flat_shape), get<0>(flat_stride));
|
||||
}
|
||||
|
||||
template <class Engine, class Layout>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
coalesce_256(Tensor<Engine,Layout> const& tensor)
|
||||
{
|
||||
return make_tensor(tensor.data(), coalesce_256(tensor.layout()));
|
||||
}
|
||||
|
||||
|
||||
// Use a smem_inv_h to read through the GMEM tensor
|
||||
// and construct a TMA Descriptor for the resulting instruction
|
||||
// At the same time, construct the Tma Tensor's Stride to generate
|
||||
// the TMA coordinates that the instruction consumes.
|
||||
//
|
||||
template <class TmaInternalType,
|
||||
class GEngine, class GLayout,
|
||||
class SShape, class SStride,
|
||||
int B, int M, int S>
|
||||
CUTE_HOST_RTC
|
||||
class VShape, class VStride>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
make_tma_copy_desc(Tensor<GEngine,GLayout> const& gtensor, // The original GMEM Tensor
|
||||
Layout<SShape,SStride> const& smem_inv_h, // smem_idx to hier gmode
|
||||
Swizzle<B,M,S> const& swizzle) // Swizzle fn on smem_idx
|
||||
construct_tma_gbasis(Tensor<GEngine,GLayout> const& gtensor, // The original GMEM Tensor
|
||||
Layout<SShape,SStride> const& slayout, // The layout of SMEM
|
||||
Layout<VShape,VStride> const& cta_v_map) // smem_idx to hier gmode
|
||||
{
|
||||
//
|
||||
// TMA parameter checking
|
||||
//
|
||||
|
||||
CUTE_STATIC_ASSERT_V(product_each(shape(slayout)) == product_each(shape(cta_v_map)),
|
||||
"TMA requires CTA_Tile and SLayout top-level shape equivalence.");
|
||||
|
||||
#if 0
|
||||
print("gtensor : "); print(gtensor); print("\n");
|
||||
print("slayout : "); print(slayout); print("\n");
|
||||
print("cta_v_map : "); print(cta_v_map); print("\n");
|
||||
#endif
|
||||
|
||||
//
|
||||
// TMA slayout manipulation
|
||||
//
|
||||
|
||||
// Invert the smem to get the largest contiguous vector in the smem layout
|
||||
// smem idx -> smem coord
|
||||
auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout));
|
||||
|
||||
// Compose with the V-Map to convert smem coord (CTA val idx) to gmem mode
|
||||
// smem idx -> gmem mode
|
||||
auto sidx2gmode_full = coalesce(composition(cta_v_map, inv_smem_layout));
|
||||
|
||||
#if 0
|
||||
print("inv_smem_layout : "); print(inv_smem_layout); print("\n");
|
||||
print("sidx2gmode_full : "); print(sidx2gmode_full); print("\n");
|
||||
#endif
|
||||
|
||||
//
|
||||
// TMA gtensor truncation
|
||||
//
|
||||
|
||||
// Truncate any incompatibilities -- no starting in the middle of gmodes
|
||||
auto smem_rank = find_if(stride(sidx2gmode_full), [](auto e) {
|
||||
[[maybe_unused]] auto v = basis_value(e);
|
||||
return not is_constant<1,decltype(v)>{};
|
||||
});
|
||||
static_assert(smem_rank > 0, "Could not find a common tile-gmem vectorization. Does the Tile select out major GMEM modes?");
|
||||
|
||||
// Keep only the static-1 basis modes into gmem
|
||||
auto sidx2gmode = take<0,smem_rank>(sidx2gmode_full);
|
||||
|
||||
#if 0
|
||||
print("smem_rank : "); print(smem_rank); print("\n");
|
||||
print("sidx2gmode : "); print(sidx2gmode); print("\n");
|
||||
#endif
|
||||
|
||||
//
|
||||
// TMA gtensor manipulation
|
||||
//
|
||||
|
||||
// The smem vector is the same units as gtensor, so compose first and then recast
|
||||
// tma_val_idx:gmem_strides
|
||||
Tensor tile_gstride = recast<TmaInternalType>(gtensor.compose(smem_inv_h));
|
||||
auto tile_gstride = recast<TmaInternalType>(gtensor.compose(sidx2gmode)).layout();
|
||||
// Coalesce modes up to size-256 (the maximum TMA box extent in units of TmaInternalType)
|
||||
// tma_box_shape:gmem_strides
|
||||
Tensor tma_gstride = coalesce_256(tile_gstride);
|
||||
auto tma_gstride = coalesce_256(tile_gstride);
|
||||
|
||||
// Perform the tiling to the gmem vector again, but with indirections to the gtensor modes
|
||||
// Perform the tiling, recast, and coalesce to the gmem vector again, but with indirections to the gtensor modes
|
||||
auto gbasis = make_identity_layout(shape(gtensor));
|
||||
auto tile_gbasis_tmp = gbasis.compose(smem_inv_h);
|
||||
auto tile_gbasis_tmp = gbasis.compose(sidx2gmode);
|
||||
|
||||
// Instead of the recast (gbasis doesn't have type info), replace the shape with the already-recasted shape
|
||||
// tma_box_shape:gmem_mode
|
||||
auto tile_gbasis = make_layout(shape(tile_gstride), stride(tile_gbasis_tmp));
|
||||
|
||||
// Recast the original tensor for shape inspections
|
||||
auto gtensor_T = recast<TmaInternalType>(gtensor);
|
||||
// "Coalesce" the tile basis into a compatible shape with the tma_gstride
|
||||
auto tma_gbasis_tile = tile_gbasis.compose(make_layout(wrap(shape(tma_gstride))));
|
||||
|
||||
// Recast the original tensor for shape/stride inspections
|
||||
Tensor gtensor_T = recast<TmaInternalType>(gtensor);
|
||||
|
||||
// Find missing bases that don't appear in tile_gbasis
|
||||
// NOTE This is essentially ArithmeticTuple complement...
|
||||
// NOTE in pursuit of implementing an ArithmeticTuple logical_divide for smem_inv_h
|
||||
auto tile_gbasis_remaining_stride = filter_tuple(flatten(shape (gtensor_T)), flatten(stride(gtensor_T)),
|
||||
flatten(stride(gbasis)),
|
||||
[&](auto s, auto d, auto e)
|
||||
@ -561,7 +618,7 @@ make_tma_copy_desc(Tensor<GEngine,GLayout> const& gtensor, // The original GM
|
||||
return cute::tuple<>{}; // If size-1 or stride-0, then don't append
|
||||
} else {
|
||||
using E = decltype(e);
|
||||
auto has_e = any_of(stride(tile_gbasis), [] (auto tb) { return tb == E{}; });
|
||||
auto has_e = any_of(flatten(stride(tma_gbasis_tile)), [] (auto tb) { return tb == E{}; });
|
||||
if constexpr (decltype(has_e)::value) {
|
||||
return cute::tuple<>{}; // If d was found, then don't append
|
||||
} else {
|
||||
@ -569,13 +626,10 @@ make_tma_copy_desc(Tensor<GEngine,GLayout> const& gtensor, // The original GM
|
||||
}
|
||||
}
|
||||
});
|
||||
auto tile_gbasis_remaining_rank = rank(tile_gbasis_remaining_stride);
|
||||
|
||||
// "Coalesce" the tile basis into a compatible shape with the tma
|
||||
auto tma_gbasis_tile = tile_gbasis.compose(make_layout(wrap(shape(tma_gstride))));
|
||||
|
||||
// Append the remaining basis modes that contribute to the TMA with size-1
|
||||
auto tma_gbasis_full = make_layout(tuple_cat(wrap( shape(tma_gbasis_tile)), wrap(repeat<tile_gbasis_remaining_rank>(Int<1>{}))),
|
||||
auto tile_gbasis_remaining_shape = repeat<rank(tile_gbasis_remaining_stride)>(Int<1>{});
|
||||
auto tma_gbasis_full = make_layout(tuple_cat(wrap( shape(tma_gbasis_tile)), wrap(tile_gbasis_remaining_shape )),
|
||||
tuple_cat(wrap(stride(tma_gbasis_tile)), wrap(tile_gbasis_remaining_stride)));
|
||||
|
||||
// Group the trailing modes to make this max rank-5 -- TMA rank limitation
|
||||
@ -583,15 +637,98 @@ make_tma_copy_desc(Tensor<GEngine,GLayout> const& gtensor, // The original GM
|
||||
auto tma_gbasis = group<cute::min(rank(tma_gbasis_full),4),-1>(tma_gbasis_full);
|
||||
|
||||
#if 0
|
||||
print("smem_inv_h : "); print(smem_inv_h); print("\n");
|
||||
print("gtensor : "); print(gtensor); print("\n");
|
||||
print("tile_gstride : "); print(tile_gstride); print("\n");
|
||||
print("tma_gstride : "); print(tma_gstride); print("\n");
|
||||
print("gbasis : "); print(gbasis); print("\n");
|
||||
print("tile_gbasis : "); print(tile_gbasis); print("\n");
|
||||
print("tile_gbasis : "); print(tma_gbasis_tile); print("\n");
|
||||
print("tma_gbasis : "); print(tma_gbasis); print("\n");
|
||||
#endif
|
||||
|
||||
return tma_gbasis;
|
||||
}
|
||||
|
||||
template <class GEngine, class GLayout,
|
||||
class TmaGmemBasisStride,
|
||||
class ShapeT, size_t TmaRank>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
fill_tma_gmem_shape_stride(Tensor<GEngine,GLayout> const& gtensor, // Gmem Shapes and Strides, in units of TmaInternalType
|
||||
TmaGmemBasisStride const& tma_gbasis_stride, // Map Tma mode idx -> Gmem mode(s)
|
||||
cute::array<ShapeT, TmaRank> & gmem_prob_shape, // Tma Shapes, uint32_t or uin64_t
|
||||
cute::array<uint64_t, TmaRank> & gmem_prob_stride) // Tma Strides
|
||||
{
|
||||
static_assert(is_tuple<TmaGmemBasisStride>::value);
|
||||
static_assert(is_same<uint32_t, ShapeT>::value || is_same<uint64_t, ShapeT>::value);
|
||||
|
||||
using TmaInternalType = typename GEngine::value_type;
|
||||
constexpr int tma_rank = decltype(rank(tma_gbasis_stride))::value;
|
||||
static_assert(TmaRank >= tma_rank);
|
||||
|
||||
auto gmem_shape = shape(gtensor);
|
||||
auto gmem_stride = stride(gtensor);
|
||||
// Use the indirections in tma_gbasis_stride into gtensor to construct the tma gmem shapes/strides
|
||||
for_each(make_seq<tma_rank>{}, [&](auto i) {
|
||||
constexpr int tma_i_rank = decltype(rank<i>(tma_gbasis_stride))::value;
|
||||
if constexpr (tma_i_rank == 1) {
|
||||
// Trivial contribution of this gmem mode to this tma mode
|
||||
auto ej = unwrap(get<i>(tma_gbasis_stride));
|
||||
gmem_prob_shape[i] = basis_get(ej, gmem_shape);
|
||||
gmem_prob_stride[i] = basis_get(ej, gmem_stride) * sizeof_bits_v<TmaInternalType> / 8;
|
||||
} else {
|
||||
// Apply a recurrence to each gmem mode that contributes to this tma mode
|
||||
for_each(get<i>(tma_gbasis_stride), [&](auto ej) {
|
||||
// Problem shape
|
||||
uint64_t shape_j = basis_get(ej, gmem_shape);
|
||||
// Problem stride (in bytes)
|
||||
uint64_t stride_j = basis_get(ej, gmem_stride) * sizeof_bits_v<TmaInternalType> / 8;
|
||||
|
||||
uint64_t old_stride = gmem_prob_stride[i];
|
||||
gmem_prob_stride[i] = gcd(gmem_prob_stride[i], stride_j);
|
||||
|
||||
if (gmem_prob_stride[i] != 0) {
|
||||
// Recurrence: g_shape = (s_i - 1) * (d_i / gcd_j d_j) + 1
|
||||
gmem_prob_shape[i] = (gmem_prob_shape[i]-1) * (old_stride / gmem_prob_stride[i])
|
||||
+ (shape_j-1) * (stride_j / gmem_prob_stride[i])
|
||||
+ 1;
|
||||
} else {
|
||||
gmem_prob_shape[i] = shape_j;
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Overload for an existing Copy_Traits
|
||||
template <class GEngine, class GLayout,
|
||||
class Op, class Bits, class Aux,
|
||||
class ShapeT, size_t TmaRank>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
fill_tma_gmem_shape_stride(Copy_Traits<Op,Bits,Aux> const& tma_traits,
|
||||
Tensor<GEngine,GLayout> const& gtensor, // Gmem Shapes and Strides, value_type = TmaInternalType
|
||||
cute::array<ShapeT, TmaRank> & gmem_prob_shape, // Tma Shapes, uint32_t or uin64_t
|
||||
cute::array<uint64_t, TmaRank> & gmem_prob_stride) // Tma Strides
|
||||
{
|
||||
return fill_tma_gmem_shape_stride(gtensor, stride(typename Aux::TmaGmemBasis{}),
|
||||
gmem_prob_shape, gmem_prob_stride);
|
||||
}
|
||||
|
||||
// Use a sidx2gmode to read through the GMEM tensor
|
||||
// and construct a TMA Descriptor for the resulting instruction
|
||||
// At the same time, construct the Tma Tensor's Stride to generate
|
||||
// the TMA coordinates that the instruction consumes.
|
||||
//
|
||||
template <class TmaInternalType,
|
||||
class GEngine, class GLayout,
|
||||
class TShape, class TStride,
|
||||
int B, int M, int S>
|
||||
CUTE_HOST_RTC
|
||||
auto
|
||||
make_tma_copy_desc(Tensor<GEngine,GLayout> const& gtensor, // The original GMEM Tensor
|
||||
Layout<TShape,TStride> const& tma_gbasis, // TMA mode -> GMEM mode mapping
|
||||
Swizzle<B,M,S> const& swizzle, // Swizzle fn on smem_idx
|
||||
uint32_t num_multicast) // The number of CTAs in multicasting
|
||||
{
|
||||
//
|
||||
// TMA desc creation
|
||||
//
|
||||
@ -602,31 +739,16 @@ make_tma_copy_desc(Tensor<GEngine,GLayout> const& gtensor, // The original GM
|
||||
// TMA gmem desc info
|
||||
//
|
||||
|
||||
// Recast the original tensor for shape/stride inspections
|
||||
Tensor gtensor_T = recast<TmaInternalType>(gtensor);
|
||||
|
||||
void* gmem_address = (void*) raw_pointer_cast(gtensor_T.data());
|
||||
auto gmem_layout = gtensor_T.layout();
|
||||
|
||||
cute::array<uint64_t, 5> gmem_prob_shape = {1,1,1,1,1};
|
||||
cute::array<uint64_t, 5> gmem_prob_stride = {0,0,0,0,0};
|
||||
// Use the indirections in tma_gbasis in the values of flat_glayout to construct the gmem shapes/strides
|
||||
for_each(make_seq<tma_dim>{}, [&](auto i) {
|
||||
for_each(stride<i>(tma_gbasis), [&](auto ej) {
|
||||
// Problem stride
|
||||
uint64_t stride_j = ceil_div(basis_get(ej, stride(gmem_layout)) * sizeof_bits_v<TmaInternalType>, 8);
|
||||
uint64_t old_stride = gmem_prob_stride[i];
|
||||
gmem_prob_stride[i] = gcd(gmem_prob_stride[i], stride_j);
|
||||
|
||||
// Problem shape
|
||||
uint64_t shape_j = basis_get(ej, shape(gmem_layout));
|
||||
if (gmem_prob_stride[i] != 0) {
|
||||
// Recurrence: g_shape = (s_i - 1) * (d_i / gcd_j d_j) + 1
|
||||
gmem_prob_shape[i] = (gmem_prob_shape[i]-1) * (old_stride / gmem_prob_stride[i])
|
||||
+ (shape_j-1) * (stride_j / gmem_prob_stride[i])
|
||||
+ 1;
|
||||
} else {
|
||||
gmem_prob_shape[i] = shape_j;
|
||||
}
|
||||
});
|
||||
});
|
||||
fill_tma_gmem_shape_stride(gtensor_T, stride(tma_gbasis), gmem_prob_shape, gmem_prob_stride);
|
||||
|
||||
assert((reinterpret_cast<uint64_t>(gmem_address) & 0b1111) == 0); // Address must be 16B-aligned
|
||||
|
||||
@ -663,6 +785,13 @@ make_tma_copy_desc(Tensor<GEngine,GLayout> const& gtensor, // The original GM
|
||||
for_each(make_seq<tma_dim>{}, [&](auto i) {
|
||||
smem_box_shape[i] *= size<i>(tma_gbasis);
|
||||
});
|
||||
// Finally, truncate the tma box by the num_multicast
|
||||
for (uint32_t i = tma_dim-1, multicast = num_multicast; multicast > 1; --i) {
|
||||
assert(smem_box_shape[i] % multicast == 0 || multicast % smem_box_shape[i] == 0);
|
||||
uint32_t new_mult = ceil_div(multicast, smem_box_shape[i]);
|
||||
smem_box_shape[i] = ceil_div(smem_box_shape[i], multicast);
|
||||
multicast = new_mult;
|
||||
}
|
||||
|
||||
assert(smem_box_shape[0] >= (uint32_t(1))); // Size must be min 1
|
||||
assert(smem_box_shape[0] <= (uint32_t(1) << 8)); // Size must be max 2^8 = 256
|
||||
@ -740,26 +869,27 @@ make_tma_copy_desc(Tensor<GEngine,GLayout> const& gtensor, // The original GM
|
||||
auto recast_ratio = cute::ratio(Int<sizeof_bits<typename GEngine::value_type>::value>{},
|
||||
Int<sizeof_bits< TmaInternalType>::value>{});
|
||||
|
||||
auto gbasis = make_basis_like(shape(gtensor));
|
||||
|
||||
// Finally, get the inverse permutation of the E<i> bases for the mocked gmem stride
|
||||
// NOTE This is essentially ArithmeticTuple inverse...
|
||||
auto gmem_stride_bases = transform_leaf(stride(gbasis), [&](auto ei) {
|
||||
auto gmem_tma_basis_stride = transform_leaf(gbasis, [&](auto ei) {
|
||||
auto si = basis_get(ei, shape(gmem_layout));
|
||||
auto di = basis_get(ei, stride(gmem_layout));
|
||||
if constexpr (is_constant<1, decltype(si)>::value || is_constant<0, decltype(di)>::value) {
|
||||
return Int<0>{}; // If size-1 or stride-0, return arithmetic identity -- no contribution to the TMA
|
||||
} else {
|
||||
auto tma_gbasis_stride = stride(tma_gbasis);
|
||||
auto tma_gmem_basis_stride = stride(tma_gbasis);
|
||||
// Find j such that E<i> is in stride<j>(tma_gbasis)
|
||||
using EI = decltype(ei);
|
||||
[[maybe_unused]] auto j = find_if(tma_gbasis_stride, [&](auto tma_stride_j) { return any_of(tma_stride_j, [&](auto dj) { return dj == EI{}; }); });
|
||||
if constexpr (decltype(j == rank(tma_gbasis_stride))::value) {
|
||||
[[maybe_unused]] auto j = find_if(tma_gmem_basis_stride, [&](auto tma_stride_j) { return any_of(tma_stride_j, [&](auto dj) { return dj == EI{}; }); });
|
||||
if constexpr (decltype(j == rank(tma_gmem_basis_stride))::value) {
|
||||
return Int<0>{}; // If not-found, return arithmetic identity -- no contribution to the TMA
|
||||
} else
|
||||
if constexpr (decltype(j == Int<0>{})::value) {
|
||||
auto scale = recast_ratio * basis_get(ei, stride(gtensor));
|
||||
return E<j>{} * scale; // Return TMA Coord basis -- with a recast scale factor
|
||||
} else
|
||||
if constexpr (decltype(rank<j>(tma_gbasis_stride) == Int<1>{})::value) {
|
||||
if constexpr (decltype(rank<j>(tma_gmem_basis_stride) == Int<1>{})::value) {
|
||||
return E<j>{}; // Return TMA Coord basis -- known scale of Int<1>{}
|
||||
} else {
|
||||
int32_t scale = ceil_div(int32_t(di * sizeof_bits_v<TmaInternalType> / cute::max(gmem_prob_stride[j], 16)), 8);
|
||||
@ -768,14 +898,64 @@ make_tma_copy_desc(Tensor<GEngine,GLayout> const& gtensor, // The original GM
|
||||
}
|
||||
});
|
||||
|
||||
#if 0
|
||||
print("tma_gbasis : "); print(gmem_stride_bases); print("\n");
|
||||
#endif
|
||||
#if 0
|
||||
print("gmem_tma_basis_stride : "); print(gmem_tma_basis_stride); print("\n");
|
||||
#endif
|
||||
|
||||
using AuxParams = AuxTmaParams<decltype(gmem_stride_bases),
|
||||
using AuxParams = AuxTmaParams<decltype(gmem_tma_basis_stride),
|
||||
decltype(tma_gbasis),
|
||||
decltype(swizzle)>;
|
||||
return cute::make_tuple(tma_desc, AuxParams{gmem_stride_bases});
|
||||
return cute::make_tuple(tma_desc, AuxParams{gmem_tma_basis_stride});
|
||||
}
|
||||
|
||||
template <class TmaInternalType,
|
||||
class CopyOp,
|
||||
class GEngine, class GLayout,
|
||||
class SLayout,
|
||||
class VShape, class VStride>
|
||||
CUTE_HOST_RTC
|
||||
auto
|
||||
make_tma_copy_atom(CopyOp,
|
||||
Tensor<GEngine,GLayout> const& gtensor, // Full GMEM Tensor
|
||||
SLayout const& slayout, // CTA Tile of SMEM, potentially swizzled
|
||||
uint32_t const& num_multicast, // The number of CTAs involved in multicasting
|
||||
Layout<VShape,VStride> const& cta_v_map) // V: CTA val idx -> gmem mode
|
||||
{
|
||||
//
|
||||
// TMA truncated layout
|
||||
//
|
||||
|
||||
auto smem_swizzle = get_swizzle_portion(slayout);
|
||||
auto smem_layout = get_nonswizzle_portion(slayout);
|
||||
|
||||
auto tma_gbasis = detail::construct_tma_gbasis<TmaInternalType>(gtensor, smem_layout, cta_v_map);
|
||||
|
||||
//
|
||||
// Construct the TMA Desc and the strides of the TMA Tensor
|
||||
//
|
||||
|
||||
auto [tma_desc, aux_params] = detail::make_tma_copy_desc<TmaInternalType>(gtensor,
|
||||
tma_gbasis,
|
||||
smem_swizzle,
|
||||
num_multicast);
|
||||
|
||||
//
|
||||
// Construct the Copy_Traits
|
||||
//
|
||||
|
||||
constexpr int num_bits_per_tma = decltype(size(tma_gbasis))::value * sizeof_bits_v<TmaInternalType>;
|
||||
using Traits = Copy_Traits<CopyOp, cute::C<num_bits_per_tma>, decltype(aux_params)>;
|
||||
using Atom = Copy_Atom<Traits, typename GEngine::value_type>;
|
||||
|
||||
Traits tma_traits{tma_desc, aux_params};
|
||||
|
||||
#if 0
|
||||
print("num_bits_per_tma : "); print(num_bits_per_tma); print("\n");
|
||||
print("g_stride_bases : "); print(tma_traits.aux_params_.g_stride_); print("\n");
|
||||
#endif
|
||||
|
||||
// Return the Copy_Atom
|
||||
return Atom{tma_traits};
|
||||
}
|
||||
|
||||
// The "logical TMA tid" is a map from the CTA rank to its logical id
|
||||
@ -790,122 +970,46 @@ template <class TmaInternalType,
|
||||
class VShape, class VStride>
|
||||
CUTE_HOST_RTC
|
||||
auto
|
||||
make_tma_copy_tiled(CopyOp,
|
||||
make_tma_copy_tiled(CopyOp const& copy_op,
|
||||
Tensor<GEngine,GLayout> const& gtensor, // Full GMEM Tensor
|
||||
SLayout const& slayout, // CTA Tile of SMEM
|
||||
Layout<TShape,TStride> const& cta_t_map, // T: CTA thr idx -> logical TMA tid
|
||||
Layout<VShape,VStride> const& cta_v_map) // V: CTA val idx -> gmem mode
|
||||
{
|
||||
//
|
||||
// TMA parameter checking
|
||||
//
|
||||
|
||||
CUTE_STATIC_ASSERT_V(product_each(shape(slayout)) == product_each(shape(cta_v_map)),
|
||||
"TMA requires CTA_Tile and SLayout top-level shape equivalence.");
|
||||
CUTE_STATIC_ASSERT_V(size(slayout) % cosize(cta_t_map) == Int<0>{},
|
||||
"Number of active CTAs in TMA must divide domain size of slayout.");
|
||||
|
||||
//
|
||||
// TMA slayout manipulation
|
||||
//
|
||||
|
||||
// Invert the smem to get the largest contiguous vector in the smem layout
|
||||
// smem idx -> smem coord
|
||||
auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout));
|
||||
|
||||
// Compose with the V-Map to convert smem coord (CTA val idx) to gmem mode
|
||||
// smem idx -> gmem mode
|
||||
auto sidx_to_gmode = coalesce(composition(cta_v_map, inv_smem_layout));
|
||||
|
||||
#if 0
|
||||
print("g_tensor : "); print(gtensor); print("\n");
|
||||
print("s_layout : "); print(slayout); print("\n");
|
||||
print("cta_t_map : "); print(cta_t_map); print("\n");
|
||||
print("cta_v_map : "); print(cta_v_map); print("\n");
|
||||
print("inv_s_layout : "); print(inv_smem_layout); print("\n");
|
||||
print("sidx_to_gmode : "); print(sidx_to_gmode); print("\n");
|
||||
#endif
|
||||
|
||||
//
|
||||
// TMA gtensor manipulation
|
||||
//
|
||||
|
||||
// Generate a TupleBasis for the gtensor
|
||||
// gmem coord -> gmem coord
|
||||
auto glayout_basis = make_identity_layout(shape(gtensor));
|
||||
|
||||
// Tile the modes of gtensor with the truncated cta_v_map o inv_smem_layout_trunc
|
||||
// smem idx -> gmem coord
|
||||
auto tma_layout_full = flatten(composition(glayout_basis, sidx_to_gmode));
|
||||
|
||||
// Truncate any incompatibilities -- no starting in the middle of gmodes
|
||||
auto smem_rank = find_if(stride(tma_layout_full), [](auto e) {
|
||||
[[maybe_unused]] auto v = basis_value(e);
|
||||
return not is_constant<1,decltype(v)>{};
|
||||
});
|
||||
static_assert(smem_rank > 0, "Could not find a common tile-gmem vectorization. Does the Tile select out major GMEM modes?");
|
||||
|
||||
// Keep only the static-1 basis modes into gmem
|
||||
auto tma_layout_trunc = take<0,smem_rank>(tma_layout_full);
|
||||
|
||||
// Keep only the portion each multicast CTA will be responsible for
|
||||
auto tma_layout_v = composition(tma_layout_trunc, shape_div(size(tma_layout_trunc), cosize(cta_t_map)));
|
||||
|
||||
#if 0
|
||||
print("glayout_basis : "); print(glayout_basis); print("\n");
|
||||
print("tma_layout_full : "); print(tma_layout_full); print("\n");
|
||||
|
||||
print("tma_layout_trunc: "); print(tma_layout_trunc); print("\n");
|
||||
print("tma_layout_v : "); print(tma_layout_v); print("\n");
|
||||
#endif
|
||||
|
||||
//
|
||||
// Construct the TMA Desc and the strides of the TMA Tensor
|
||||
//
|
||||
|
||||
auto [tma_desc, aux_params] = detail::make_tma_copy_desc<TmaInternalType>(gtensor,
|
||||
tma_layout_v,
|
||||
get_swizzle_portion(slayout));
|
||||
|
||||
//
|
||||
// Construct the Copy_Traits
|
||||
//
|
||||
|
||||
using T = typename GEngine::value_type;
|
||||
constexpr int num_bits_per_tma = decltype(size(tma_layout_trunc))::value * sizeof_bits_v<T>;
|
||||
using Traits = Copy_Traits<CopyOp, cute::C<num_bits_per_tma>, decltype(aux_params)>;
|
||||
using Atom = Copy_Atom<Traits, T>;
|
||||
|
||||
Traits tma_traits{tma_desc, aux_params};
|
||||
|
||||
#if 0
|
||||
print("num_bits_per_tma : "); print(num_bits_per_tma); print("\n");
|
||||
print("g_stride_bases : "); print(tma_traits.aux_params_.g_stride_); print("\n");
|
||||
#endif
|
||||
Copy_Atom atom = make_tma_copy_atom<TmaInternalType>(copy_op, gtensor, slayout,
|
||||
cosize(cta_t_map), cta_v_map);
|
||||
|
||||
//
|
||||
// Construct the TiledCopy
|
||||
//
|
||||
|
||||
auto cta_tiler = product_each(shape(cta_v_map));
|
||||
[[maybe_unused]] auto cta_tiler = product_each(shape(cta_v_map));
|
||||
|
||||
auto num_elems_per_tma = size<1>(typename decltype(atom)::RefLayout{}) / Int<sizeof_bits_v<typename GEngine::value_type>>{};
|
||||
|
||||
// smem idx -> smem coord
|
||||
auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout));
|
||||
// CTA V -> smem_coord
|
||||
auto layout_v = composition(inv_smem_layout, size(tma_layout_trunc));
|
||||
auto layout_v = composition(inv_smem_layout, num_elems_per_tma);
|
||||
// Scale that up to cover all of the smem_coords
|
||||
auto layout_V = tile_to_shape(make_layout(layout_v), size(cta_v_map));
|
||||
// CTA T -> smem idx
|
||||
auto layout_t = make_layout(cosize(cta_t_map), shape_div(size(tma_layout_trunc), cosize(cta_t_map)));
|
||||
auto layout_t = make_layout(cosize(cta_t_map), shape_div(num_elems_per_tma, cosize(cta_t_map)));
|
||||
// CTA TID -> smem coord
|
||||
auto layout_T = composition(inv_smem_layout, composition(layout_t, cta_t_map));
|
||||
// Combine with the T mapping
|
||||
auto layout_TV = make_layout(layout_T, layout_V);
|
||||
[[maybe_unused]] auto layout_TV = make_layout(layout_T, layout_V);
|
||||
|
||||
#if 0
|
||||
print("cta_tiler : "); print(cta_tiler); print("\n");
|
||||
print("layout_VT : "); print(layout_VT); print("\n");
|
||||
print("layout_v : "); print(layout_v); print("\n");
|
||||
print("layout_V : "); print(layout_V); print("\n");
|
||||
print("layout_t : "); print(layout_t); print("\n");
|
||||
print("layout_T : "); print(layout_T); print("\n");
|
||||
print("layout_TV : "); print(layout_TV); print("\n");
|
||||
#endif
|
||||
|
||||
return TiledCopy<Atom, decltype(layout_TV), decltype(cta_tiler)>{tma_traits};
|
||||
return TiledCopy<decltype(atom), decltype(layout_TV), decltype(cta_tiler)>{atom};
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
@ -982,7 +1086,7 @@ make_tma_copy_tiled(CopyOp,
|
||||
|
||||
copy(tma.with(barrier, mcast_mask), tAgA, tAsA); // copy with supporting TMA params
|
||||
*/
|
||||
template <class TmaInternalType,
|
||||
template <class TmaInternalType = void,
|
||||
class CopyOp,
|
||||
class GEngine, class GLayout,
|
||||
class SLayout,
|
||||
@ -996,37 +1100,16 @@ make_tma_copy(CopyOp const& copy_op,
|
||||
CTA_Tiler const& cta_tiler,
|
||||
Cluster_Size const& cluster_size)
|
||||
{
|
||||
auto cta_v_tile = make_identity_layout(shape(gtensor)).compose(cta_tiler);
|
||||
auto cta_t_tile = make_layout(cluster_size);
|
||||
return detail::make_tma_copy_tiled<TmaInternalType>(copy_op,
|
||||
gtensor,
|
||||
slayout,
|
||||
cta_t_tile,
|
||||
cta_v_tile);
|
||||
auto cta_v_tile = make_identity_layout(shape(gtensor)).compose(cta_tiler);
|
||||
auto cta_t_tile = make_layout(cluster_size);
|
||||
// Prefer TmaInternalType if specified. Fallback to GEngine::value_type
|
||||
using TmaType = conditional_t<is_same<void, TmaInternalType>::value, typename GEngine::value_type, TmaInternalType>;
|
||||
return detail::make_tma_copy_tiled<TmaType>(copy_op,
|
||||
gtensor, slayout,
|
||||
cta_t_tile, cta_v_tile);
|
||||
}
|
||||
|
||||
// Explicit defaulting
|
||||
template <class CopyOp,
|
||||
class GEngine, class GLayout,
|
||||
class SLayout,
|
||||
class CTA_Tile,
|
||||
class Cluster_Size>
|
||||
CUTE_HOST_RTC
|
||||
auto
|
||||
make_tma_copy(CopyOp const& copy_op,
|
||||
Tensor<GEngine,GLayout> const& gtensor,
|
||||
SLayout const& slayout,
|
||||
CTA_Tile const& cta_tile,
|
||||
Cluster_Size const& cluster_size)
|
||||
{
|
||||
using TmaInternalType = typename GEngine::value_type;
|
||||
return make_tma_copy<TmaInternalType>(copy_op,
|
||||
gtensor,
|
||||
slayout,
|
||||
cta_tile,
|
||||
cluster_size);
|
||||
}
|
||||
|
||||
template <class CopyOp,
|
||||
class GEngine, class GLayout,
|
||||
class SLayout>
|
||||
@ -1039,6 +1122,7 @@ make_tma_copy(CopyOp const& copy_op,
|
||||
return make_tma_copy(copy_op, gtensor, slayout, product_each(shape(slayout)), Int<1>{});
|
||||
}
|
||||
|
||||
// Explicit defaulting
|
||||
template <class CopyOp,
|
||||
class GEngine, class GLayout,
|
||||
class SLayout,
|
||||
@ -1053,4 +1137,86 @@ make_tma_copy(CopyOp const& copy_op,
|
||||
return make_tma_copy(copy_op, gtensor, slayout, product_each(shape(slayout)), cluster_size);
|
||||
}
|
||||
|
||||
////////////////////////////////////
|
||||
// Experimental Make TMA Atom and Partitioner
|
||||
///////////////////////////////////
|
||||
|
||||
template <class TmaInternalType = void,
|
||||
class CopyOp,
|
||||
class GEngine, class GLayout,
|
||||
class SLayout,
|
||||
class CTA_Tiler,
|
||||
class Cluster_Size>
|
||||
CUTE_HOST_RTC
|
||||
auto
|
||||
make_tma_atom(CopyOp const& copy_op,
|
||||
Tensor<GEngine,GLayout> const& gtensor,
|
||||
SLayout const& slayout,
|
||||
CTA_Tiler const& cta_tiler,
|
||||
Cluster_Size const& cluster_size)
|
||||
{
|
||||
auto cta_v_tile = make_identity_layout(shape(gtensor)).compose(cta_tiler);
|
||||
// Prefer TmaInternalType if specified. Fallback to GEngine::value_type
|
||||
using TmaType = conditional_t<is_same<void, TmaInternalType>::value, typename GEngine::value_type, TmaInternalType>;
|
||||
return detail::make_tma_copy_atom<TmaType>(copy_op,
|
||||
gtensor, slayout,
|
||||
size(cluster_size), cta_v_tile);
|
||||
}
|
||||
|
||||
// The "VectorCopy Partitioner" for TMA
|
||||
template <class... Args,
|
||||
class CtaCoord,
|
||||
class TShape, class TStride,
|
||||
class SEngine, class SLayout,
|
||||
class GEngine, class GLayout>
|
||||
CUTE_DEVICE
|
||||
auto
|
||||
tma_partition(Copy_Atom<Args...> const& copy_atom,
|
||||
CtaCoord const& cta_coord,
|
||||
Layout<TShape,TStride> const& cta_layout, // T: CTA coord -> logical multicast id
|
||||
Tensor<SEngine,SLayout> const& stensor, // SMEM Tensor (TMATile, Iter)
|
||||
Tensor<GEngine,GLayout> const& gtensor) // GMEM Tensor (TMATile, Iter)
|
||||
{
|
||||
// Invert the smem to get the largest contiguous vector in the smem layout
|
||||
Layout inv_smem_layout = right_inverse(get_nonswizzle_portion(layout<0>(stensor)));
|
||||
// Scale that up to cover all of the smem_coords
|
||||
Layout layout_v = tile_to_shape(make_layout(inv_smem_layout), size<0>(stensor));
|
||||
|
||||
// Factor out the single-instrucion portion
|
||||
Layout tma_layout_v = make_layout(Int<Copy_Atom<Args...>::NumValSrc>{});
|
||||
Layout layout_V = logical_divide(layout_v, tma_layout_v);
|
||||
|
||||
// Transform tile mode and coalesce
|
||||
Tensor gtensor_v = coalesce(gtensor.compose(layout_V, _), Shape<Shape<_1,_1>,_1>{}); // ((TMA,TMA_Iter),Iter)
|
||||
Tensor stensor_v = coalesce(stensor.compose(layout_V, _), Shape<Shape<_1,_1>,_1>{}); // ((TMA,TMA_Iter),Iter)
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print("layout_V : "); print(layout_V); print("\n");
|
||||
print("gtensor_v : "); print(gtensor_v); print("\n");
|
||||
print("stensor_v : "); print(stensor_v); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
// Restride the cta-into-tma-instr layout
|
||||
Layout tma_layout_t = composition(make_layout(Int<1>{}, shape_div(size(tma_layout_v), cosize(cta_layout))), cta_layout);
|
||||
Layout tma_layout_tv = make_layout(tma_layout_t, tma_layout_v);
|
||||
|
||||
// Transform TMA mode
|
||||
Tensor gtensor_tv = gtensor_v.compose(make_tile(tma_layout_tv, _), _); // (((Thr,Frg),TMA_Iter),Iter)
|
||||
Tensor stensor_tv = stensor_v.compose(make_tile(tma_layout_tv, _), _); // (((Thr,Frg),TMA_Iter),Iter)
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print("tma_layout_tv : "); print(tma_layout_tv); print("\n");
|
||||
print("gtensor_tv : "); print(gtensor_tv); print("\n");
|
||||
print("stensor_tv : "); print(stensor_tv); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
// Slice and group Frg,TMA_Iter and return
|
||||
auto c = make_coord(make_coord(make_coord(cta_coord, _), _), _);
|
||||
return cute::make_tuple(group_modes<0,2>(gtensor_tv(c)), group_modes<0,2>(stensor_tv(c)));
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
@ -31,12 +31,12 @@
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
#include <cute/arch/mma.hpp>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cute/arch/mma.hpp>
|
||||
|
||||
#include <cute/atom/mma_traits.hpp>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cute/util/type_traits.hpp>
|
||||
|
||||
namespace cute {
|
||||
@ -196,39 +196,37 @@ struct MMA_Atom<MMA_Traits<Args...>>
|
||||
template <class TiledMMA, class ThrCoord>
|
||||
struct ThrMMA;
|
||||
|
||||
// @tparam MMA_Atom The MMA_Atom to use in the TiledMMA
|
||||
// @tparam AtomLayoutMNK The MNK-tiling of the Atom to be performed.
|
||||
// @tparam PermuationsMNK Permutations to apply to each MNK-mode before tiling for the Atom.
|
||||
template <class MMA_Atom,
|
||||
class AtomLayoutMNK = Layout<Shape<_1,_1,_1>>,
|
||||
class ValLayoutMNK = Layout<Shape<_1,_1,_1>>,
|
||||
class PermutationsMNK = Tile<Underscore,Underscore,Underscore>>
|
||||
class AtomLayoutMNK,
|
||||
class PermutationMNK = Tile<Underscore,Underscore,Underscore>>
|
||||
struct TiledMMA : MMA_Atom
|
||||
{
|
||||
static_assert(rank_v<AtomLayoutMNK> == 3, "TiledMMA requires rank-3 AtomLayoutMNK");
|
||||
static_assert(rank_v<ValLayoutMNK> == 3, "TiledMMA requires rank-3 ValLayoutMNK");
|
||||
static_assert(rank_v<PermutationsMNK> == 3, "TiledMMA requires rank-3 PermutationsMNK");
|
||||
|
||||
using Atom = MMA_Atom;
|
||||
using AtomShape_MNK = typename MMA_Atom::Shape_MNK;
|
||||
|
||||
using AtomThrID = typename MMA_Atom::ThrID;
|
||||
using AtomLayoutC_TV = typename MMA_Atom::LayoutC_TV;
|
||||
using AtomLayoutA_TV = typename MMA_Atom::LayoutA_TV;
|
||||
using AtomLayoutB_TV = typename MMA_Atom::LayoutB_TV;
|
||||
|
||||
// ThrV -> thread_idx
|
||||
using AtomThrID = typename MMA_Atom::ThrID;
|
||||
static_assert( rank_v<AtomLayoutMNK> == 3, "TiledMMA requires rank-3 AtomLayoutMNK");
|
||||
static_assert( rank_v<PermutationMNK> == 3, "TiledMMA requires rank-3 PermutationMNK");
|
||||
static_assert( is_tile<PermutationMNK>::value, "TiledMMA requires independent permutations of MNK.");
|
||||
static_assert(is_static<PermutationMNK>::value, "TiledMMA requires static permutations of MNK.");
|
||||
|
||||
// (M,N,K)
|
||||
using TiledShape_MNK = decltype(make_shape(size<0>(AtomShape_MNK{})*size<0>(AtomLayoutMNK{})*size<0>(ValLayoutMNK{}),
|
||||
size<1>(AtomShape_MNK{})*size<1>(AtomLayoutMNK{})*size<1>(ValLayoutMNK{}),
|
||||
size<2>(AtomShape_MNK{})*size<2>(AtomLayoutMNK{})*size<2>(ValLayoutMNK{})));
|
||||
|
||||
// thrid = (ThrV,ThrM,ThrN,ThrK) -> thr_idx
|
||||
using ThrLayoutVMNK = decltype(tiled_product(AtomThrID{}, AtomLayoutMNK{}));
|
||||
ThrLayoutVMNK thr_layout_vmnk_;
|
||||
|
||||
// thr_idx -> (ThrV,ThrM,ThrN,ThrK)
|
||||
using TidLayout = decltype(right_inverse(ThrLayoutVMNK{}));
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
TiledMMA(MMA_Atom const& mma_atom = {}, AtomLayoutMNK const& thr_layout_mnk = {})
|
||||
: MMA_Atom(mma_atom),
|
||||
thr_layout_vmnk_(tiled_product(AtomThrID{}, thr_layout_mnk)) {}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr auto
|
||||
get_thr_layout_vmnk() const {
|
||||
return ThrLayoutVMNK{};
|
||||
return thr_layout_vmnk_;
|
||||
}
|
||||
|
||||
// Tile a tensor or a layout from shape
|
||||
@ -243,17 +241,17 @@ struct TiledMMA : MMA_Atom
|
||||
// RestM: The values tiled in M.
|
||||
// RestN: The values tiled in N.
|
||||
template <class CTensor>
|
||||
CUTE_HOST_DEVICE constexpr static
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
thrfrg_C(CTensor&& ctensor)
|
||||
thrfrg_C(CTensor&& ctensor) const
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(rank(ctensor) >= Int<2>{});
|
||||
CUTE_STATIC_ASSERT_V(size<0>(ctensor) % size<0>(TiledShape_MNK{}) == Int<0>{});
|
||||
CUTE_STATIC_ASSERT_V(size<1>(ctensor) % size<1>(TiledShape_MNK{}) == Int<0>{});
|
||||
//CUTE_STATIC_ASSERT_V(size<0>(ctensor) % size<0>(TiledShape_MNK{}) == Int<0>{});
|
||||
//CUTE_STATIC_ASSERT_V(size<1>(ctensor) % size<1>(TiledShape_MNK{}) == Int<0>{});
|
||||
|
||||
// Reorder the tensor for the TiledAtom
|
||||
auto t_tile = make_tile(left_inverse(get<0>(PermutationsMNK{})),
|
||||
left_inverse(get<1>(PermutationsMNK{})));
|
||||
auto t_tile = make_tile(get<0>(PermutationMNK{}),
|
||||
get<1>(PermutationMNK{}));
|
||||
auto t_tensor = logical_divide(ctensor, t_tile); // (PermM,PermN)
|
||||
|
||||
// Tile the tensor for the Atom
|
||||
@ -266,25 +264,13 @@ struct TiledMMA : MMA_Atom
|
||||
|
||||
// Tile the tensor for the C-threads
|
||||
auto thr_tile = make_tile(_,
|
||||
make_tile(make_layout(size<1>(ThrLayoutVMNK{})),
|
||||
make_layout(size<2>(ThrLayoutVMNK{}))));
|
||||
make_tile(make_layout(size<1>(thr_layout_vmnk_)),
|
||||
make_layout(size<2>(thr_layout_vmnk_))));
|
||||
auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrN)),(FrgV,(RestM,RestN)))
|
||||
|
||||
return thr_tensor;
|
||||
}
|
||||
|
||||
// Tile from (M,N,...)
|
||||
// to (thr_idx,(FrgV,(RestM,RestN,...)))
|
||||
template <class CTensor>
|
||||
CUTE_HOST_DEVICE constexpr static
|
||||
auto
|
||||
tidfrg_C(CTensor&& ctensor)
|
||||
{
|
||||
// Don't need a ctile composition because ThrK is last mode in TidLayout
|
||||
|
||||
return thrfrg_C(ctensor).compose(TidLayout{}, _);
|
||||
}
|
||||
|
||||
// Tile a tensor or a layout from shape
|
||||
// (M,K,...)
|
||||
// to shape
|
||||
@ -297,17 +283,17 @@ struct TiledMMA : MMA_Atom
|
||||
// RestM: The values tiled in M.
|
||||
// RestK: The values tiled in K.
|
||||
template <class ATensor>
|
||||
CUTE_HOST_DEVICE constexpr static
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
thrfrg_A(ATensor&& atensor)
|
||||
thrfrg_A(ATensor&& atensor) const
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(rank(atensor) >= Int<2>{});
|
||||
//CUTE_STATIC_ASSERT_V(size<0>(atensor) % size<0>(TiledShape_MNK{}) == Int<0>{});
|
||||
//UTE_STATIC_ASSERT_V(size<1>(atensor) % size<2>(TiledShape_MNK{}) == Int<0>{});
|
||||
//CUTE_STATIC_ASSERT_V(size<1>(atensor) % size<2>(TiledShape_MNK{}) == Int<0>{});
|
||||
|
||||
// Reorder the tensor for the TiledAtom
|
||||
auto t_tile = make_tile(left_inverse(get<0>(PermutationsMNK{})),
|
||||
left_inverse(get<2>(PermutationsMNK{})));
|
||||
auto t_tile = make_tile(get<0>(PermutationMNK{}),
|
||||
get<2>(PermutationMNK{}));
|
||||
auto t_tensor = logical_divide(atensor, t_tile); // (PermM,PermK)
|
||||
|
||||
// Tile the tensor for the Atom
|
||||
@ -320,29 +306,13 @@ struct TiledMMA : MMA_Atom
|
||||
|
||||
// Tile the tensor for the Thread
|
||||
auto thr_tile = make_tile(_,
|
||||
make_tile(make_layout(size<1>(ThrLayoutVMNK{})),
|
||||
make_layout(size<3>(ThrLayoutVMNK{}))));
|
||||
make_tile(make_layout(size<1>(thr_layout_vmnk_)),
|
||||
make_layout(size<3>(thr_layout_vmnk_))));
|
||||
auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK)))
|
||||
|
||||
return thr_tensor;
|
||||
}
|
||||
|
||||
// Tile from (M,K,...)
|
||||
// to (thr_idx,(FrgV,(RestM,RestK,...)))
|
||||
template <class ATensor>
|
||||
CUTE_HOST_DEVICE constexpr static
|
||||
auto
|
||||
tidfrg_A(ATensor&& atensor)
|
||||
{
|
||||
auto atile = make_tile(_,
|
||||
make_tile(make_layout(make_shape (size<1>(ThrLayoutVMNK{}), size<2>(ThrLayoutVMNK{})),
|
||||
make_stride( Int<1>{} , Int<0>{} )),
|
||||
_));
|
||||
// (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK))
|
||||
|
||||
return thrfrg_A(atensor).compose(atile, _).compose(TidLayout{}, _);
|
||||
}
|
||||
|
||||
// Tile a tensor or a layout from shape
|
||||
// (N,K,...)
|
||||
// to shape
|
||||
@ -355,17 +325,17 @@ struct TiledMMA : MMA_Atom
|
||||
// RestN: The values tiled in N.
|
||||
// RestK: The values tiled in K.
|
||||
template <class BTensor>
|
||||
CUTE_HOST_DEVICE constexpr static
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
thrfrg_B(BTensor&& btensor)
|
||||
thrfrg_B(BTensor&& btensor) const
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(rank(btensor) >= Int<2>{});
|
||||
//CUTE_STATIC_ASSERT_V(size<0>(btensor) % size<1>(TiledShape_MNK{}) == Int<0>{});
|
||||
//CUTE_STATIC_ASSERT_V(size<1>(btensor) % size<2>(TiledShape_MNK{}) == Int<0>{});
|
||||
|
||||
// Reorder the tensor for the TiledAtom
|
||||
auto t_tile = make_tile(left_inverse(get<1>(PermutationsMNK{})),
|
||||
left_inverse(get<2>(PermutationsMNK{})));
|
||||
auto t_tile = make_tile(get<1>(PermutationMNK{}),
|
||||
get<2>(PermutationMNK{}));
|
||||
auto t_tensor = logical_divide(btensor, t_tile); // (PermN,PermK)
|
||||
|
||||
// Tile the tensor for the Atom
|
||||
@ -378,44 +348,28 @@ struct TiledMMA : MMA_Atom
|
||||
|
||||
// Tile the tensor for the Thread
|
||||
auto thr_tile = make_tile(_,
|
||||
make_tile(make_layout(size<2>(ThrLayoutVMNK{})),
|
||||
make_layout(size<3>(ThrLayoutVMNK{}))));
|
||||
make_tile(make_layout(size<2>(thr_layout_vmnk_)),
|
||||
make_layout(size<3>(thr_layout_vmnk_))));
|
||||
auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK)))
|
||||
|
||||
return thr_tensor;
|
||||
}
|
||||
|
||||
// Tile from (N,K,...)
|
||||
// to (thr_idx,(FrgV,(RestN,RestK,...)))
|
||||
template <class BTensor>
|
||||
CUTE_HOST_DEVICE constexpr static
|
||||
template <class ThrIdx,
|
||||
__CUTE_REQUIRES(is_integral<ThrIdx>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tidfrg_B(BTensor&& btensor)
|
||||
get_slice(ThrIdx const& thr_idx) const
|
||||
{
|
||||
auto btile = make_tile(_,
|
||||
make_tile(make_layout(make_shape (size<1>(ThrLayoutVMNK{}), size<2>(ThrLayoutVMNK{})),
|
||||
make_stride( Int<0>{} , Int<1>{} )),
|
||||
_));
|
||||
// (ThrV,(ThrN,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK))
|
||||
|
||||
return thrfrg_B(btensor).compose(btile, _).compose(TidLayout{}, _);
|
||||
auto thr_vmnk = thr_layout_vmnk_.get_flat_coord(thr_idx);
|
||||
return ThrMMA<TiledMMA, decltype(thr_vmnk)>{*this, thr_vmnk};
|
||||
}
|
||||
|
||||
template <class ThrIdx,
|
||||
__CUTE_REQUIRES(is_integral<ThrIdx>::value)>
|
||||
CUTE_HOST_DEVICE static constexpr
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
get_slice(ThrIdx const& thr_idx)
|
||||
{
|
||||
auto thr_vmnk = ThrLayoutVMNK{}.get_flat_coord(thr_idx);
|
||||
return ThrMMA<TiledMMA, decltype(thr_vmnk)>(thr_vmnk);
|
||||
}
|
||||
|
||||
template <class ThrIdx,
|
||||
__CUTE_REQUIRES(is_integral<ThrIdx>::value)>
|
||||
CUTE_HOST_DEVICE static constexpr
|
||||
auto
|
||||
get_thread_slice(ThrIdx const& thr_idx)
|
||||
get_thread_slice(ThrIdx const& thr_idx) const
|
||||
{
|
||||
return get_slice(thr_idx);
|
||||
}
|
||||
@ -424,104 +378,144 @@ struct TiledMMA : MMA_Atom
|
||||
// Utility for printing and visualization
|
||||
//
|
||||
|
||||
CUTE_HOST_DEVICE constexpr static
|
||||
// The size of the MNK-mode
|
||||
template <int I>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
get_layoutC_MN()
|
||||
tile_size_mnk() const {
|
||||
static_assert(0 <= I && I < 3);
|
||||
auto core_size = size<I>(AtomShape_MNK{}) * size<I+1>(get_thr_layout_vmnk());
|
||||
[[maybe_unused]] auto perm_size = size<I>(PermutationMNK{});
|
||||
if constexpr (is_underscore<decltype(perm_size)>::value) {
|
||||
return core_size;
|
||||
} else {
|
||||
return cute::max(core_size, perm_size);
|
||||
}
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
get_layoutC_MN() const
|
||||
{
|
||||
// (M,N) -> (M,N)
|
||||
auto ref_C = make_layout(make_shape(size<0>(TiledShape_MNK{}), size<1>(TiledShape_MNK{})));
|
||||
auto ref_C = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<1>()));
|
||||
// (cthrid,val) -> (M,N)
|
||||
auto layoutC_TV = thrfrg_C(ref_C);
|
||||
// (M,N) -> (cthrid,frg)
|
||||
auto layoutC_MN = right_inverse(layoutC_TV).with_shape(shape(ref_C));
|
||||
|
||||
// cthrid = (v,m,n) -> thr_idx
|
||||
auto thrID_C = ThrLayoutVMNK{}(_,_,_,Int<0>{});
|
||||
auto thrID_C = thr_layout_vmnk_(_,_,_,Int<0>{});
|
||||
|
||||
return cute::make_tuple(layoutC_MN, thrID_C);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr static
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
get_layoutC_TV()
|
||||
get_layoutC_TV() const
|
||||
{
|
||||
// (M,N) -> (M,N)
|
||||
auto ref_C = make_layout(make_shape(size<0>(TiledShape_MNK{}), size<1>(TiledShape_MNK{})));
|
||||
auto ref_C = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<1>()));
|
||||
// (cthrid,val) -> (M,N)
|
||||
auto layoutC_TV = thrfrg_C(ref_C);
|
||||
|
||||
return tidfrg_C(ref_C);
|
||||
// thr_idx -> (ThrV,ThrM,ThrN,ThrK)
|
||||
auto thridx_2_thrid = right_inverse(thr_layout_vmnk_);
|
||||
|
||||
// (thr_idx,val) -> (M,N)
|
||||
return layoutC_TV.compose(thridx_2_thrid, _);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr static
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
get_layoutA_MK()
|
||||
get_layoutA_MK() const
|
||||
{
|
||||
// (M,K) -> (M,K)
|
||||
auto ref_A = make_layout(make_shape(size<0>(TiledShape_MNK{}), size<2>(TiledShape_MNK{})));
|
||||
auto ref_A = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<2>()));
|
||||
// (athrid,val) -> (M,K)
|
||||
auto layoutA_TV = thrfrg_A(ref_A);
|
||||
// (M,K) -> (athrid,frg)
|
||||
auto layoutA_MK = right_inverse(layoutA_TV).with_shape(shape(ref_A));
|
||||
|
||||
// athrid = (v,m,k) -> thr_idx
|
||||
auto thrID_A = ThrLayoutVMNK{}(_,_,Int<0>{},_);
|
||||
auto thrID_A = thr_layout_vmnk_(_,_,Int<0>{},_);
|
||||
|
||||
return cute::make_tuple(layoutA_MK, thrID_A);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr static
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
get_layoutA_TV()
|
||||
get_layoutA_TV() const
|
||||
{
|
||||
// (M,K) -> (M,K)
|
||||
auto ref_A = make_layout(make_shape(size<0>(TiledShape_MNK{}), size<2>(TiledShape_MNK{})));
|
||||
auto ref_A = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<2>()));
|
||||
// (athrid,val) -> (M,K)
|
||||
auto layoutA_TV = thrfrg_A(ref_A);
|
||||
|
||||
return tidfrg_A(ref_A);
|
||||
// (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK))
|
||||
auto atile = make_tile(_,
|
||||
make_tile(make_layout(make_shape (size<1>(thr_layout_vmnk_), size<2>(thr_layout_vmnk_)),
|
||||
make_stride( Int<1>{} , Int<0>{} )),
|
||||
_));
|
||||
|
||||
// thr_idx -> (ThrV,ThrM,ThrN,ThrK)
|
||||
auto thridx_2_thrid = right_inverse(thr_layout_vmnk_);
|
||||
|
||||
// (thr_idx,val) -> (M,K)
|
||||
return thrfrg_A(ref_A).compose(atile, _).compose(thridx_2_thrid, _);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr static
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
get_layoutB_NK()
|
||||
get_layoutB_NK() const
|
||||
{
|
||||
// (N,K) -> (N,K)
|
||||
auto ref_B = make_layout(make_shape(size<1>(TiledShape_MNK{}), size<2>(TiledShape_MNK{})));
|
||||
auto ref_B = make_layout(make_shape(tile_size_mnk<1>(), tile_size_mnk<2>()));
|
||||
// (bthrid,val) -> (N,K)
|
||||
auto layoutB_TV = thrfrg_B(ref_B);
|
||||
// (N,K) -> (bthrid,frg)
|
||||
auto layoutB_NK = right_inverse(layoutB_TV).with_shape(shape(ref_B));
|
||||
|
||||
// bthrid = (v,n,k) -> thr_idx
|
||||
auto thrID_B = ThrLayoutVMNK{}(_,Int<0>{},_,_);
|
||||
auto thrID_B = thr_layout_vmnk_(_,Int<0>{},_,_);
|
||||
|
||||
return cute::make_tuple(layoutB_NK, thrID_B);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr static
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
get_layoutB_TV()
|
||||
get_layoutB_TV() const
|
||||
{
|
||||
// (N,K) -> (N,K)
|
||||
auto ref_B = make_layout(make_shape(size<1>(TiledShape_MNK{}), size<2>(TiledShape_MNK{})));
|
||||
auto ref_B = make_layout(make_shape(tile_size_mnk<1>(), tile_size_mnk<2>()));
|
||||
// (bthrid,val) -> (N,K)
|
||||
auto layoutB_TV = thrfrg_B(ref_B);
|
||||
|
||||
return tidfrg_B(ref_B);
|
||||
// (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK))
|
||||
auto btile = make_tile(_,
|
||||
make_tile(make_layout(make_shape (size<1>(thr_layout_vmnk_), size<2>(thr_layout_vmnk_)),
|
||||
make_stride( Int<0>{} , Int<1>{} )),
|
||||
_));
|
||||
|
||||
// thr_idx -> (ThrV,ThrM,ThrN,ThrK)
|
||||
auto thridx_2_thrid = right_inverse(thr_layout_vmnk_);
|
||||
|
||||
// (thr_idx,val) -> (N,K)
|
||||
return thrfrg_B(ref_B).compose(btile, _).compose(thridx_2_thrid, _);
|
||||
}
|
||||
};
|
||||
|
||||
template <class TiledMMA, class ThrVMNK>
|
||||
struct ThrMMA : TiledMMA
|
||||
{
|
||||
// Use ThrVMNK and thrfrg rather than thr_idx and tidfrg
|
||||
// to support swizzled threads partitioning dynamic layouts
|
||||
ThrVMNK thr_vmnk_;
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
ThrMMA(ThrVMNK const& thr_vmnk) : thr_vmnk_(thr_vmnk) {}
|
||||
|
||||
template <class CTensor>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
partition_C(CTensor&& ctensor) const
|
||||
{
|
||||
auto thr_tensor = make_tensor(std::forward<CTensor>(ctensor).data(), TiledMMA::thrfrg_C(ctensor.layout()));
|
||||
auto thr_tensor = make_tensor(std::forward<CTensor>(ctensor).data(), this->thrfrg_C(ctensor.layout()));
|
||||
|
||||
auto thr_vmn = make_coord(get<0>(thr_vmnk_), make_coord(get<1>(thr_vmnk_), get<2>(thr_vmnk_)));
|
||||
return thr_tensor(thr_vmn, make_coord(_, repeat<rank<1,1>(thr_tensor)>(_)));
|
||||
@ -532,7 +526,7 @@ struct ThrMMA : TiledMMA
|
||||
auto
|
||||
partition_A(ATensor&& atensor) const
|
||||
{
|
||||
auto thr_tensor = make_tensor(std::forward<ATensor>(atensor).data(), TiledMMA::thrfrg_A(atensor.layout()));
|
||||
auto thr_tensor = make_tensor(std::forward<ATensor>(atensor).data(), this->thrfrg_A(atensor.layout()));
|
||||
|
||||
auto thr_vmk = make_coord(get<0>(thr_vmnk_), make_coord(get<1>(thr_vmnk_), get<3>(thr_vmnk_)));
|
||||
return thr_tensor(thr_vmk, make_coord(_, repeat<rank<1,1>(thr_tensor)>(_)));
|
||||
@ -543,7 +537,7 @@ struct ThrMMA : TiledMMA
|
||||
auto
|
||||
partition_B(BTensor&& btensor) const
|
||||
{
|
||||
auto thr_tensor = make_tensor(std::forward<BTensor>(btensor).data(), TiledMMA::thrfrg_B(btensor.layout()));
|
||||
auto thr_tensor = make_tensor(std::forward<BTensor>(btensor).data(), this->thrfrg_B(btensor.layout()));
|
||||
|
||||
auto thr_vnk = make_coord(get<0>(thr_vmnk_), make_coord(get<2>(thr_vmnk_), get<3>(thr_vmnk_)));
|
||||
return thr_tensor(thr_vnk, make_coord(_, repeat<rank<1,1>(thr_tensor)>(_)));
|
||||
@ -580,38 +574,32 @@ struct ThrMMA : TiledMMA
|
||||
|
||||
template <class MMA_Op,
|
||||
class MMAThrLayout = Layout<Shape<_1,_1,_1>>,
|
||||
class MMAValLayout = Layout<Shape<_1,_1,_1>>,
|
||||
class Permutations = Tile<Underscore,Underscore,Underscore>>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
make_tiled_mma(MMA_Atom<MMA_Op> const&,
|
||||
make_tiled_mma(MMA_Atom<MMA_Op> const& mma_atom,
|
||||
MMAThrLayout const& thr_layout = {},
|
||||
MMAValLayout const& val_layout = {},
|
||||
Permutations const& permutations = {})
|
||||
{
|
||||
auto thr_layout_mnk = append<3>(thr_layout, Layout<_1,_0>{});
|
||||
auto val_layout_mnk = append<3>(val_layout, Layout<_1,_0>{});
|
||||
auto permutation_mnk = append<3>(permutations, _);
|
||||
|
||||
return TiledMMA<MMA_Atom<MMA_Op>,
|
||||
decltype(thr_layout_mnk),
|
||||
decltype(val_layout_mnk),
|
||||
decltype(permutation_mnk)>{};
|
||||
decltype(permutation_mnk)>{mma_atom, thr_layout_mnk};
|
||||
}
|
||||
|
||||
template <class MMA_Op,
|
||||
class MMAThrLayout = Layout<Shape<_1,_1,_1>>,
|
||||
class MMAValLayout = Layout<Shape<_1,_1,_1>>,
|
||||
class Permutations = Tile<Underscore,Underscore,Underscore>>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
make_tiled_mma(MMA_Op const&,
|
||||
MMAThrLayout const& thr_layout = {},
|
||||
MMAValLayout const& val_layout = {},
|
||||
Permutations const& permutations = {})
|
||||
{
|
||||
// Attempt to wrap in an MMA_Atom<> and forward
|
||||
return make_tiled_mma(MMA_Atom<MMA_Op>{}, thr_layout, val_layout, permutations);
|
||||
return make_tiled_mma(MMA_Atom<MMA_Op>{}, thr_layout, permutations);
|
||||
}
|
||||
|
||||
//
|
||||
@ -680,28 +668,38 @@ partition_shape_B(TiledMMA<Args...> const& mma, Shape_NK const& shape_NK)
|
||||
// Size
|
||||
//
|
||||
|
||||
template <int... I, class... Args>
|
||||
template <int I, class... Args>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tile_size(TiledMMA<Args...> const& mma)
|
||||
{
|
||||
return size<I...>(typename TiledMMA<Args...>::TiledShape_MNK{});
|
||||
return mma.template tile_size_mnk<I>();
|
||||
}
|
||||
|
||||
template <int... I, class... Args>
|
||||
template <class... Args>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tile_shape(TiledMMA<Args...> const& mma)
|
||||
{
|
||||
return shape<I...>(typename TiledMMA<Args...>::TiledShape_MNK{});
|
||||
return make_shape(tile_size<0>(mma), tile_size<1>(mma), tile_size<2>(mma));
|
||||
}
|
||||
|
||||
// Deprecate?
|
||||
template <int... I, class... Args>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
size(TiledMMA<Args...> const& mma)
|
||||
{
|
||||
return size<I...>(typename TiledMMA<Args...>::ThrLayoutVMNK{});
|
||||
return size<I...>(mma.get_thr_layout_vmnk());
|
||||
}
|
||||
|
||||
// Alias
|
||||
template <int... I, class... Args>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
thr_size(TiledMMA<Args...> const& mma)
|
||||
{
|
||||
return size<I...>(mma.get_thr_layout_vmnk());
|
||||
}
|
||||
|
||||
//
|
||||
@ -715,33 +713,31 @@ print(MMA_Atom<MMA_Traits<Args...>> const&)
|
||||
{
|
||||
using Atom = MMA_Atom<MMA_Traits<Args...>>;
|
||||
print("MMA_Atom\n");
|
||||
print(" ThrID: "); print(typename Atom::ThrID{}); print("\n");
|
||||
print(" LayoutA_TV: "); print(typename Atom::LayoutA_TV{}); print("\n");
|
||||
print(" LayoutB_TV: "); print(typename Atom::LayoutB_TV{}); print("\n");
|
||||
print(" LayoutC_TV: "); print(typename Atom::LayoutC_TV{}); print("\n");
|
||||
print(" ThrID: "); print(typename Atom::ThrID{}); print("\n");
|
||||
print(" LayoutA_TV: "); print(typename Atom::LayoutA_TV{}); print("\n");
|
||||
print(" LayoutB_TV: "); print(typename Atom::LayoutB_TV{}); print("\n");
|
||||
print(" LayoutC_TV: "); print(typename Atom::LayoutC_TV{}); print("\n");
|
||||
}
|
||||
|
||||
template <class Atom, class TiledThr, class TiledVal, class TiledPerm>
|
||||
template <class Atom, class TiledThr, class TiledPerm>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print(TiledMMA<Atom, TiledThr, TiledVal, TiledPerm> const& mma)
|
||||
print(TiledMMA<Atom, TiledThr, TiledPerm> const& mma)
|
||||
{
|
||||
using MMA = TiledMMA<Atom, TiledThr, TiledVal, TiledPerm>;
|
||||
print("TiledMMA\n");
|
||||
print(" TiledThr: "); print(TiledThr{}); print("\n");
|
||||
print(" TiledVal: "); print(TiledVal{}); print("\n");
|
||||
print(" TiledPerm: "); print(TiledPerm{}); print("\n");
|
||||
print(" TiledShape_MNK: "); print(typename MMA::TiledShape_MNK{}); print("\n");
|
||||
print(" ThrLayoutVMNK: "); print(typename MMA::ThrLayoutVMNK{}); print("\n");
|
||||
print(" ThrLayoutVMNK: "); print(mma.get_thr_layout_vmnk()); print("\n");
|
||||
print(" PermutationMNK: "); print(TiledPerm{}); print("\n");
|
||||
print(static_cast<Atom const&>(mma));
|
||||
}
|
||||
|
||||
template <class TiledMMA, class ThrVMNK>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print(ThrMMA<TiledMMA, ThrVMNK> const&)
|
||||
print(ThrMMA<TiledMMA, ThrVMNK> const& thr_mma)
|
||||
{
|
||||
print(TiledMMA{});
|
||||
print("ThrMMA\n");
|
||||
print(" Thr VMNK: "); print(thr_mma.thr_vmnk_); print("\n");
|
||||
print(static_cast<TiledMMA>(thr_mma));
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
@ -766,18 +762,6 @@ print_latex(TiledMMA<Args...> const& mma)
|
||||
layoutB_NK, thrID_B);
|
||||
}
|
||||
|
||||
// EXPERIMENTAL -- Doesn't work with Swizzled Thr TileMMAs...
|
||||
template <class... Args>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
print_latex_2(TiledMMA<Args...> const& mma)
|
||||
{
|
||||
print_latex_mma(typename TiledMMA<Args...>::TiledShape_MNK{},
|
||||
mma.get_layoutC_TV(),
|
||||
mma.get_layoutA_TV(),
|
||||
mma.get_layoutB_TV());
|
||||
}
|
||||
|
||||
// MNK MMA Layout to console printer -- 8-value color coded by thread
|
||||
template <class LayoutC, class ThrIDC,
|
||||
class LayoutA, class ThrIDA,
|
||||
@ -943,122 +927,6 @@ print_latex_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and
|
||||
printf(latex_footer);
|
||||
}
|
||||
|
||||
// ThrVal MMA Layout to Latex TIKZ -- 8-value color coded by thread
|
||||
template <class Shape_MNK,
|
||||
class LayoutC, class LayoutA, class LayoutB>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print_latex_mma(Shape_MNK const& shape_mnk,
|
||||
LayoutC const& C, // (thr_idx,vid) -> (m,n)
|
||||
LayoutA const& A, // (thr_idx,vid) -> (m,k)
|
||||
LayoutB const& B) // (thr_idx,vid) -> (n,k)
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{});
|
||||
|
||||
char const* latex_header =
|
||||
"\\documentclass{standalone}\n"
|
||||
"\\usepackage{tikz}\n"
|
||||
"\\usetikzlibrary{external}\n"
|
||||
"\\tikzexternalize\n"
|
||||
"\\begin{document}\n"
|
||||
"\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/.style={rectangle,draw=black,thick,minimum size=1cm,anchor=center}]\n\n";
|
||||
char const* latex_footer =
|
||||
"\\end{tikzpicture}\n"
|
||||
"\\end{document}\n";
|
||||
|
||||
char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}",
|
||||
"{rgb,255:red,175;green,255;blue,175}",
|
||||
"{rgb,255:red,255;green,255;blue,175}",
|
||||
"{rgb,255:red,255;green,175;blue,175}",
|
||||
"{rgb,255:red,210;green,210;blue,255}",
|
||||
"{rgb,255:red,210;green,255;blue,210}",
|
||||
"{rgb,255:red,255;green,255;blue,210}",
|
||||
"{rgb,255:red,255;green,210;blue,210}"};
|
||||
|
||||
// Header
|
||||
printf("%% Shape_MNK: "); print(shape_mnk); printf("\n");
|
||||
printf("%% LayoutC : "); print(C); printf("\n");
|
||||
printf("%% LayoutA : "); print(A); printf("\n");
|
||||
printf("%% LayoutB : "); print(B); printf("\n\n");
|
||||
|
||||
printf(latex_header);
|
||||
|
||||
auto M = size<0>(shape_mnk);
|
||||
auto N = size<1>(shape_mnk);
|
||||
auto K = size<2>(shape_mnk);
|
||||
|
||||
// C starting at 0,0
|
||||
bool c_filled[M][N] = {};
|
||||
for (int t = 0; t < size<0>(C); ++t) {
|
||||
for (int v = 0; v < size<1>(C); ++v) {
|
||||
int m = C(t,v) % M;
|
||||
int n = C(t,v) / M;
|
||||
|
||||
if (not c_filled[m][n]) {
|
||||
printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
|
||||
color_map[t % 8],
|
||||
m, n,
|
||||
t, v);
|
||||
c_filled[m][n] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A starting at 0,-size<1>(A)-1
|
||||
bool a_filled[M][K] = {};
|
||||
for (int t = 0; t < size<0>(A); ++t) {
|
||||
for (int v = 0; v < size<1>(A); ++v) {
|
||||
int m = A(t,v) % M;
|
||||
int k = A(t,v) / M;
|
||||
|
||||
if (not a_filled[m][k]) {
|
||||
printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
|
||||
color_map[t % 8],
|
||||
m, k - 1 - K,
|
||||
t, v);
|
||||
a_filled[m][k] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// B starting at -size<1>(B)-1,0
|
||||
bool b_filled[N][K] = {};
|
||||
for (int t = 0; t < size<0>(B); ++t) {
|
||||
for (int v = 0; v < size<1>(B); ++v) {
|
||||
int n = B(t,v) % N;
|
||||
int k = B(t,v) / N;
|
||||
|
||||
if (not b_filled[n][k]) {
|
||||
printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
|
||||
color_map[t % 8],
|
||||
k - 1 - K, n,
|
||||
t, v);
|
||||
b_filled[n][k] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A labels
|
||||
for (int m = 0, k = -1; m < M; ++m) {
|
||||
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k - 1 - K, m);
|
||||
}
|
||||
for (int k = 0, m = -1; k < K; ++k) {
|
||||
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k - 1 - K, k);
|
||||
}
|
||||
// B labels
|
||||
for (int n = 0, k = -1; n < N; ++n) {
|
||||
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", k - 1 - K, n, n);
|
||||
}
|
||||
for (int k = 0, n = -1; k < K; ++k) {
|
||||
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", k - 1 - K, n, k);
|
||||
}
|
||||
|
||||
// Footer
|
||||
printf(latex_footer);
|
||||
}
|
||||
|
||||
} // namespace cute
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -47,7 +47,7 @@
|
||||
#endif
|
||||
|
||||
#if !defined(__CUDACC_RTC__) && !defined(__clang__) && \
|
||||
(defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA))
|
||||
(defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA))
|
||||
# define CUTE_UNROLL #pragma unroll
|
||||
# define CUTE_NO_UNROLL #pragma unroll 1
|
||||
#elif defined(__CUDACC_RTC__) || defined(__clang__)
|
||||
@ -120,6 +120,8 @@
|
||||
#include <cassert>
|
||||
#endif
|
||||
|
||||
#define CUTE_STATIC_V(x) decltype(x)::value
|
||||
|
||||
#define CUTE_STATIC_ASSERT static_assert
|
||||
#define CUTE_STATIC_ASSERT_V(x,...) static_assert(decltype(x)::value, ##__VA_ARGS__)
|
||||
|
||||
|
||||
@ -91,7 +91,7 @@ private:
|
||||
// Flag for fast branching on straddled elements
|
||||
static constexpr bool is_storage_unaligned = ((sizeof_bits_v<storage_type> % sizeof_bits_v<element_type>) != 0);
|
||||
|
||||
friend class subbyte_iterator<T>;
|
||||
friend struct subbyte_iterator<T>;
|
||||
|
||||
// Pointer to storage element
|
||||
storage_type* ptr_ = nullptr;
|
||||
@ -208,7 +208,7 @@ struct subbyte_iterator
|
||||
|
||||
private:
|
||||
|
||||
template <class, class> friend class swizzle_ptr;
|
||||
template <class, class> friend struct swizzle_ptr;
|
||||
|
||||
// Pointer to storage element
|
||||
storage_type* ptr_ = nullptr;
|
||||
|
||||
@ -327,7 +327,7 @@ ceil_div(IntTupleA const& a, IntTupleB const& b)
|
||||
{
|
||||
if constexpr (is_tuple<IntTupleA>::value && is_tuple<IntTupleB>::value) {
|
||||
static_assert(tuple_size<IntTupleA>::value >= tuple_size<IntTupleB>::value, "Mismatched ranks");
|
||||
constexpr int R = tuple_size<IntTupleA>::value; // Missing ranks in TupleB are implictly 1
|
||||
constexpr int R = tuple_size<IntTupleA>::value; // Missing ranks in TupleB are implicitly 1
|
||||
return transform(a, append<R>(b,Int<1>{}), [](auto const& x, auto const& y) { return ceil_div(x,y); });
|
||||
} else {
|
||||
return (a + b - Int<1>{}) / b;
|
||||
@ -336,6 +336,28 @@ ceil_div(IntTupleA const& a, IntTupleB const& b)
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// round_up
|
||||
// Round @a a up to the nearest multiple of @a b.
|
||||
// For negative numbers, rounds away from zero.
|
||||
//
|
||||
|
||||
template <class IntTupleA, class IntTupleB>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
round_up(IntTupleA const& a, IntTupleB const& b)
|
||||
{
|
||||
if constexpr (is_tuple<IntTupleA>::value && is_tuple<IntTupleB>::value) {
|
||||
static_assert(tuple_size<IntTupleA>::value >= tuple_size<IntTupleB>::value, "Mismatched ranks");
|
||||
constexpr int R = tuple_size<IntTupleA>::value; // Missing ranks in TupleB are implicitly 1
|
||||
return transform(a, append<R>(b,Int<1>{}), [](auto const& x, auto const& y) { return round_up(x,y); });
|
||||
} else {
|
||||
return ((a + b - Int<1>{}) / b) * b;
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
/** Division for Shapes
|
||||
* Case Tuple Tuple:
|
||||
* Perform shape_div element-wise
|
||||
@ -429,6 +451,7 @@ template <class A, class B>
|
||||
using is_congruent = decltype(congruent(declval<A>(), declval<B>()));
|
||||
|
||||
/** Test if two IntTuple have the similar profiles up to Shape A (hierarchical rank division)
|
||||
* weakly_congruent is a partial order on A and B: A <= B
|
||||
*/
|
||||
template <class IntTupleA, class IntTupleB>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
@ -458,7 +481,7 @@ using is_weakly_congruent = decltype(weakly_congruent(declval<A>(), declval<B>()
|
||||
|
||||
/** Test if Shape B is compatible with Shape A:
|
||||
* Any coordinate into A can also be used as a coordinate into B
|
||||
* A <= B is a partially ordered set of factored shapes
|
||||
* compatible is a partial order on A and B: A <= B
|
||||
*/
|
||||
template <class IntTupleA, class IntTupleB>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
@ -487,7 +510,8 @@ template <class A, class B>
|
||||
using is_compatible = decltype(compatible(declval<A>(), declval<B>()));
|
||||
|
||||
/** Test if Shape B is weakly compatible with Shape A:
|
||||
* Shape B divides Shape A at some level of refinement
|
||||
* Shape B is a multiple of a shape that is compatible with Shape A
|
||||
* weakly_compatible is a partial order on A and B: A <= B
|
||||
*/
|
||||
template <class IntTupleA, class IntTupleB>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
@ -502,7 +526,7 @@ weakly_compatible(IntTupleA const& a, IntTupleB const& b)
|
||||
[](auto const&... z) { return (true_type{} && ... && z); });
|
||||
}
|
||||
} else if constexpr (is_integral<IntTupleA>::value) {
|
||||
return a % size(b) == Int<0>{};
|
||||
return size(b) % a == Int<0>{};
|
||||
} else if constexpr (is_integral<IntTupleB>::value) {
|
||||
return false_type{};
|
||||
} else {
|
||||
|
||||
@ -981,7 +981,6 @@ auto
|
||||
composition(Layout<LShape,LStride> const& lhs,
|
||||
Layout<RShape,RStride> const& rhs)
|
||||
{
|
||||
//return detail::composition_impl(flatten(lhs), rhs.shape(), rhs.stride());
|
||||
return detail::composition_impl(lhs, rhs.shape(), rhs.stride());
|
||||
}
|
||||
|
||||
@ -997,8 +996,8 @@ composition(Layout<LShape,LStride> const& lhs,
|
||||
return detail::transform_layout(lhs, rhs, [](auto const& l, auto const& r) { return composition(l,r); }, make_seq<tuple_size<IntTuple>::value>{}, seq<>{}, seq<>{});
|
||||
} else if constexpr (is_underscore<IntTuple>::value) {
|
||||
return lhs;
|
||||
} else {
|
||||
return composition(lhs, make_layout(rhs));
|
||||
} else if constexpr (is_integral<IntTuple>::value) {
|
||||
return detail::composition_impl(lhs, rhs, Int<1>{});
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
@ -1097,15 +1096,18 @@ inverse_seq(Shape const& shape, Stride const& stride, seq<Is...>)
|
||||
auto next_I = cute::find_if(stride, [](auto a) { return is_constant<NextStride, decltype(a)>{}; });
|
||||
|
||||
if constexpr (next_I == decltype(rank(stride))::value) {
|
||||
// If not found, return current seq
|
||||
return seq<Is...>{};
|
||||
} else {
|
||||
// auto next_stride = get<next_I>(shape) * get<next_I>(stride);
|
||||
// NOTE: Needed for g++-7
|
||||
using next_stride = decltype(get<next_I>(shape) * get<next_I>(stride));
|
||||
|
||||
if constexpr (is_static<next_stride>::value) {
|
||||
if constexpr (is_static<next_stride>::value && !is_constant<NextStride, next_stride>::value) {
|
||||
// If next_stride is static and unique, then continue
|
||||
return inverse_seq<next_stride::value>(shape, stride, seq<Is..., next_I>{});
|
||||
} else {
|
||||
// Else return current seq + next_I
|
||||
return seq<Is..., next_I>{};
|
||||
}
|
||||
}
|
||||
@ -1340,28 +1342,24 @@ template <class LShape, class LStride,
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
logical_divide(Layout<LShape,LStride> const& layout,
|
||||
Layout<TShape,TStride> const& tile)
|
||||
Layout<TShape,TStride> const& tiler)
|
||||
{
|
||||
//CUTE_STATIC_ASSERT_V(size(layout) % size(tile) == Int<0>{},
|
||||
// "Tiling does not evenly divide the block");
|
||||
// NOTE: With tiles that have stride-0, this doesn't have to be true
|
||||
|
||||
return composition(layout, make_layout(tile, complement(tile, size(layout))));
|
||||
return composition(layout, make_layout(tiler, complement(tiler, size(layout))));
|
||||
}
|
||||
|
||||
template <class LShape, class LStride, class IntTuple>
|
||||
template <class LShape, class LStride, class Tiler>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
logical_divide(Layout<LShape,LStride> const& layout,
|
||||
IntTuple const& tile)
|
||||
Tiler const& tiler)
|
||||
{
|
||||
if constexpr (is_tuple<IntTuple>::value) {
|
||||
static_assert(tuple_size<IntTuple>::value <= Layout<LShape,LStride>::rank, "logical_divide: Too many modes in tile.");
|
||||
return transform_layout(layout, tile, [](auto const& l, auto const& t) { return logical_divide(l,t); });
|
||||
} else if constexpr (is_underscore<IntTuple>::value) {
|
||||
if constexpr (is_tuple<Tiler>::value) {
|
||||
static_assert(tuple_size<Tiler>::value <= Layout<LShape,LStride>::rank, "logical_divide: Too many modes in tiler.");
|
||||
return transform_layout(layout, tiler, [](auto const& l, auto const& t) { return logical_divide(l,t); });
|
||||
} else if constexpr (is_underscore<Tiler>::value) {
|
||||
return layout;
|
||||
} else if constexpr (is_integral<IntTuple>::value) {
|
||||
return logical_divide(layout, make_layout(tile));
|
||||
} else if constexpr (is_integral<Tiler>::value) {
|
||||
return logical_divide(layout, make_layout(tiler));
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
@ -1374,24 +1372,24 @@ logical_divide(Layout<LShape,LStride> const& layout,
|
||||
//
|
||||
|
||||
template <class LShape, class LStride,
|
||||
class Tile>
|
||||
class Tiler>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
zipped_divide(Layout<LShape,LStride> const& layout,
|
||||
Tile const& tile)
|
||||
Tiler const& tiler)
|
||||
{
|
||||
return tile_unzip(logical_divide(layout, tile), tile);
|
||||
return tile_unzip(logical_divide(layout, tiler), tiler);
|
||||
}
|
||||
|
||||
// Same as zipped_divide, but unpacks the second mode: ((BLK_A,BLK_B,...),a,b,...,x,y)
|
||||
template <class LShape, class LStride,
|
||||
class Tile>
|
||||
class Tiler>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tiled_divide(Layout<LShape,LStride> const& layout,
|
||||
Tile const& tile)
|
||||
Tiler const& tiler)
|
||||
{
|
||||
auto div = zipped_divide(layout, tile);
|
||||
auto div = zipped_divide(layout, tiler);
|
||||
|
||||
auto R = rank<1>(div);
|
||||
return div(_, repeat<R>(_));
|
||||
@ -1399,13 +1397,13 @@ tiled_divide(Layout<LShape,LStride> const& layout,
|
||||
|
||||
// Same as zipped_divide, but unpacks both modes: (BLK_A,BLK_B,...,a,b,...,x,y)
|
||||
template <class LShape, class LStride,
|
||||
class Tile>
|
||||
class Tiler>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
flat_divide(Layout<LShape,LStride> const& layout,
|
||||
Tile const& tile)
|
||||
Tiler const& tiler)
|
||||
{
|
||||
auto div = zipped_divide(layout, tile);
|
||||
auto div = zipped_divide(layout, tiler);
|
||||
|
||||
auto R0 = rank<0>(div);
|
||||
auto R1 = rank<1>(div);
|
||||
@ -1421,24 +1419,24 @@ template <class LShape, class LStride,
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
logical_product(Layout<LShape,LStride> const& layout,
|
||||
Layout<TShape,TStride> const& tile)
|
||||
Layout<TShape,TStride> const& tiler)
|
||||
{
|
||||
return make_layout(layout, composition(complement(layout, size(layout)*cosize(tile)), tile));
|
||||
return make_layout(layout, composition(complement(layout, size(layout)*cosize(tiler)), tiler));
|
||||
}
|
||||
|
||||
template <class LShape, class LStride, class IntTuple>
|
||||
template <class LShape, class LStride, class Tiler>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
logical_product(Layout<LShape,LStride> const& layout,
|
||||
IntTuple const& tile)
|
||||
Tiler const& tiler)
|
||||
{
|
||||
if constexpr (is_tuple<IntTuple>::value) {
|
||||
static_assert(tuple_size<IntTuple>::value <= Layout<LShape,LStride>::rank);
|
||||
return transform_layout(layout, tile, [](auto const& l, auto const& t) { return logical_product(l,t); });
|
||||
} else if constexpr (is_underscore<IntTuple>::value) {
|
||||
if constexpr (is_tuple<Tiler>::value) {
|
||||
static_assert(tuple_size<Tiler>::value <= Layout<LShape,LStride>::rank, "logical_product: Too many modes in tiler.");
|
||||
return transform_layout(layout, tiler, [](auto const& l, auto const& t) { return logical_product(l,t); });
|
||||
} else if constexpr (is_underscore<Tiler>::value) {
|
||||
return layout;
|
||||
} else if constexpr (is_integral<IntTuple>::value) {
|
||||
return logical_product(layout, make_layout(tile));
|
||||
} else if constexpr (is_integral<Tiler>::value) {
|
||||
return logical_product(layout, make_layout(tiler));
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
@ -1451,45 +1449,43 @@ logical_product(Layout<LShape,LStride> const& layout,
|
||||
//
|
||||
|
||||
template <class LShape, class LStride,
|
||||
class Tile>
|
||||
class Tiler>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
zipped_product(Layout<LShape,LStride> const& layout,
|
||||
Tile const& tile)
|
||||
Tiler const& tiler)
|
||||
{
|
||||
return tile_unzip(logical_product(layout, tile), tile);
|
||||
return tile_unzip(logical_product(layout, tiler), tiler);
|
||||
}
|
||||
|
||||
// Same as zipped_product, but unpacks the second mode: ((BLK_A,BLK_B,...),a,b,...,x,y)
|
||||
template <class LShape, class LStride,
|
||||
class Tile>
|
||||
class Tiler>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tiled_product(Layout<LShape,LStride> const& layout,
|
||||
Tile const& tile)
|
||||
Tiler const& tiler)
|
||||
{
|
||||
auto div = zipped_product(layout, tile);
|
||||
auto div = zipped_product(layout, tiler);
|
||||
|
||||
auto R = rank(tile);
|
||||
auto R = rank<1>(div);
|
||||
return div(_, repeat<R>(_));
|
||||
}
|
||||
|
||||
// Attempts to reproduce layout "block" over layout "layout"
|
||||
// That is, think of every element of "layout" as a "block"
|
||||
// Attempts to reproduce a layout over a tiler
|
||||
// That is, think of every element of "tiler" as a "layout"
|
||||
// and return the layout of the resulting structure
|
||||
template <class TShape, class TStride,
|
||||
class UShape, class UStride>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
blocked_product(Layout<TShape,TStride> const& block,
|
||||
Layout<UShape,UStride> const& layout)
|
||||
blocked_product(Layout<TShape,TStride> const& layout,
|
||||
Layout<UShape,UStride> const& tiler)
|
||||
{
|
||||
constexpr int R = cute::max(rank_v<TShape>, rank_v<UShape>);
|
||||
auto padded_block = append<R>(block);
|
||||
auto padded_layout = append<R>(layout);
|
||||
|
||||
auto result = logical_product(padded_block, padded_layout);
|
||||
|
||||
auto result = logical_product(append<R>(layout), append<R>(tiler));
|
||||
|
||||
return coalesce(zip(get<0>(result), get<1>(result)), repeat<R>(Int<1>{}));
|
||||
}
|
||||
|
||||
@ -1497,14 +1493,12 @@ template <class TShape, class TStride,
|
||||
class UShape, class UStride>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
raked_product(Layout<TShape,TStride> const& block,
|
||||
Layout<UShape,UStride> const& layout)
|
||||
raked_product(Layout<TShape,TStride> const& layout,
|
||||
Layout<UShape,UStride> const& tiler)
|
||||
{
|
||||
constexpr int R = cute::max(rank_v<TShape>, rank_v<UShape>);
|
||||
auto padded_block = append<R>(block);
|
||||
auto padded_layout = append<R>(layout);
|
||||
|
||||
auto result = logical_product(padded_block, padded_layout);
|
||||
auto result = logical_product(append<R>(layout), append<R>(tiler));
|
||||
|
||||
return coalesce(zip(get<1>(result), get<0>(result)), repeat<R>(Int<1>{}));
|
||||
}
|
||||
|
||||
@ -473,6 +473,16 @@ zipped_divide(ComposedLayout<A,O,B> const& a,
|
||||
return composition(a.layout_a(), a.offset(), zipped_divide(a.layout_b(), b));
|
||||
}
|
||||
|
||||
template <class A, class O, class B,
|
||||
class Tile>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
flat_divide(ComposedLayout<A,O,B> const& a,
|
||||
Tile const& b)
|
||||
{
|
||||
return composition(a.layout_a(), a.offset(), flat_divide(a.layout_b(), b));
|
||||
}
|
||||
|
||||
template <class A, class O, class B,
|
||||
class Tile>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
|
||||
@ -181,7 +181,7 @@ template <class T0, class T1, class... Ts>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
make_inttuple_iter(T0 const& t0, T1 const& t1, Ts const&... ts) {
|
||||
return make_tuple_iter(cute::make_tuple(t0, t1, ts...));
|
||||
return make_inttuple_iter(cute::make_tuple(t0, t1, ts...));
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
@ -148,7 +148,9 @@ using _96 = Int<96>;
|
||||
using _128 = Int<128>;
|
||||
using _192 = Int<192>;
|
||||
using _256 = Int<256>;
|
||||
using _384 = Int<384>;
|
||||
using _512 = Int<512>;
|
||||
using _768 = Int<768>;
|
||||
using _1024 = Int<1024>;
|
||||
using _2048 = Int<2048>;
|
||||
using _4096 = Int<4096>;
|
||||
|
||||
@ -97,7 +97,7 @@ template <class T, class U,
|
||||
__CUTE_REQUIRES(is_std_integral<T>::value &&
|
||||
is_std_integral<U>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
cute::common_type_t<T, U>
|
||||
gcd(T t, U u) {
|
||||
while (true) {
|
||||
if (t == 0) { return u; }
|
||||
@ -112,7 +112,7 @@ template <class T, class U,
|
||||
__CUTE_REQUIRES(is_std_integral<T>::value &&
|
||||
is_std_integral<U>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
cute::common_type_t<T, U>
|
||||
lcm(T const& t, U const& u) {
|
||||
return (t / gcd(t,u)) * u;
|
||||
}
|
||||
|
||||
@ -233,14 +233,14 @@ CUTE_HOST_DEVICE void print(T const* const ptr)
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE void print(counting_iterator<T> ptr)
|
||||
{
|
||||
printf("counting_iter_"); print(ptr.n_);
|
||||
printf("counting_iter("); print(ptr.n_); printf(")");
|
||||
}
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
template <class T>
|
||||
CUTE_HOST std::ostream& operator<<(std::ostream& os, counting_iterator<T> ptr)
|
||||
{
|
||||
return os << "counting_iter_" << ptr.n_;
|
||||
return os << "counting_iter(" << ptr.n_ << ")";
|
||||
}
|
||||
#endif // !defined(__CUDACC_RTC__)
|
||||
|
||||
|
||||
@ -990,13 +990,10 @@ CUTE_HOST_DEVICE void print_tensor(Tensor<Engine,Layout> const& tensor)
|
||||
{
|
||||
print(tensor); print(":\n");
|
||||
|
||||
auto format = get_format(tensor(0));
|
||||
using type = typename decltype(format)::type;
|
||||
|
||||
if constexpr (Layout::rank == 1)
|
||||
{
|
||||
for (int m = 0; m < size(tensor); ++m) {
|
||||
printf(format.format, format.digits, type(tensor(m)));
|
||||
pretty_print(tensor(m));
|
||||
printf("\n");
|
||||
}
|
||||
} else
|
||||
@ -1004,7 +1001,7 @@ CUTE_HOST_DEVICE void print_tensor(Tensor<Engine,Layout> const& tensor)
|
||||
{
|
||||
for (int m = 0; m < size<0>(tensor); ++m) {
|
||||
for (int n = 0; n < size<1>(tensor); ++n) {
|
||||
printf(format.format, format.digits, type(tensor(m,n)));
|
||||
pretty_print(tensor(m,n));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
@ -1013,7 +1010,7 @@ CUTE_HOST_DEVICE void print_tensor(Tensor<Engine,Layout> const& tensor)
|
||||
{
|
||||
print_tensor(tensor(_,_,0));
|
||||
for (int k = 1; k < size<2>(tensor); ++k) {
|
||||
for (int i = 0; i < format.digits*size<1>(tensor); ++i) { print("-"); } print("\n");
|
||||
for (int i = 0; i < 5*size<1>(tensor); ++i) { print("-"); } print("\n");
|
||||
print_tensor(tensor(_,_,k));
|
||||
}
|
||||
} else
|
||||
@ -1021,7 +1018,7 @@ CUTE_HOST_DEVICE void print_tensor(Tensor<Engine,Layout> const& tensor)
|
||||
{
|
||||
print_tensor(tensor(_,_,_,0));
|
||||
for (int p = 1; p < size<3>(tensor); ++p) {
|
||||
for (int i = 0; i < format.digits*size<1>(tensor); ++i) { print("="); } print("\n");
|
||||
for (int i = 0; i < 5*size<1>(tensor); ++i) { print("="); } print("\n");
|
||||
print_tensor(tensor(_,_,_,p));
|
||||
}
|
||||
}
|
||||
|
||||
@ -57,61 +57,6 @@ num_digits(int x)
|
||||
10)))))))));
|
||||
}
|
||||
|
||||
template <class T>
|
||||
struct format_and_size {
|
||||
using type = T;
|
||||
char const* format;
|
||||
int digits;
|
||||
};
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
format_and_size<int>
|
||||
get_format(bool) {
|
||||
return {"%*d", 3};
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
format_and_size<int32_t>
|
||||
get_format(int32_t) {
|
||||
return {"%*d", 5};
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
format_and_size<uint32_t>
|
||||
get_format(uint32_t) {
|
||||
return {"%*d", 5};
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
format_and_size<int64_t>
|
||||
get_format(int64_t) {
|
||||
return {"%*d", 5};
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
format_and_size<uint64_t>
|
||||
get_format(uint64_t) {
|
||||
return {"%*d", 5};
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
format_and_size<float>
|
||||
get_format(half_t) {
|
||||
return {"%*.2f", 8};
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
format_and_size<float>
|
||||
get_format(float) {
|
||||
return {"%*.2e", 10};
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
format_and_size<double>
|
||||
get_format(double) {
|
||||
return {"%*.3e", 11};
|
||||
}
|
||||
|
||||
//
|
||||
// print dispatcher
|
||||
//
|
||||
@ -195,4 +140,54 @@ print(char const* format) {
|
||||
printf("%s", format);
|
||||
}
|
||||
|
||||
//
|
||||
// pretty printing
|
||||
//
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE void
|
||||
pretty_print(T const& v) {
|
||||
printf(" "); print(v);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE void
|
||||
pretty_print(bool const& v) {
|
||||
printf("%*d", 3, int(v));
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE void
|
||||
pretty_print(int32_t const& v) {
|
||||
printf("%*d", 5, v);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE void
|
||||
pretty_print(uint32_t const& v) {
|
||||
printf("%*d", 5, v);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE void
|
||||
pretty_print(int64_t const& v) {
|
||||
printf("%*lld", 5, static_cast<long long>(v));
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE void
|
||||
pretty_print(uint64_t const& v) {
|
||||
printf("%*llu", 5, static_cast<unsigned long long>(v));
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE void
|
||||
pretty_print(half_t const& v) {
|
||||
printf("%*.2f", 8, float(v));
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE void
|
||||
pretty_print(float const& v) {
|
||||
printf("%*.2e", 10, v);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE void
|
||||
pretty_print(double const& v) {
|
||||
printf("%*.3e", 11, v);
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
@ -122,6 +122,9 @@ using CUTE_STL_NAMESPACE::is_empty_v;
|
||||
|
||||
using CUTE_STL_NAMESPACE::invoke_result_t;
|
||||
|
||||
using CUTE_STL_NAMESPACE::common_type;
|
||||
using CUTE_STL_NAMESPACE::common_type_t;
|
||||
|
||||
// <utility>
|
||||
using CUTE_STL_NAMESPACE::declval;
|
||||
|
||||
|
||||
@ -47,6 +47,18 @@ namespace cutlass {
|
||||
namespace arch {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Enumerates the reserved named barriers to avoid potential conflicts
|
||||
// This enum class specifies the NamedBarriers reserved by CUTLASS.
|
||||
enum class ReservedNamedBarriers {
|
||||
EpilogueBarrier = 0,
|
||||
TransposeBarrier = 1,
|
||||
TransformBarrier = 2,
|
||||
StreamkBarrier0 = 3,
|
||||
StreamkBarrier1 = 4
|
||||
, FirstUserBarrier = StreamkBarrier1 + 1
|
||||
};
|
||||
|
||||
|
||||
class NamedBarrier {
|
||||
|
||||
// Data Members:
|
||||
@ -60,9 +72,19 @@ class NamedBarrier {
|
||||
|
||||
public:
|
||||
|
||||
// Constructor for CUTLASS developers:
|
||||
// effective barrier ID starts from 0
|
||||
CUTLASS_DEVICE
|
||||
NamedBarrier(uint32_t num_threads, ReservedNamedBarriers reserved_named_barriers)
|
||||
: num_threads_(num_threads), id_(static_cast<uint32_t>(reserved_named_barriers)) {}
|
||||
|
||||
// Constructor for CUTLASS users:
|
||||
// effective barrier ID starts from ReservedNamedBarrierCount
|
||||
CUTLASS_DEVICE
|
||||
NamedBarrier(uint32_t num_threads, uint32_t id = 0)
|
||||
: num_threads_(num_threads), id_(id) {}
|
||||
: num_threads_(num_threads), id_(id + ReservedNamedBarrierCount) {
|
||||
CUTLASS_ASSERT(id + ReservedNamedBarrierCount <= HardwareMaxNumNamedBarriers && "Effective barrier_id should not exceed 16.");
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void arrive_and_wait() const {
|
||||
@ -80,8 +102,52 @@ class NamedBarrier {
|
||||
}
|
||||
|
||||
// Static variants
|
||||
|
||||
// Calling interface for CUTLASS users:
|
||||
// effective barrier ID starts from ReservedNamedBarrierCount
|
||||
CUTLASS_DEVICE
|
||||
static void arrive_and_wait(uint32_t num_threads, uint32_t barrier_id) {
|
||||
arrive_and_wait_internal(num_threads, barrier_id + ReservedNamedBarrierCount);
|
||||
}
|
||||
|
||||
// Calling interface for CUTLASS developers:
|
||||
// effective barrier ID starts from 0
|
||||
CUTLASS_DEVICE
|
||||
static void arrive_and_wait(uint32_t num_threads, ReservedNamedBarriers reserved_named_barriers) {
|
||||
arrive_and_wait_internal(num_threads, static_cast<int>(reserved_named_barriers));
|
||||
}
|
||||
|
||||
// Calling interface for CUTLASS users:
|
||||
// effective barrier ID starts from ReservedNamedBarrierCount
|
||||
CUTLASS_DEVICE
|
||||
static void arrive(uint32_t num_threads, uint32_t barrier_id) {
|
||||
arrive_internal(num_threads, barrier_id + ReservedNamedBarrierCount);
|
||||
}
|
||||
|
||||
// Calling interface for CUTLASS developers:
|
||||
// effective barrier ID starts from 0
|
||||
CUTLASS_DEVICE
|
||||
static void arrive(uint32_t num_threads, ReservedNamedBarriers reserved_named_barriers) {
|
||||
arrive_internal(num_threads, static_cast<int>(reserved_named_barriers));
|
||||
}
|
||||
|
||||
// Calling interface for CUTLASS users:
|
||||
// effective barrier ID starts from ReservedNamedBarrierCount
|
||||
CUTLASS_DEVICE
|
||||
static void sync(uint32_t num_threads, uint32_t barrier_id) {
|
||||
sync_internal(num_threads, barrier_id + ReservedNamedBarrierCount);
|
||||
}
|
||||
|
||||
// Calling interface for CUTLASS developers:
|
||||
// effective barrier ID starts from 0
|
||||
CUTLASS_DEVICE
|
||||
static void sync(uint32_t num_threads, ReservedNamedBarriers reserved_named_barriers) {
|
||||
sync_internal(num_threads, static_cast<int>(reserved_named_barriers));
|
||||
}
|
||||
|
||||
private:
|
||||
CUTLASS_DEVICE
|
||||
static void arrive_and_wait_internal(uint32_t num_threads, uint32_t barrier_id) {
|
||||
#if CUDA_BARRIER_ENABLED
|
||||
asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads));
|
||||
#elif defined(__CUDA_ARCH__)
|
||||
@ -90,7 +156,7 @@ class NamedBarrier {
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void arrive(uint32_t num_threads, uint32_t barrier_id) {
|
||||
static void arrive_internal(uint32_t num_threads, uint32_t barrier_id) {
|
||||
#if CUDA_BARRIER_ENABLED
|
||||
asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads));
|
||||
#elif defined(__CUDA_ARCH__)
|
||||
@ -99,9 +165,16 @@ class NamedBarrier {
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void sync(uint32_t num_threads, uint32_t barrier_id) {
|
||||
NamedBarrier::arrive_and_wait(num_threads, barrier_id);
|
||||
static void sync_internal(uint32_t num_threads, uint32_t barrier_id) {
|
||||
NamedBarrier::arrive_and_wait_internal(num_threads, barrier_id);
|
||||
}
|
||||
|
||||
public:
|
||||
// Currently we reserve 8 NamedBarriers for CUTLASS' own use cases,
|
||||
// while leaving the renaming for general users.
|
||||
static const uint32_t ReservedNamedBarrierCount = static_cast<uint32_t>(ReservedNamedBarriers::FirstUserBarrier);
|
||||
static const uint32_t HardwareMaxNumNamedBarriers = 16;
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -80,7 +80,7 @@ public:
|
||||
static int const kElementsPerStoredItem = int(sizeof(Storage) * 8) / sizeof_bits<T>::value;
|
||||
|
||||
/// Number of storage elements
|
||||
static size_t const kStorageElements = N / kElementsPerStoredItem;
|
||||
static size_t const kStorageElements = (N + kElementsPerStoredItem - 1) / kElementsPerStoredItem;
|
||||
|
||||
/// Number of logical elements
|
||||
static size_t const kElements = N;
|
||||
|
||||
@ -68,7 +68,7 @@ template <
|
||||
struct NamedBarrierSync {
|
||||
CUTLASS_DEVICE
|
||||
static void sync() {
|
||||
cutlass::arch::NamedBarrier::sync(ThreadCount, BarrierId);
|
||||
cutlass::arch::NamedBarrier::sync(ThreadCount, static_cast<arch::ReservedNamedBarriers>(BarrierId));
|
||||
}
|
||||
};
|
||||
|
||||
@ -227,9 +227,9 @@ template <
|
||||
uint32_t MaxNumNamedBarriers = 16
|
||||
>
|
||||
struct NamedBarrierManager {
|
||||
static constexpr uint32_t HardwareMaxNumNamedBarriers = 16;
|
||||
static_assert(MaxNumNamedBarriers <= HardwareMaxNumNamedBarriers);
|
||||
static_assert(MaxNumNamedBarriers + Offset <= HardwareMaxNumNamedBarriers, "Barrier IDs cannot exceed 15");
|
||||
|
||||
static_assert(MaxNumNamedBarriers <= arch::NamedBarrier::HardwareMaxNumNamedBarriers);
|
||||
static_assert(MaxNumNamedBarriers + Offset <= arch::NamedBarrier::HardwareMaxNumNamedBarriers, "Barrier IDs cannot exceed 15");
|
||||
|
||||
// Number of threads participating in the barrier
|
||||
static constexpr uint32_t ThreadCount = ThreadCount_;
|
||||
|
||||
@ -55,6 +55,7 @@
|
||||
#include <cstring>
|
||||
#endif
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
namespace cutlass {
|
||||
@ -83,6 +84,28 @@ struct alignas(2) bfloat16_t {
|
||||
return h;
|
||||
}
|
||||
|
||||
private:
|
||||
struct from_32_bit_integer_t {};
|
||||
static constexpr from_32_bit_integer_t from_32_bit_integer{};
|
||||
|
||||
template<class T>
|
||||
CUTLASS_HOST_DEVICE
|
||||
explicit bfloat16_t(from_32_bit_integer_t, T x) {
|
||||
static_assert(cutlass::platform::is_integral<T>::value && sizeof(T) == 4, "Requires 32-bit integer");
|
||||
|
||||
float flt = static_cast<float>(x);
|
||||
uint32_t bits;
|
||||
|
||||
#if defined(__CUDA_ARCH__)
|
||||
bits = reinterpret_cast<uint32_t &>(flt);
|
||||
#else
|
||||
std::memcpy(&bits, &flt, sizeof(bits));
|
||||
#endif
|
||||
|
||||
storage = uint16_t(bits >> 16);
|
||||
}
|
||||
|
||||
public:
|
||||
/// Default constructor
|
||||
bfloat16_t() = default;
|
||||
|
||||
@ -129,18 +152,10 @@ struct alignas(2) bfloat16_t {
|
||||
|
||||
/// Integer conversion - round toward nearest
|
||||
CUTLASS_HOST_DEVICE
|
||||
explicit bfloat16_t(int x) {
|
||||
float flt = static_cast<float>(x);
|
||||
uint32_t bits;
|
||||
explicit bfloat16_t(int x) : bfloat16_t(from_32_bit_integer, x) {}
|
||||
|
||||
#if defined(__CUDA_ARCH__)
|
||||
bits = reinterpret_cast<uint32_t &>(flt);
|
||||
#else
|
||||
std::memcpy(&bits, &flt, sizeof(bits));
|
||||
#endif
|
||||
|
||||
storage = uint16_t(bits >> 16);
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
explicit bfloat16_t(uint32_t x) : bfloat16_t(from_32_bit_integer, x) {}
|
||||
|
||||
/// Converts to float
|
||||
CUTLASS_HOST_DEVICE
|
||||
|
||||
@ -35,7 +35,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdio>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
@ -28,6 +28,16 @@
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*
|
||||
Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain
|
||||
existing integrations of CUTLASS require C++11 host compilers.
|
||||
|
||||
Until this requirement can be lifted, certain headers with this annotation are required
|
||||
to be remain consistent with C++11 syntax.
|
||||
|
||||
C++11 compatibility is enforced by this unit test: `cutlass_test_unit_core_cpp11`.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuComplex.h>
|
||||
|
||||
147
include/cutlass/cuda_host_adapter.hpp
Normal file
147
include/cutlass/cuda_host_adapter.hpp
Normal file
@ -0,0 +1,147 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Interface betweeen a CUTLASS device-wide operator and CUDA.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
#include "cutlass/platform/platform.h"
|
||||
#if ! defined(__CUDACC_RTC__)
|
||||
#include <cstdio>
|
||||
#endif
|
||||
|
||||
#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8)))
|
||||
# define CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED
|
||||
#endif
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
//
|
||||
// Macro-level guard for CUDA Host Adapter
|
||||
//
|
||||
#if !defined(CUTLASS_ENABLE_CUDA_HOST_ADAPTER)
|
||||
#define CUTLASS_ENABLE_CUDA_HOST_ADAPTER false
|
||||
#endif
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// This class defines an object which abstracts interactions between the CUTLASS device-wide GEMM and
|
||||
/// CUDA. The intention is to enable CUTLASS to be used with both the CUDA Runtime API and CUDA Driver API.
|
||||
struct CudaHostAdapter {
|
||||
|
||||
/// Limit the number of kernels
|
||||
static constexpr int32_t kMaximumKernelCount = 4;
|
||||
|
||||
/// Maximum cluster size
|
||||
static constexpr int MaxClusterSize = 32;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Handles
|
||||
void *kernel_handles[kMaximumKernelCount];
|
||||
int32_t kernel_count = 0;
|
||||
|
||||
CudaHostAdapter() = default;
|
||||
|
||||
/// Dtor
|
||||
virtual ~CudaHostAdapter() {}
|
||||
|
||||
/// Copy Ctor deleted
|
||||
CudaHostAdapter(const CudaHostAdapter&) = delete;
|
||||
|
||||
/// Copy Assignment deleted
|
||||
CudaHostAdapter& operator=(const CudaHostAdapter&) = delete;
|
||||
|
||||
/// Move ctor deleted
|
||||
CudaHostAdapter(CudaHostAdapter&&) = delete;
|
||||
|
||||
/// Move assignment deleted
|
||||
CudaHostAdapter& operator=(CudaHostAdapter&&) = delete;
|
||||
|
||||
|
||||
/// Ctor
|
||||
inline CudaHostAdapter(
|
||||
void **kernel_handles_,
|
||||
int32_t kernel_count_
|
||||
):
|
||||
kernel_count(kernel_count_)
|
||||
{
|
||||
CUTLASS_ASSERT(kernel_count >= 0);
|
||||
for (int32_t i = 0; i < kernel_count && i < kMaximumKernelCount; ++i) {
|
||||
kernel_handles[i] = kernel_handles_[i];
|
||||
}
|
||||
}
|
||||
|
||||
/// Queries the occupancy of a kernel
|
||||
virtual Status query_occupancy(
|
||||
int32_t *device_sms,
|
||||
int32_t *sm_occupancy,
|
||||
int32_t kernel_index,
|
||||
int32_t thread_count,
|
||||
int32_t smem_size) = 0;
|
||||
|
||||
/// Launches a kernel without using Threadblock Clusters.
|
||||
virtual Status launch(
|
||||
dim3 const grid_dims,
|
||||
dim3 const block_dims,
|
||||
size_t const smem_size,
|
||||
cudaStream_t cuda_stream,
|
||||
void** kernel_params,
|
||||
int32_t kernel_index) = 0;
|
||||
|
||||
/// Launches a kernel using the CUDA Extensible Launch API and Threadblock Clusters.
|
||||
virtual Status launch(
|
||||
dim3 const grid_dims,
|
||||
dim3 const cluster_dims,
|
||||
dim3 const block_dims,
|
||||
size_t const smem_size,
|
||||
cudaStream_t cuda_stream,
|
||||
void** kernel_params,
|
||||
int32_t kernel_index) = 0;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
64
include/cutlass/detail/collective.hpp
Normal file
64
include/cutlass/detail/collective.hpp
Normal file
@ -0,0 +1,64 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cute/container/tuple.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <size_t I, class Tuple>
|
||||
struct deduce_mixed_width_dtype {
|
||||
static_assert(I >= 0u && I <= 2u, "Valid indices are 0, 1, and 2, which represent Operand, Scale, and Bias, respectively.");
|
||||
|
||||
private:
|
||||
using underlying_tuple = cute::conditional_t<cute::is_tuple<Tuple>::value, Tuple, cute::tuple<Tuple>>;
|
||||
static constexpr size_t valid_index = cute::min(I, cute::tuple_size_v<underlying_tuple> - 1);
|
||||
|
||||
public:
|
||||
using type = cute::conditional_t<(I < cute::tuple_size_v<underlying_tuple>),
|
||||
cute::tuple_element_t<valid_index, underlying_tuple>,
|
||||
void>;
|
||||
};
|
||||
|
||||
template <size_t I, class Tuple>
|
||||
using deduce_mixed_width_dtype_t = typename deduce_mixed_width_dtype<I, Tuple>::type;
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
@ -187,6 +187,7 @@ constexpr bool is_tma_copy_engine() {
|
||||
|| cute::is_base_of_v<cute::SM90_TMA_LOAD_IM2COL, GmemTiledCopy>
|
||||
|| cute::is_base_of_v<cute::SM90_TMA_LOAD_IM2COL_MULTICAST, GmemTiledCopy>
|
||||
|| cute::is_base_of_v<cute::SM90_TMA_STORE, GmemTiledCopy>
|
||||
|| cute::is_base_of_v<cute::SM90_TMA_STORE_IM2COL, GmemTiledCopy>
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -104,19 +104,22 @@ sm90_compute_tile_shape_or_override() {
|
||||
if constexpr (cute::is_same_v<EpilogueTileType, EpilogueTileAuto>) {
|
||||
|
||||
if constexpr (detail::sm90_is_cooperative_v<Schedule>) {
|
||||
using N_tile = decltype(cute::min(_32{}, get<1>(TileShape_MNK{})));
|
||||
if constexpr (size<0>(TileShape_MNK{}) >= 128) {
|
||||
return Shape<_128,_32>{};
|
||||
return Shape<_128, N_tile>{};
|
||||
}
|
||||
else {
|
||||
return Shape<_64,_32>{};
|
||||
return Shape<_64, N_tile>{};
|
||||
}
|
||||
}
|
||||
else if constexpr (detail::sm90_is_warp_specialized_v<Schedule>) {
|
||||
if constexpr (sizeof_bits_v<ElementD> == 8) {
|
||||
return Shape<_64,_64>{};
|
||||
using N_tile = decltype(cute::min(_64{}, get<1>(TileShape_MNK{})));
|
||||
return Shape<_64, N_tile>{};
|
||||
}
|
||||
else {
|
||||
return Shape<_64,_32>{};
|
||||
using N_tile = decltype(cute::min(_32{}, get<1>(TileShape_MNK{})));
|
||||
return Shape<_64,N_tile>{};
|
||||
}
|
||||
}
|
||||
else {
|
||||
@ -265,6 +268,13 @@ struct Sm90TmaBuilderImpl {
|
||||
using GmemStrideTypeC = cutlass::detail::TagToStrideC_t<GmemLayoutTagC>;
|
||||
using GmemStrideTypeD = cutlass::detail::TagToStrideC_t<GmemLayoutTagD>;
|
||||
|
||||
using CopyOpS2G =
|
||||
SM90_TMA_STORE
|
||||
;
|
||||
using CopyOpG2S =
|
||||
SM90_TMA_LOAD
|
||||
;
|
||||
|
||||
// TMA builder allows for passing callbacks directly, which is either a fusion::FusionCallbacks
|
||||
// instance or a direct visitor implementation, e.g. fusion::Sm90LinearCombination
|
||||
using FusionCallbacks =
|
||||
@ -285,10 +295,10 @@ struct Sm90TmaBuilderImpl {
|
||||
ElementD,
|
||||
GmemStrideTypeD,
|
||||
FusionCallbacks,
|
||||
SM90_TMA_LOAD,
|
||||
CopyOpG2S,
|
||||
decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom<GmemStrideTypeC, ElementC, EpilogueTile_MN>()),
|
||||
decltype(detail::sm90_get_smem_load_op_for_source<GmemStrideTypeC, ElementC>()),
|
||||
SM90_TMA_STORE,
|
||||
CopyOpS2G,
|
||||
decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom<GmemStrideTypeD, ElementD, EpilogueTile_MN>()),
|
||||
decltype(detail::sm90_get_smem_store_op_for_accumulator<GmemStrideTypeD, ElementD>())
|
||||
>;
|
||||
@ -400,6 +410,7 @@ template <
|
||||
class ElementD,
|
||||
class GmemLayoutTagD,
|
||||
int AlignmentD,
|
||||
class Schedule,
|
||||
FloatRoundStyle RoundStyle
|
||||
>
|
||||
struct CollectiveBuilder<
|
||||
@ -416,9 +427,11 @@ struct CollectiveBuilder<
|
||||
ElementD,
|
||||
GmemLayoutTagD,
|
||||
AlignmentD,
|
||||
NoSmemWarpSpecialized,
|
||||
fusion::LinearCombination<ElementD,ElementCompute,ElementCompute,RoundStyle>,
|
||||
void> {
|
||||
Schedule,
|
||||
fusion::LinearCombination<ElementD,ElementCompute,ElementC_,ElementCompute,RoundStyle>,
|
||||
cute::enable_if_t<cute::is_same_v<Schedule, NoSmemWarpSpecialized> ||
|
||||
cute::is_same_v<Schedule, NoSmemWarpSpecializedArray> ||
|
||||
cute::is_same_v<Schedule, NoSmemWarpSpecializedGroup> >> {
|
||||
|
||||
// Passing void C disables source load
|
||||
using ElementC = cute::conditional_t<cute::is_void_v<ElementC_>,
|
||||
@ -433,12 +446,21 @@ struct CollectiveBuilder<
|
||||
ElementD, FragmentSize, ElementAccumulator, ElementCompute,
|
||||
ScaleType, RoundStyle, ElementC>;
|
||||
|
||||
using CollectiveOp = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter<
|
||||
cutlass::epilogue::collective::DefaultEpilogue<
|
||||
cutlass::detail::TagToStrideC_t<GmemLayoutTagC>,
|
||||
cutlass::detail::TagToStrideC_t<GmemLayoutTagD>,
|
||||
ThreadOp,
|
||||
cutlass::gemm::EpilogueDefault>
|
||||
using CollectiveOp = cute::conditional_t<
|
||||
cute::is_same_v<Schedule, NoSmemWarpSpecialized>,
|
||||
cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter<
|
||||
cutlass::epilogue::collective::DefaultEpilogue<
|
||||
cutlass::detail::TagToStrideC_t<GmemLayoutTagC>,
|
||||
cutlass::detail::TagToStrideC_t<GmemLayoutTagD>,
|
||||
ThreadOp,
|
||||
cutlass::gemm::EpilogueDefault>>,
|
||||
// Epilogue for Ptr-Array and Grouped Gemm
|
||||
cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter<
|
||||
cutlass::epilogue::collective::DefaultEpilogueArray<
|
||||
cutlass::detail::TagToStrideC_t<GmemLayoutTagC>,
|
||||
cutlass::detail::TagToStrideC_t<GmemLayoutTagD>,
|
||||
ThreadOp,
|
||||
Schedule>>
|
||||
>;
|
||||
};
|
||||
|
||||
@ -533,6 +555,9 @@ struct CollectiveBuilder<
|
||||
FusionOperation,
|
||||
void> {
|
||||
private:
|
||||
static_assert(cute::is_same_v<FusionOperation, fusion::LinearCombination<ElementD,ElementCompute,ElementC,ElementCompute>>,
|
||||
"Auto schedule doesn't support fusion. Use one of the TmaWarpSpecialized schedules instead.");
|
||||
|
||||
// Pick No-Smem epilogue as the Auto Epilogue Schedule (Auto schedules do not guarantee best performance)
|
||||
// since TMA epilogues are not compatible with non-TMA non-WS mainloops
|
||||
using EpilogueSchedule = NoSmemWarpSpecialized;
|
||||
@ -595,7 +620,7 @@ CollectiveBuilder<
|
||||
cute::is_base_of_v<TmaWarpSpecializedCooperativeElementwiseBase, Schedule> >> {
|
||||
private:
|
||||
using FusionOp =
|
||||
fusion::LinCombEltAct<Schedule::template ActivationFunctor, ElementD, ElementCompute, ElementCompute, Schedule::Round>;
|
||||
fusion::LinCombEltAct<Schedule::template ActivationFunctor, ElementD, ElementCompute, ElementC, ElementCompute, Schedule::Round>;
|
||||
using ImplSchedule =
|
||||
cute::conditional_t<cute::is_base_of_v<TmaWarpSpecializedElementwiseBase, Schedule>,
|
||||
TmaWarpSpecialized, TmaWarpSpecializedCooperative>;
|
||||
@ -677,7 +702,7 @@ private:
|
||||
GmemStrideTypeAux, typename Schedule::ElementT>());
|
||||
using FusionOperationAux = fusion::LinCombPerRowBiasEltActAux<
|
||||
GmemLayoutTagD, Schedule::template ActivationFunctor, ElementD, ElementCompute,
|
||||
typename Schedule::ElementT, typename Schedule::ElementBias, ElementCompute
|
||||
typename Schedule::ElementT, typename Schedule::ElementBias, ElementC_, ElementCompute
|
||||
>;
|
||||
using FusionCallbacksAux = fusion::FusionCallbacks<
|
||||
DispatchPolicy, FusionOperationAux, TileShape_MNK, EpilogueTile_MN, SmemLayoutAtomAux, SmemCopyOpAux
|
||||
@ -685,7 +710,7 @@ private:
|
||||
|
||||
using FusionOperationNoAux = fusion::LinCombPerRowBiasEltAct<
|
||||
Schedule::template ActivationFunctor, ElementD, ElementCompute,
|
||||
typename Schedule::ElementBias, ElementCompute
|
||||
typename Schedule::ElementBias, ElementC_, ElementCompute
|
||||
>;
|
||||
using FusionCallbacksNoAux = fusion::FusionCallbacks<
|
||||
DispatchPolicy, FusionOperationNoAux, TileShape_MNK, EpilogueTile_MN
|
||||
@ -750,7 +775,7 @@ struct CollectiveBuilder<
|
||||
GmemLayoutTagD,
|
||||
AlignmentD,
|
||||
cutlass::gemm::EpilogueTransposed,
|
||||
fusion::LinearCombination<ElementD,ElementCompute,ElementCompute,RoundStyle>,
|
||||
fusion::LinearCombination<ElementD,ElementCompute,ElementC_,ElementCompute,RoundStyle>,
|
||||
void> {
|
||||
// Passing void C disables source load
|
||||
using ElementC = cute::conditional_t<cute::is_void_v<ElementC_>,
|
||||
|
||||
@ -62,7 +62,7 @@ template <
|
||||
class GmemLayoutTagD,
|
||||
int AlignmentD,
|
||||
class Schedule,
|
||||
class FusionOpOrCallbacks = cutlass::epilogue::fusion::LinearCombination<ElementD,ElementCompute>,
|
||||
class FusionOpOrCallbacks = cutlass::epilogue::fusion::LinearCombination<ElementD,ElementCompute,ElementC,ElementCompute>,
|
||||
class Enable = void
|
||||
>
|
||||
struct CollectiveBuilder {
|
||||
|
||||
@ -54,6 +54,7 @@ class CollectiveEpilogue {
|
||||
|
||||
#include "detail.hpp"
|
||||
#include "default_epilogue.hpp"
|
||||
#include "default_epilogue_array.hpp"
|
||||
#include "epilogue_tensor_broadcast.hpp"
|
||||
#include "sm70_epilogue_vectorized.hpp"
|
||||
#include "sm90_epilogue_tma_warpspecialized.hpp"
|
||||
|
||||
@ -131,8 +131,9 @@ public:
|
||||
return true;
|
||||
}
|
||||
|
||||
// Note: SharedStorage is unused for DefaultEpilogue
|
||||
CUTLASS_HOST_DEVICE
|
||||
DefaultEpilogue(Params const& params_)
|
||||
DefaultEpilogue(Params const& params_, SharedStorage const& shared_storage = SharedStorage())
|
||||
: params(params_), epilogue_op(params_.thread) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
|
||||
254
include/cutlass/epilogue/collective/default_epilogue_array.hpp
Normal file
254
include/cutlass/epilogue/collective/default_epilogue_array.hpp
Normal file
@ -0,0 +1,254 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Functor performing elementwise operations used by epilogues.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/detail.hpp"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/numeric/int.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace collective {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Applies an element wise operation to all elements within the fragment
|
||||
// and writes them out to destination storage.
|
||||
template <
|
||||
class StrideC_,
|
||||
class StrideD_,
|
||||
class ThreadEpilogueOp_,
|
||||
class EpilogueSchedule_
|
||||
>
|
||||
class DefaultEpilogueArray {
|
||||
public:
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using EpilogueSchedule = EpilogueSchedule_;
|
||||
|
||||
// derived types of output thread level operator
|
||||
using ThreadEpilogueOp = ThreadEpilogueOp_;
|
||||
using ElementOutput = typename ThreadEpilogueOp::ElementOutput;
|
||||
using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator;
|
||||
using ElementCompute = typename ThreadEpilogueOp::ElementCompute;
|
||||
using ElementScalar = ElementCompute;
|
||||
using ElementC = typename ThreadEpilogueOp::ElementC;
|
||||
using StrideC = StrideC_;
|
||||
using ElementD = typename ThreadEpilogueOp::ElementD;
|
||||
using StrideD = StrideD_;
|
||||
using StridesC = cute::conditional_t<cute::is_same_v<EpilogueSchedule, NoSmemWarpSpecializedGroup>,
|
||||
StrideC const*, StrideC>;
|
||||
using StridesD = cute::conditional_t<cute::is_same_v<EpilogueSchedule, NoSmemWarpSpecializedGroup>,
|
||||
StrideD const*, StrideD>;
|
||||
|
||||
using GmemTiledCopyC = void;
|
||||
using GmemTiledCopyD = void;
|
||||
|
||||
static const int kOutputAlignment = ThreadEpilogueOp::kCount;
|
||||
using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;
|
||||
|
||||
static_assert(cute::is_same_v<EpilogueSchedule, NoSmemWarpSpecializedGroup> ||
|
||||
cute::is_same_v<EpilogueSchedule, NoSmemWarpSpecializedArray>, "Incompatible epilogue schedule.");
|
||||
static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
||||
static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
||||
|
||||
struct SharedStorage { };
|
||||
|
||||
// Host side epilogue arguments
|
||||
struct Arguments {
|
||||
typename ThreadEpilogueOp::Params thread{};
|
||||
ElementC const** ptr_C = nullptr;
|
||||
StridesC dC{};
|
||||
ElementD** ptr_D = nullptr;
|
||||
StridesD dD{};
|
||||
};
|
||||
|
||||
// Device side epilogue params
|
||||
using Params = Arguments;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(
|
||||
ProblemShape const&,
|
||||
Arguments const& args,
|
||||
[[maybe_unused]] void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static cutlass::Status
|
||||
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
template<class ProblemShape>
|
||||
CUTLASS_HOST_DEVICE static bool
|
||||
can_implement(
|
||||
[[maybe_unused]] ProblemShape const& problem_shape,
|
||||
[[maybe_unused]] Arguments const& args) {
|
||||
return true;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
DefaultEpilogueArray(Params const& params_)
|
||||
: params(params_), epilogue_op(params_.thread) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool
|
||||
is_source_needed() {
|
||||
return epilogue_op.is_source_needed();
|
||||
}
|
||||
|
||||
template<
|
||||
class ProblemShapeMNKL,
|
||||
class BlockShapeMNK,
|
||||
class BlockCoordMNKL,
|
||||
class FrgEngine, class FrgLayout,
|
||||
class TiledMma,
|
||||
class ResidueMNK
|
||||
>
|
||||
CUTLASS_HOST_DEVICE void
|
||||
operator()(
|
||||
ProblemShapeMNKL problem_shape_mnkl,
|
||||
BlockShapeMNK blk_shape_MNK,
|
||||
BlockCoordMNKL blk_coord_mnkl,
|
||||
cute::Tensor<FrgEngine, FrgLayout> const& accumulators,
|
||||
TiledMma tiled_mma,
|
||||
ResidueMNK residue_mnk,
|
||||
int thread_idx,
|
||||
[[maybe_unused]] char* smem_buf)
|
||||
{
|
||||
using namespace cute;
|
||||
using X = Underscore;
|
||||
|
||||
static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
|
||||
static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static");
|
||||
static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
|
||||
static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");
|
||||
|
||||
// Separate out problem shape for convenience
|
||||
auto M = get<0>(problem_shape_mnkl);
|
||||
auto N = get<1>(problem_shape_mnkl);
|
||||
auto L = get<3>(problem_shape_mnkl);
|
||||
// Batches are managed by using appropriate pointers to C and D matrices
|
||||
const int32_t mock_L = 1;
|
||||
const int32_t mock_l_coord = 0;
|
||||
// Slice to get the tile this CTA is responsible for
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl;
|
||||
|
||||
StrideC stride_c;
|
||||
StrideD stride_d;
|
||||
if constexpr (cute::is_same_v<EpilogueSchedule, NoSmemWarpSpecializedGroup>) {
|
||||
stride_c = detail::get_epilogue_stride<EpilogueSchedule>(params.dC[l_coord]);
|
||||
stride_d = detail::get_epilogue_stride<EpilogueSchedule>(params.dD[l_coord]);
|
||||
}
|
||||
else {
|
||||
stride_c = detail::get_epilogue_stride<EpilogueSchedule>(params.dC);
|
||||
stride_d = detail::get_epilogue_stride<EpilogueSchedule>(params.dD);
|
||||
}
|
||||
|
||||
// Represent the full output tensor
|
||||
Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C[l_coord]), make_shape(M,N,mock_L), stride_c); // (m,n,l)
|
||||
Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), make_shape(M,N,mock_L), stride_d); // (m,n,l)
|
||||
Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)
|
||||
Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)
|
||||
|
||||
Tensor gC = gC_mnl(_,_,m_coord,n_coord, mock_l_coord); // (BLK_M,BLK_N)
|
||||
Tensor gD = gD_mnl(_,_,m_coord,n_coord, mock_l_coord); // (BLK_M,BLK_N)
|
||||
|
||||
// Partition source and destination tiles to match the accumulator partitioning
|
||||
auto thr_mma = tiled_mma.get_thread_slice(thread_idx);
|
||||
Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N)
|
||||
Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N)
|
||||
|
||||
static_assert(is_static<FrgLayout>::value, "Accumulator layout must be static");
|
||||
CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD),
|
||||
"Source and destination must have the same number of elements.");
|
||||
CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators),
|
||||
"Accumulator count must have the same destination element count.");
|
||||
|
||||
// Make an identity coordinate tensor for predicating our output MN tile
|
||||
auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD))));
|
||||
Tensor tCcD = thr_mma.partition_C(cD);
|
||||
|
||||
// source is needed
|
||||
if (epilogue_op.is_source_needed()) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(accumulators); ++i) {
|
||||
if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) {
|
||||
tCgD(i) = epilogue_op(accumulators(i), tCgC(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
// source is not needed, avoid load
|
||||
else {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(accumulators); ++i) {
|
||||
if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) {
|
||||
tCgD(i) = epilogue_op(accumulators(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Params params;
|
||||
ThreadEpilogueOp epilogue_op;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace collective
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -170,7 +170,8 @@ public:
|
||||
[[maybe_unused]] TileCoordMNKL tile_coord_mnkl,
|
||||
[[maybe_unused]] TiledMma tiled_mma,
|
||||
[[maybe_unused]] int thread_idx,
|
||||
[[maybe_unused]] TensorStorage& shared_tensors)
|
||||
[[maybe_unused]] TensorStorage& shared_tensors,
|
||||
[[maybe_unused]] int subtile_idx=-1)
|
||||
{
|
||||
return load_pipe_producer_state;
|
||||
}
|
||||
@ -202,7 +203,8 @@ public:
|
||||
cute::Tensor<AccEngine,AccLayout> accumulators,
|
||||
TiledMma tiled_mma,
|
||||
int thread_idx,
|
||||
TensorStorage& shared_tensors)
|
||||
TensorStorage& shared_tensors,
|
||||
int subtile_index = -1)
|
||||
{
|
||||
constexpr int BLK_M_RANK = cute::rank<0>(tile_shape_MNK);
|
||||
auto m_max_coord = unwrap(cute::transform(make_seq<BLK_M_RANK>{}, [&](auto i) {
|
||||
|
||||
@ -85,6 +85,7 @@ public:
|
||||
using CopyAtomR2G = CopyAtomR2G_;
|
||||
|
||||
static const int kOutputAlignment = ThreadEpilogueOp::kCount;
|
||||
|
||||
using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;
|
||||
|
||||
static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
||||
@ -179,7 +180,7 @@ public:
|
||||
|
||||
// synchronizing function for smem reads/writes
|
||||
#if CUDA_BARRIER_ENABLED
|
||||
auto synchronize = [] () { cutlass::arch::NamedBarrier::sync(typename TiledCopyS2R::TiledNumThr{}, 0); };
|
||||
auto synchronize = [] () { cutlass::arch::NamedBarrier::sync(typename TiledCopyS2R::TiledNumThr{}, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
|
||||
#else
|
||||
auto synchronize = [] () { __syncthreads(); };
|
||||
#endif
|
||||
|
||||
@ -109,8 +109,8 @@ public:
|
||||
using CopyOpR2S = CopyOpR2S_;
|
||||
|
||||
using ThreadEpilogueOp = typename epilogue::fusion::FusionCallbacksTraits<FusionCallbacks>::Operation;
|
||||
using GmemTiledCopyC = SM90_TMA_LOAD;
|
||||
using GmemTiledCopyD = SM90_TMA_STORE;
|
||||
using GmemTiledCopyC = CopyOpG2S;
|
||||
using GmemTiledCopyD = CopyOpS2G;
|
||||
|
||||
static_assert(!is_layout<EpilogueTile>::value && is_tuple<EpilogueTile>::value, "EpilogueTile must be a cute::Tile or cute::Shape");
|
||||
static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]");
|
||||
@ -198,12 +198,12 @@ public:
|
||||
struct Params {
|
||||
using TMA_C = decltype(make_tma_copy(
|
||||
CopyOpG2S{},
|
||||
make_tensor(static_cast<SmemElementC const*>(nullptr),
|
||||
make_tensor(make_gmem_ptr(static_cast<SmemElementC const*>(nullptr)),
|
||||
repeat_like(StrideC{}, int32_t(0)), StrideC{}),
|
||||
SmemLayoutC{}(_,_,0)));
|
||||
using TMA_D = decltype(make_tma_copy(
|
||||
CopyOpS2G{},
|
||||
make_tensor(static_cast<ElementD const*>(nullptr),
|
||||
make_tensor(make_gmem_ptr(static_cast<ElementD const*>(nullptr)),
|
||||
repeat_like(StrideD{}, int32_t(0)), StrideD{}),
|
||||
SmemLayoutD{}(_,_,0)));
|
||||
|
||||
@ -225,14 +225,20 @@ public:
|
||||
// Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
auto M_C =
|
||||
size(M)
|
||||
;
|
||||
auto M_D =
|
||||
size(M)
|
||||
;
|
||||
|
||||
typename Params::TMA_C tma_load_c;
|
||||
if constexpr (not cute::is_void_v<ElementC>) {
|
||||
Tensor tensor_c = make_tensor(args.ptr_C, make_layout(make_shape(M,N,L), args.dC));
|
||||
Tensor tensor_c = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M_C,N,L), args.dC));
|
||||
tma_load_c = make_tma_copy(CopyOpG2S{}, tensor_c, SmemLayoutC{}(_,_,0));
|
||||
}
|
||||
|
||||
Tensor tensor_d = make_tensor(args.ptr_D, make_layout(make_shape(M,N,L), args.dD));
|
||||
Tensor tensor_d = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M_D,N,L), args.dD));
|
||||
typename Params::TMA_D tma_store_d = make_tma_copy(
|
||||
CopyOpS2G{},
|
||||
tensor_d,
|
||||
@ -332,25 +338,31 @@ public:
|
||||
TileCoordMNKL tile_coord_mnkl,
|
||||
TiledMma tiled_mma,
|
||||
int thread_idx,
|
||||
TensorStorage& shared_tensors) {
|
||||
TensorStorage& shared_tensors,
|
||||
int subtile_idx=-1) {
|
||||
using namespace cute;
|
||||
|
||||
// Indexing variables
|
||||
auto [M, N, K, L] = problem_shape_mnkl;
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl;
|
||||
|
||||
auto coord_shape =
|
||||
make_coord(m_coord, n_coord, l_coord)
|
||||
;
|
||||
|
||||
// Tile residue
|
||||
auto m_max_coord = unwrap(cute::transform(make_seq<cute::rank<0>(tile_shape_MNK)>{}, [&](auto i) {
|
||||
auto m_max_coord = unwrap(cute::transform(make_seq<rank<0>(tile_shape_MNK)>{}, [&](auto i) {
|
||||
return get<0,i>(problem_shape_mnkl) - get<0,i>(tile_shape_MNK) * get<0,i>(tile_coord_mnkl);
|
||||
}));
|
||||
auto n_max_coord = unwrap(cute::transform(make_seq<cute::rank<1>(tile_shape_MNK)>{}, [&](auto i) {
|
||||
auto n_max_coord = unwrap(cute::transform(make_seq<rank<1>(tile_shape_MNK)>{}, [&](auto i) {
|
||||
return get<1,i>(problem_shape_mnkl) - get<1,i>(tile_shape_MNK) * get<1,i>(tile_coord_mnkl);
|
||||
}));
|
||||
auto residue_mn = make_coord(m_max_coord, n_max_coord);
|
||||
|
||||
// Represent the full source tensor, slice to get the tile this CTA is currently responsible for
|
||||
Tensor mC = params.tma_load_c.get_tma_tensor(make_shape(M,N,L)); // (M,N,L)
|
||||
Tensor gC = local_tile(mC, take<0,2>(CtaTileMNK{}), make_coord(m_coord,n_coord,l_coord)); // (CTA_M,CTA_N)
|
||||
Tensor mC_mn = params.tma_load_c.get_tma_tensor(make_shape(M,N,L)); // (M,N,L)
|
||||
Tensor mC = coalesce(mC_mn, take<0,2>(CtaTileMNK{}));
|
||||
Tensor gC = local_tile(mC, take<0,2>(CtaTileMNK{}), coord_shape); // (CTA_M,CTA_N)
|
||||
|
||||
// Apply epilogue subtile, get matching smem tensor
|
||||
SmemElementC* ptr_sC = reinterpret_cast<SmemElementC*>(shared_tensors.smem_D.data());
|
||||
@ -391,6 +403,9 @@ public:
|
||||
for (int epi_n = 0; epi_n < size<3>(gC_epi); ++epi_n) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int epi_m = 0; epi_m < size<2>(gC_epi); ++epi_m) {
|
||||
if (subtile_idx != -1 && (epi_n * static_cast<int>(size<2>(gC_epi)) + epi_m) != subtile_idx) {
|
||||
continue;
|
||||
}
|
||||
// Acquire the lock for this stage
|
||||
constexpr uint16_t mcast_mask = 0;
|
||||
uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state);
|
||||
@ -449,31 +464,36 @@ public:
|
||||
cute::Tensor<AccEngine,AccLayout> accumulators,
|
||||
TiledMma tiled_mma,
|
||||
int thread_idx,
|
||||
TensorStorage& shared_tensors) {
|
||||
TensorStorage& shared_tensors,
|
||||
int subtile_idx=-1) {
|
||||
using namespace cute;
|
||||
using ElementAccumulator = typename AccEngine::value_type;
|
||||
using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits<FusionCallbacks>::ElementCompute;
|
||||
using ElementCompute = cute::conditional_t<cute::is_void_v<ElementCompute_>,ElementAccumulator,ElementCompute_>;
|
||||
|
||||
static_assert(is_rmem<AccEngine>::value, "Accumulator must be RF resident.");
|
||||
static_assert(cute::rank(AccLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA,MMA_M,MMA_N)");
|
||||
static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
|
||||
static_assert(rank(AccLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA,MMA_M,MMA_N)");
|
||||
static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
|
||||
static_assert(is_static<TileShapeMNK>::value, "TileShapeMNK must be static");
|
||||
static_assert(cute::rank(TileShapeMNK{}) == 3, "TileShapeMNK must be rank 3");
|
||||
static_assert(cute::rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4");
|
||||
static_assert(rank(TileShapeMNK{}) == 3, "TileShapeMNK must be rank 3");
|
||||
static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4");
|
||||
|
||||
// Indexing variables
|
||||
auto [M, N, K, L] = problem_shape_mnkl;
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl;
|
||||
auto mma_tile_m = size<0>(typename TiledMma::TiledShape_MNK{});
|
||||
auto mma_tile_n = size<1>(typename TiledMma::TiledShape_MNK{});
|
||||
auto mma_tile_m = tile_size<0>(tiled_mma);
|
||||
auto mma_tile_n = tile_size<1>(tiled_mma);
|
||||
auto epi_tile_m = size<0>(EpilogueTile{});
|
||||
auto epi_tile_n = size<1>(EpilogueTile{});
|
||||
|
||||
auto coord_shape =
|
||||
make_coord(m_coord, n_coord, l_coord)
|
||||
;
|
||||
|
||||
// Represent the full output tensor, slice to get the tile this CTA is responsible for
|
||||
Tensor mD = params.tma_store_d.get_tma_tensor(make_shape(M,N,L)); // (M,N,L)
|
||||
Tensor gD = local_tile(mD, take<0,2>(CtaTileMNK{}), make_coord(m_coord,n_coord,l_coord)); // (CTA_M,CTA_N)
|
||||
Tensor mD_mn = params.tma_store_d.get_tma_tensor(make_shape(M,N,L)); // (M,N,L)
|
||||
Tensor mD = coalesce(mD_mn, take<0,2>(CtaTileMNK{}));
|
||||
Tensor gD = local_tile(mD, take<0,2>(CtaTileMNK{}), coord_shape); // (CTA_M,CTA_N)
|
||||
|
||||
// Apply epilogue subtiling
|
||||
Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
@ -530,11 +550,11 @@ public:
|
||||
Tensor bSG_gD = thrblk_s2g.partition_D(gD_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N)
|
||||
|
||||
// Coordinate tensors and residue for tile quantization
|
||||
auto m_max_coord = unwrap(cute::transform(make_seq<cute::rank<0>(CtaTileMNK{})>{}, [&](auto i) {
|
||||
auto m_max_coord = unwrap(cute::transform(make_seq<rank<0>(CtaTileMNK{})>{}, [&](auto i) {
|
||||
auto c_m = get<0,i>(problem_shape_mnkl) - get<0,i>(CtaTileMNK{}) * get<0,i>(tile_coord_mnkl);
|
||||
return cute::max(0, c_m);
|
||||
}));
|
||||
auto n_max_coord = unwrap(cute::transform(make_seq<cute::rank<1>(CtaTileMNK{})>{}, [&](auto i) {
|
||||
auto n_max_coord = unwrap(cute::transform(make_seq<rank<1>(CtaTileMNK{})>{}, [&](auto i) {
|
||||
auto c_n = get<1,i>(problem_shape_mnkl) - get<1,i>(CtaTileMNK{}) * get<1,i>(tile_coord_mnkl);
|
||||
return cute::max(0, c_n);
|
||||
}));
|
||||
@ -559,13 +579,13 @@ public:
|
||||
tRS_cD,
|
||||
tRS_rC
|
||||
};
|
||||
auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks<RefSrc>(cst_args);
|
||||
auto cst_callbacks = fusion_callbacks.get_consumer_store_callbacks<RefSrc>(cst_args);
|
||||
bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed();
|
||||
bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed();
|
||||
|
||||
// Thread synchronizer for previously issued waits or fences
|
||||
// to ensure visibility of smem reads/writes to threads or TMA unit
|
||||
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(size(TiledMma{}), 0); };
|
||||
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
|
||||
|
||||
// Predication for TMA store (one warp issues TMA store)
|
||||
bool issue_tma_store = (thread_idx / NumThreadsPerWarp) == 0;
|
||||
@ -594,9 +614,12 @@ public:
|
||||
for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m) {
|
||||
bool is_last_iteration = epi_m == size<2>(gD_epi)-1 && epi_n == size<3>(gD_epi)-1;
|
||||
|
||||
if (subtile_idx != -1 && (epi_n * static_cast<int>(size<2>(gD_epi)) + epi_m) != subtile_idx) {
|
||||
continue;
|
||||
}
|
||||
// The current tile in accumulator
|
||||
int mma_m = epi_m;
|
||||
int mma_n = (epi_n * epi_tile_n) / mma_tile_n;
|
||||
int mma_n = (epi_n * size<1>(EpilogueTile{})) / mma_tile_n;
|
||||
Tensor tRS_rAcc_frg_mn = tRS_rAcc_frg(_,mma_m,mma_n);
|
||||
|
||||
if (is_producer_load_needed) {
|
||||
|
||||
@ -46,6 +46,8 @@ namespace cutlass::epilogue {
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct NoSmemWarpSpecialized {};
|
||||
struct NoSmemWarpSpecializedArray {};
|
||||
struct NoSmemWarpSpecializedGroup {};
|
||||
struct TmaWarpSpecialized {};
|
||||
struct TmaWarpSpecializedCooperative {};
|
||||
// DEPRECATED schedules, will be removed in next release
|
||||
|
||||
@ -50,6 +50,8 @@ struct FusionOperation {
|
||||
// metadata types/queries that can be overrided
|
||||
using ElementOutput = void;
|
||||
using ElementCompute = void;
|
||||
|
||||
using ElementSource = void;
|
||||
static constexpr bool IsSourceSupported = false;
|
||||
|
||||
using ElementScalar = void;
|
||||
@ -96,11 +98,13 @@ struct ScaledAcc : FusionOperation {
|
||||
template<
|
||||
class ElementOutput_,
|
||||
class ElementCompute_,
|
||||
class ElementSource_ = ElementOutput_,
|
||||
class ElementScalar_ = ElementCompute_,
|
||||
FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
struct LinearCombination
|
||||
: ScaledAcc<ElementOutput_, ElementCompute_, ElementScalar_, RoundStyle_> {
|
||||
using ElementSource = ElementSource_;
|
||||
static constexpr bool IsSourceSupported = true;
|
||||
};
|
||||
|
||||
@ -109,11 +113,12 @@ template<
|
||||
template <class> class ActivationFn_,
|
||||
class ElementOutput_,
|
||||
class ElementCompute_,
|
||||
class ElementSource_ = ElementOutput_,
|
||||
class ElementScalar_ = ElementCompute_,
|
||||
FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
struct LinCombEltAct
|
||||
: LinearCombination<ElementOutput_, ElementCompute_, ElementScalar_, RoundStyle_> {
|
||||
: LinearCombination<ElementOutput_, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_> {
|
||||
using ActivationFn = ActivationFn_<ElementCompute_>;
|
||||
static constexpr bool IsEltActSupported = true;
|
||||
};
|
||||
@ -123,12 +128,13 @@ template<
|
||||
class ElementOutput_,
|
||||
class ElementCompute_,
|
||||
class ElementBias_ = ElementOutput_,
|
||||
class ElementSource_ = ElementOutput_,
|
||||
class ElementScalar_ = ElementCompute_,
|
||||
int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
|
||||
FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
struct LinCombPerRowBias
|
||||
: LinearCombination<ElementOutput_, ElementCompute_, ElementScalar_, RoundStyle_> {
|
||||
: LinearCombination<ElementOutput_, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_> {
|
||||
using ElementBias = ElementBias_;
|
||||
static constexpr int AlignmentBias = AlignmentBias_;
|
||||
static constexpr bool IsPerRowBiasSupported = true;
|
||||
@ -140,13 +146,14 @@ template<
|
||||
class ElementOutput_,
|
||||
class ElementCompute_,
|
||||
class ElementBias_ = ElementOutput_,
|
||||
class ElementSource_ = ElementOutput_,
|
||||
class ElementScalar_ = ElementCompute_,
|
||||
int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
|
||||
FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
struct LinCombPerRowBiasEltAct
|
||||
: LinCombPerRowBias<ElementOutput_, ElementCompute_,
|
||||
ElementBias_, ElementScalar_, AlignmentBias_, RoundStyle_> {
|
||||
ElementBias_, ElementSource_, ElementScalar_, AlignmentBias_, RoundStyle_> {
|
||||
using ActivationFn = ActivationFn_<ElementCompute_>;
|
||||
static constexpr bool IsEltActSupported = true;
|
||||
};
|
||||
@ -160,6 +167,7 @@ template<
|
||||
class ElementCompute_,
|
||||
class ElementAux_ = ElementOutput_,
|
||||
class ElementBias_ = ElementOutput_,
|
||||
class ElementSource_ = ElementOutput_,
|
||||
class ElementScalar_ = ElementCompute_,
|
||||
int AlignmentAux_ = 128 / sizeof_bits_v<ElementAux_>,
|
||||
int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
|
||||
@ -167,7 +175,7 @@ template<
|
||||
>
|
||||
struct LinCombPerRowBiasEltActAux
|
||||
: LinCombPerRowBiasEltAct<ActivationFn_, ElementOutput_, ElementCompute_,
|
||||
ElementBias_, ElementScalar_, AlignmentBias_, RoundStyle_> {
|
||||
ElementBias_, ElementSource_, ElementScalar_, AlignmentBias_, RoundStyle_> {
|
||||
using ElementAux = ElementAux_;
|
||||
using GmemLayoutTagAux = GmemLayoutTagAux_;
|
||||
static constexpr int AlignmentAux = AlignmentAux_;
|
||||
@ -180,6 +188,7 @@ template<
|
||||
class ElementOutput_,
|
||||
class ElementCompute_,
|
||||
class ElementBias_ = ElementOutput_,
|
||||
class ElementSource_ = ElementOutput_,
|
||||
class ElementScalar_ = ElementCompute_, // per-row alpha/beta
|
||||
int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
|
||||
int AlignmentScalar_ = 128 / sizeof_bits_v<ElementScalar_>,
|
||||
@ -187,7 +196,7 @@ template<
|
||||
>
|
||||
struct PerRowLinCombPerRowBiasEltAct
|
||||
: LinCombPerRowBiasEltAct<ActivationFn_, ElementOutput_, ElementCompute_,
|
||||
ElementBias_, ElementScalar_, AlignmentBias_, RoundStyle_> {
|
||||
ElementBias_, ElementSource_, ElementScalar_, AlignmentBias_, RoundStyle_> {
|
||||
static constexpr int AlignmentScalar = AlignmentScalar_;
|
||||
static constexpr bool IsPerRowScaleSupported = true;
|
||||
};
|
||||
@ -202,13 +211,14 @@ template<
|
||||
class ElementOutput_,
|
||||
class ElementCompute_,
|
||||
class ElementBias_ = ElementOutput_,
|
||||
class ElementSource_ = ElementOutput_,
|
||||
class ElementScalar_ = ElementCompute_,
|
||||
int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
|
||||
FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
struct ScaledLinCombPerRowBiasEltAct
|
||||
: LinCombPerRowBiasEltAct<ActivationFn_, ElementOutput_, ElementCompute_,
|
||||
ElementBias_, ElementScalar_, AlignmentBias_, RoundStyle_> {
|
||||
ElementBias_, ElementSource_, ElementScalar_, AlignmentBias_, RoundStyle_> {
|
||||
static constexpr bool IsScaleFactorSupported = true;
|
||||
};
|
||||
|
||||
@ -231,6 +241,7 @@ template<
|
||||
class ElementAux_ = ElementOutput_,
|
||||
class ElementAmax_ = ElementCompute_,
|
||||
class ElementBias_ = ElementOutput_,
|
||||
class ElementSource_ = ElementOutput_,
|
||||
class ElementScalar_ = ElementCompute_,
|
||||
int AlignmentAux_ = 128 / sizeof_bits_v<ElementAux_>,
|
||||
int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
|
||||
@ -238,7 +249,7 @@ template<
|
||||
>
|
||||
struct ScaledLinCombPerRowBiasEltActAmaxAux
|
||||
: ScaledLinCombPerRowBiasEltAct<ActivationFn_, ElementOutput_, ElementCompute_,
|
||||
ElementBias_, ElementScalar_, AlignmentBias_, RoundStyle_> {
|
||||
ElementBias_, ElementSource_, ElementScalar_, AlignmentBias_, RoundStyle_> {
|
||||
using ElementAmax = ElementAmax_;
|
||||
static constexpr bool IsAbsMaxSupported = true;
|
||||
|
||||
@ -257,12 +268,16 @@ template<
|
||||
class ElementOutput_,
|
||||
class ElementCompute_,
|
||||
class ElementAux_ = ElementOutput_,
|
||||
class ElementSource_ = ElementOutput_,
|
||||
class ElementScalar_ = ElementCompute_,
|
||||
int AlignmentAux_ = 128 / sizeof_bits_v<ElementAux_>,
|
||||
FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
struct LinCombDeEltAct
|
||||
: LinCombEltAct<ActivationFn_, ElementOutput_, ElementCompute_, ElementScalar_, RoundStyle_> {
|
||||
: LinearCombination<ElementOutput_, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_> {
|
||||
using ActivationFn = ActivationFn_<ElementCompute_>;
|
||||
static constexpr bool IsDeEltActSupported = true;
|
||||
|
||||
using ElementAux = ElementAux_;
|
||||
using GmemLayoutTagAux = GmemLayoutTagAux_;
|
||||
static constexpr int AlignmentAux = AlignmentAux_;
|
||||
@ -280,6 +295,7 @@ template<
|
||||
class ElementCompute_,
|
||||
class ElementAux_ = ElementOutput_,
|
||||
class ElementBias_ = ElementCompute_,
|
||||
class ElementSource_ = ElementOutput_,
|
||||
class ElementScalar_ = ElementCompute_,
|
||||
int AlignmentAux_ = 128 / sizeof_bits_v<ElementAux_>,
|
||||
int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
|
||||
@ -287,7 +303,7 @@ template<
|
||||
>
|
||||
struct LinCombDeEltActDePerRowBias
|
||||
: LinCombDeEltAct<GmemLayoutTagAux_, ActivationFn_, ElementOutput_, ElementCompute_,
|
||||
ElementAux_, ElementScalar_, AlignmentAux_, RoundStyle_> {
|
||||
ElementAux_, ElementSource_, ElementScalar_, AlignmentAux_, RoundStyle_> {
|
||||
using ElementBias = ElementBias_;
|
||||
static constexpr int AlignmentBias = AlignmentBias_;
|
||||
static constexpr bool IsDePerRowBiasSupported = true;
|
||||
|
||||
@ -113,13 +113,14 @@ struct FusionCallbacks<
|
||||
template<
|
||||
class ElementOutput,
|
||||
class ElementCompute,
|
||||
class ElementSource = ElementOutput,
|
||||
class ElementScalar = ElementCompute,
|
||||
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
using Sm90LinearCombination =
|
||||
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + (alpha * acc)
|
||||
Sm90ScalarBroadcast<ElementScalar>, // beta
|
||||
Sm90SrcFetch, // C
|
||||
Sm90SrcFetch<ElementSource>, // C
|
||||
Sm90EVT<Sm90Compute<multiplies, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc
|
||||
Sm90ScalarBroadcast<ElementScalar>, // alpha
|
||||
Sm90AccFetch // acc
|
||||
@ -133,6 +134,7 @@ template <
|
||||
bool ReuseSmemC,
|
||||
class ElementOutput,
|
||||
class ElementCompute,
|
||||
class ElementSource,
|
||||
class ElementScalar,
|
||||
FloatRoundStyle RoundStyle,
|
||||
class CtaTileShapeMNK,
|
||||
@ -140,13 +142,13 @@ template <
|
||||
>
|
||||
struct FusionCallbacks<
|
||||
epilogue::Sm90TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmemC>,
|
||||
fusion::LinearCombination<ElementOutput, ElementCompute, ElementScalar, RoundStyle>,
|
||||
fusion::LinearCombination<ElementOutput, ElementCompute, ElementSource, ElementScalar, RoundStyle>,
|
||||
CtaTileShapeMNK,
|
||||
EpilogueTile
|
||||
> : Sm90LinearCombination<typename cutlass::detail::get_unpacked_element_type<ElementOutput>::type, ElementCompute, ElementScalar, RoundStyle> {
|
||||
> : Sm90LinearCombination<typename cutlass::detail::get_unpacked_element_type<ElementOutput>::type, ElementCompute, ElementSource, ElementScalar, RoundStyle> {
|
||||
|
||||
using Impl = Sm90LinearCombination<typename cutlass::detail::get_unpacked_element_type<ElementOutput>::type, ElementCompute, ElementScalar, RoundStyle>;
|
||||
using Operation = fusion::LinearCombination<ElementOutput, ElementCompute, ElementScalar, RoundStyle>;
|
||||
using Impl = Sm90LinearCombination<typename cutlass::detail::get_unpacked_element_type<ElementOutput>::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>;
|
||||
using Operation = fusion::LinearCombination<ElementOutput, ElementCompute, ElementSource, ElementScalar, RoundStyle>;
|
||||
|
||||
struct Arguments {
|
||||
ElementScalar alpha = ElementScalar(1);
|
||||
@ -180,12 +182,13 @@ template<
|
||||
template <class> class ActivationFn,
|
||||
class ElementOutput,
|
||||
class ElementCompute,
|
||||
class ElementSource = ElementOutput,
|
||||
class ElementScalar = ElementCompute,
|
||||
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
using Sm90LinCombEltAct =
|
||||
Sm90EVT<Sm90Compute<ActivationFn, ElementOutput, ElementCompute, RoundStyle>, // activation(beta * C + (alpha * acc))
|
||||
Sm90LinearCombination<ElementCompute, ElementCompute, ElementScalar, RoundStyle> // beta * C + (alpha * acc)
|
||||
Sm90LinearCombination<ElementCompute, ElementCompute, ElementSource, ElementScalar, RoundStyle> // beta * C + (alpha * acc)
|
||||
>;
|
||||
|
||||
template <
|
||||
@ -196,6 +199,7 @@ template <
|
||||
template <class> class ActivationFn,
|
||||
class ElementOutput,
|
||||
class ElementCompute,
|
||||
class ElementSource,
|
||||
class ElementScalar,
|
||||
FloatRoundStyle RoundStyle,
|
||||
class CtaTileShapeMNK,
|
||||
@ -203,13 +207,13 @@ template <
|
||||
>
|
||||
struct FusionCallbacks<
|
||||
epilogue::Sm90TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmemC>,
|
||||
fusion::LinCombEltAct<ActivationFn, ElementOutput, ElementCompute, ElementScalar, RoundStyle>,
|
||||
fusion::LinCombEltAct<ActivationFn, ElementOutput, ElementCompute, ElementSource, ElementScalar, RoundStyle>,
|
||||
CtaTileShapeMNK,
|
||||
EpilogueTile
|
||||
> : Sm90LinCombEltAct<ActivationFn, ElementOutput, ElementCompute, ElementScalar, RoundStyle> {
|
||||
> : Sm90LinCombEltAct<ActivationFn, ElementOutput, ElementCompute, ElementSource, ElementScalar, RoundStyle> {
|
||||
|
||||
using Impl = Sm90LinCombEltAct<ActivationFn, typename cutlass::detail::get_unpacked_element_type<ElementOutput>::type, ElementCompute, ElementScalar, RoundStyle>;
|
||||
using Operation = fusion::LinCombEltAct<ActivationFn, ElementOutput, ElementCompute, ElementScalar, RoundStyle>;
|
||||
using Impl = Sm90LinCombEltAct<ActivationFn, typename cutlass::detail::get_unpacked_element_type<ElementOutput>::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>;
|
||||
using Operation = fusion::LinCombEltAct<ActivationFn, ElementOutput, ElementCompute, ElementSource, ElementScalar, RoundStyle>;
|
||||
|
||||
struct Arguments {
|
||||
ElementScalar alpha = ElementScalar(1);
|
||||
@ -250,6 +254,7 @@ template<
|
||||
class ElementOutput,
|
||||
class ElementCompute,
|
||||
class ElementBias = ElementOutput,
|
||||
class ElementSource = ElementOutput,
|
||||
class ElementScalar = ElementCompute,
|
||||
int AlignmentBias = 128 / sizeof_bits_v<ElementBias>,
|
||||
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
|
||||
@ -257,7 +262,7 @@ template<
|
||||
using Sm90LinCombPerRowBias =
|
||||
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + (alpha * acc + bias)
|
||||
Sm90ScalarBroadcast<ElementScalar>, // beta
|
||||
Sm90SrcFetch, // C
|
||||
Sm90SrcFetch<ElementSource>, // C
|
||||
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc + bias
|
||||
Sm90ScalarBroadcast<ElementScalar>, // alpha
|
||||
Sm90AccFetch, // acc
|
||||
@ -273,6 +278,7 @@ template <
|
||||
class ElementOutput,
|
||||
class ElementCompute,
|
||||
class ElementBias,
|
||||
class ElementSource,
|
||||
class ElementScalar,
|
||||
int AlignmentBias,
|
||||
FloatRoundStyle RoundStyle,
|
||||
@ -281,15 +287,15 @@ template <
|
||||
>
|
||||
struct FusionCallbacks<
|
||||
epilogue::Sm90TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmemC>,
|
||||
fusion::LinCombPerRowBias<ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle>,
|
||||
fusion::LinCombPerRowBias<ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>,
|
||||
CtaTileShapeMNK,
|
||||
EpilogueTile
|
||||
> : Sm90LinCombPerRowBias<
|
||||
CtaTileShapeMNK, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle> {
|
||||
CtaTileShapeMNK, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle> {
|
||||
using Impl = Sm90LinCombPerRowBias<
|
||||
CtaTileShapeMNK, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle>;
|
||||
CtaTileShapeMNK, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>;
|
||||
using Operation = fusion::LinCombPerRowBias<
|
||||
ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle>;
|
||||
ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>;
|
||||
|
||||
struct Arguments {
|
||||
ElementScalar alpha = ElementScalar(1);
|
||||
@ -330,13 +336,14 @@ template<
|
||||
class ElementOutput,
|
||||
class ElementCompute,
|
||||
class ElementBias = ElementOutput,
|
||||
class ElementSource = ElementOutput,
|
||||
class ElementScalar = ElementCompute,
|
||||
int AlignmentBias = 128 / sizeof_bits_v<ElementBias>,
|
||||
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
using Sm90LinCombPerRowBiasEltAct =
|
||||
Sm90EVT<Sm90Compute<ActivationFn, ElementOutput, ElementCompute, RoundStyle>,
|
||||
Sm90LinCombPerRowBias<CtaTileShapeMNK, ElementCompute, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle>
|
||||
Sm90LinCombPerRowBias<CtaTileShapeMNK, ElementCompute, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>
|
||||
>;
|
||||
|
||||
template <
|
||||
@ -348,6 +355,7 @@ template <
|
||||
class ElementOutput,
|
||||
class ElementCompute,
|
||||
class ElementBias,
|
||||
class ElementSource,
|
||||
class ElementScalar,
|
||||
int AlignmentBias,
|
||||
FloatRoundStyle RoundStyle,
|
||||
@ -357,21 +365,21 @@ template <
|
||||
struct FusionCallbacks<
|
||||
epilogue::Sm90TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmemC>,
|
||||
fusion::LinCombPerRowBiasEltAct<
|
||||
ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle
|
||||
ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle
|
||||
>,
|
||||
CtaTileShapeMNK,
|
||||
EpilogueTile
|
||||
> : Sm90LinCombPerRowBiasEltAct<
|
||||
CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle
|
||||
CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle
|
||||
> {
|
||||
|
||||
using Impl =
|
||||
Sm90LinCombPerRowBiasEltAct<
|
||||
CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle
|
||||
CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle
|
||||
>;
|
||||
using Operation =
|
||||
fusion::LinCombPerRowBiasEltAct<
|
||||
ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle
|
||||
ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle
|
||||
>;
|
||||
|
||||
struct Arguments {
|
||||
@ -426,6 +434,7 @@ template<
|
||||
class ElementCompute,
|
||||
class ElementAux = ElementOutput,
|
||||
class ElementBias = ElementOutput,
|
||||
class ElementSource = ElementOutput,
|
||||
class ElementScalar = ElementCompute,
|
||||
int AlignmentAux = 128 / sizeof_bits_v<ElementAux>,
|
||||
int AlignmentBias = 128 / sizeof_bits_v<ElementBias>,
|
||||
@ -434,7 +443,7 @@ template<
|
||||
using Sm90LinCombPerRowBiasEltActAux =
|
||||
Sm90EVT<Sm90Compute<ActivationFn, ElementOutput, ElementCompute, RoundStyle>,
|
||||
Sm90EVT<Sm90AuxStore<Stages, EpilogueTile, ElementAux, RoundStyle, StrideAux, SmemLayoutAtom, CopyOpR2S, AlignmentAux>,
|
||||
Sm90LinCombPerRowBias<CtaTileShapeMNK, ElementCompute, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle>
|
||||
Sm90LinCombPerRowBias<CtaTileShapeMNK, ElementCompute, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>
|
||||
>
|
||||
>;
|
||||
|
||||
@ -449,6 +458,7 @@ template <
|
||||
class ElementCompute,
|
||||
class ElementAux,
|
||||
class ElementBias,
|
||||
class ElementSource,
|
||||
class ElementScalar,
|
||||
int AlignmentAux,
|
||||
int AlignmentBias,
|
||||
@ -462,7 +472,7 @@ struct FusionCallbacks<
|
||||
epilogue::Sm90TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmemC>,
|
||||
fusion::LinCombPerRowBiasEltActAux<
|
||||
GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute,
|
||||
ElementAux, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
|
||||
ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
|
||||
>,
|
||||
CtaTileShapeMNK,
|
||||
EpilogueTile,
|
||||
@ -470,18 +480,18 @@ struct FusionCallbacks<
|
||||
CopyOpR2S
|
||||
> : Sm90LinCombPerRowBiasEltActAux<
|
||||
CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t<GmemLayoutTagAux>, SmemLayoutAtom, CopyOpR2S, ActivationFn,
|
||||
ElementOutput, ElementCompute, ElementAux, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
|
||||
ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
|
||||
> {
|
||||
|
||||
using Impl =
|
||||
Sm90LinCombPerRowBiasEltActAux<
|
||||
CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t<GmemLayoutTagAux>, SmemLayoutAtom, CopyOpR2S, ActivationFn,
|
||||
ElementOutput, ElementCompute, ElementAux, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
|
||||
ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
|
||||
>;
|
||||
using Operation =
|
||||
fusion::LinCombPerRowBiasEltActAux<
|
||||
GmemLayoutTagAux, ActivationFn,
|
||||
ElementOutput, ElementCompute, ElementAux, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
|
||||
ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
|
||||
>;
|
||||
|
||||
struct Arguments {
|
||||
@ -535,6 +545,7 @@ template<
|
||||
class ElementOutput,
|
||||
class ElementCompute,
|
||||
class ElementBias = ElementOutput,
|
||||
class ElementSource = ElementOutput,
|
||||
class ElementScalar = ElementCompute,
|
||||
int AlignmentBias = 128 / sizeof_bits_v<ElementBias>,
|
||||
int AlignmentScalar = 128 / sizeof_bits_v<ElementScalar>,
|
||||
@ -543,7 +554,7 @@ template<
|
||||
using Sm90PerRowLinCombPerRowBias =
|
||||
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + (alpha * acc + bias)
|
||||
Sm90ColBroadcast<0, CtaTileShapeMNK, ElementScalar, Stride<_1,_0,_0>, AlignmentScalar>, // beta
|
||||
Sm90SrcFetch, // C
|
||||
Sm90SrcFetch<ElementSource>, // C
|
||||
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc + bias
|
||||
Sm90ColBroadcast<0, CtaTileShapeMNK, ElementScalar, Stride<_1,_0,_0>, AlignmentScalar>, // alpha
|
||||
Sm90AccFetch, // acc
|
||||
@ -558,6 +569,7 @@ template<
|
||||
class ElementOutput,
|
||||
class ElementCompute,
|
||||
class ElementBias = ElementOutput,
|
||||
class ElementSource = ElementOutput,
|
||||
class ElementScalar = ElementCompute,
|
||||
int AlignmentBias = 128 / sizeof_bits_v<ElementBias>,
|
||||
int AlignmentScalar = 128 / sizeof_bits_v<ElementScalar>,
|
||||
@ -566,7 +578,7 @@ template<
|
||||
using Sm90PerRowLinCombPerRowBiasEltAct =
|
||||
Sm90EVT<Sm90Compute<ActivationFn, ElementOutput, ElementCompute, RoundStyle>,
|
||||
Sm90PerRowLinCombPerRowBias<CtaTileShapeMNK, ElementCompute, ElementCompute,
|
||||
ElementBias, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle>
|
||||
ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle>
|
||||
>;
|
||||
|
||||
template <
|
||||
@ -578,6 +590,7 @@ template <
|
||||
class ElementOutput,
|
||||
class ElementCompute,
|
||||
class ElementBias,
|
||||
class ElementSource,
|
||||
class ElementScalar,
|
||||
int AlignmentBias,
|
||||
int AlignmentScalar,
|
||||
@ -588,21 +601,21 @@ template <
|
||||
struct FusionCallbacks<
|
||||
epilogue::Sm90TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmemC>,
|
||||
fusion::PerRowLinCombPerRowBiasEltAct<
|
||||
ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle
|
||||
ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle
|
||||
>,
|
||||
CtaTileShapeMNK,
|
||||
EpilogueTile
|
||||
> : Sm90PerRowLinCombPerRowBiasEltAct<
|
||||
CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle
|
||||
CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle
|
||||
> {
|
||||
|
||||
using Impl =
|
||||
Sm90PerRowLinCombPerRowBiasEltAct<
|
||||
CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle
|
||||
CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle
|
||||
>;
|
||||
using Operation =
|
||||
fusion::PerRowLinCombPerRowBiasEltAct<
|
||||
ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle
|
||||
ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle
|
||||
>;
|
||||
|
||||
struct Arguments {
|
||||
@ -664,6 +677,7 @@ template<
|
||||
class ElementOutput,
|
||||
class ElementCompute,
|
||||
class ElementBias = ElementOutput,
|
||||
class ElementSource = ElementOutput,
|
||||
class ElementScalar = ElementCompute,
|
||||
int AlignmentBias = 128 / sizeof_bits_v<ElementBias>,
|
||||
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
|
||||
@ -671,7 +685,7 @@ template<
|
||||
using Sm90ScaledLinCombPerRowBias =
|
||||
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + (alpha * acc + bias)
|
||||
Sm90ScalarBroadcast<ElementScalar, Stride<_0,_0,_0>, 2>, // scale_c * beta
|
||||
Sm90SrcFetch, // C
|
||||
Sm90SrcFetch<ElementSource>, // C
|
||||
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc + bias
|
||||
Sm90ScalarBroadcast<ElementScalar, Stride<_0,_0,_0>, 3>, // scale_a * scale_b * alpha
|
||||
Sm90AccFetch, // acc
|
||||
@ -690,6 +704,7 @@ template<
|
||||
class ElementOutput,
|
||||
class ElementCompute,
|
||||
class ElementBias = ElementOutput,
|
||||
class ElementSource = ElementOutput,
|
||||
class ElementScalar = ElementCompute,
|
||||
int AlignmentBias = 128 / sizeof_bits_v<ElementBias>,
|
||||
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
|
||||
@ -698,7 +713,7 @@ using Sm90ScaledLinCombPerRowBiasEltAct =
|
||||
Sm90EVT<Sm90Compute<detail::ScaleOutOp<ElementOutput>::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d
|
||||
Sm90EVT<Sm90Compute<ActivationFn, ElementCompute, ElementCompute, RoundStyle>, // activation(Z)
|
||||
// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias
|
||||
Sm90ScaledLinCombPerRowBias<CtaTileShapeMNK, ElementCompute, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle>
|
||||
Sm90ScaledLinCombPerRowBias<CtaTileShapeMNK, ElementCompute, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>
|
||||
>,
|
||||
Sm90ScalarBroadcast<ElementScalar> // scale_d
|
||||
>;
|
||||
@ -712,6 +727,7 @@ template <
|
||||
class ElementOutput,
|
||||
class ElementCompute,
|
||||
class ElementBias,
|
||||
class ElementSource,
|
||||
class ElementScalar,
|
||||
int AlignmentBias,
|
||||
FloatRoundStyle RoundStyle,
|
||||
@ -721,21 +737,21 @@ template <
|
||||
struct FusionCallbacks<
|
||||
epilogue::Sm90TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmemC>,
|
||||
fusion::ScaledLinCombPerRowBiasEltAct<
|
||||
ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle
|
||||
ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle
|
||||
>,
|
||||
CtaTileShapeMNK,
|
||||
EpilogueTile
|
||||
> : Sm90ScaledLinCombPerRowBiasEltAct<
|
||||
CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle
|
||||
CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle
|
||||
> {
|
||||
|
||||
using Impl =
|
||||
Sm90ScaledLinCombPerRowBiasEltAct<
|
||||
CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle
|
||||
CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle
|
||||
>;
|
||||
using Operation =
|
||||
fusion::ScaledLinCombPerRowBiasEltAct<
|
||||
ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle
|
||||
ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle
|
||||
>;
|
||||
|
||||
struct Arguments {
|
||||
@ -819,6 +835,7 @@ template<
|
||||
class ElementAux = ElementOutput,
|
||||
class ElementAmax = ElementCompute,
|
||||
class ElementBias = ElementOutput,
|
||||
class ElementSource = ElementOutput,
|
||||
class ElementScalar = ElementCompute,
|
||||
int AlignmentAux = 128 / sizeof_bits_v<ElementAux>,
|
||||
int AlignmentBias = 128 / sizeof_bits_v<ElementBias>,
|
||||
@ -827,7 +844,7 @@ template<
|
||||
using Sm90ScaledLinCombPerRowBiasEltActAmaxAux =
|
||||
Sm90SplitTreeVisitor<
|
||||
// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias
|
||||
Sm90ScaledLinCombPerRowBias<CtaTileShapeMNK, ElementCompute, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle>,
|
||||
Sm90ScaledLinCombPerRowBias<CtaTileShapeMNK, ElementCompute, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>,
|
||||
// D = activation(Z) * scale_d, amax_d = max(abs(elements in D))
|
||||
Sm90EVT<Sm90Compute<detail::ScaleOutOp<ElementOutput>::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d
|
||||
Sm90EVT<Sm90ScalarReduction<detail::amax, atomic_maximum, ElementAmax, ElementCompute, RoundStyle>, // amax_d
|
||||
@ -860,6 +877,7 @@ template <
|
||||
class ElementAux,
|
||||
class ElementAmax,
|
||||
class ElementBias,
|
||||
class ElementSource,
|
||||
class ElementScalar,
|
||||
int AlignmentAux,
|
||||
int AlignmentBias,
|
||||
@ -873,7 +891,7 @@ struct FusionCallbacks<
|
||||
epilogue::Sm90TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmemC>,
|
||||
fusion::ScaledLinCombPerRowBiasEltActAmaxAux<
|
||||
GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute,
|
||||
ElementAux, ElementAmax, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
|
||||
ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
|
||||
>,
|
||||
CtaTileShapeMNK,
|
||||
EpilogueTile,
|
||||
@ -882,19 +900,19 @@ struct FusionCallbacks<
|
||||
> : Sm90ScaledLinCombPerRowBiasEltActAmaxAux<
|
||||
CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t<GmemLayoutTagAux>,
|
||||
SmemLayoutAtom, CopyOpR2S, ActivationFn,
|
||||
ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
|
||||
ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
|
||||
> {
|
||||
|
||||
using Impl =
|
||||
Sm90ScaledLinCombPerRowBiasEltActAmaxAux<
|
||||
CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t<GmemLayoutTagAux>,
|
||||
SmemLayoutAtom, CopyOpR2S, ActivationFn,
|
||||
ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
|
||||
ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
|
||||
>;
|
||||
using Operation =
|
||||
fusion::ScaledLinCombPerRowBiasEltActAmaxAux<
|
||||
GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute,
|
||||
ElementAux, ElementAmax, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
|
||||
ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
|
||||
>;
|
||||
|
||||
struct Arguments {
|
||||
@ -1014,13 +1032,14 @@ template<
|
||||
class ElementOutput,
|
||||
class ElementCompute,
|
||||
class ElementAux = ElementOutput,
|
||||
class ElementSource = ElementOutput,
|
||||
class ElementScalar = ElementCompute,
|
||||
int AlignmentAux = 128 / sizeof_bits_v<ElementAux>,
|
||||
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
using Sm90LinCombDeEltAct =
|
||||
Sm90EVT<Sm90Compute<ActivationFn, ElementOutput, ElementCompute, RoundStyle>, // activation(beta * C + (alpha * acc), aux)
|
||||
Sm90LinearCombination<ElementCompute, ElementCompute, ElementScalar, RoundStyle>, // beta * C + (alpha * acc)
|
||||
Sm90LinearCombination<ElementCompute, ElementCompute, ElementSource, ElementScalar, RoundStyle>, // beta * C + (alpha * acc)
|
||||
Sm90AuxLoad<Stages, EpilogueTile, ElementAux, StrideAux, SmemLayoutAtom, CopyOpS2R, AlignmentAux> // aux
|
||||
>;
|
||||
|
||||
@ -1034,6 +1053,7 @@ template <
|
||||
class ElementOutput,
|
||||
class ElementCompute,
|
||||
class ElementAux,
|
||||
class ElementSource,
|
||||
class ElementScalar,
|
||||
int AlignmentAux,
|
||||
FloatRoundStyle RoundStyle,
|
||||
@ -1046,7 +1066,7 @@ struct FusionCallbacks<
|
||||
epilogue::Sm90TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmemC>,
|
||||
fusion::LinCombDeEltAct<
|
||||
GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute,
|
||||
ElementAux, ElementScalar, AlignmentAux, RoundStyle
|
||||
ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle
|
||||
>,
|
||||
CtaTileShapeMNK,
|
||||
EpilogueTile,
|
||||
@ -1054,18 +1074,18 @@ struct FusionCallbacks<
|
||||
CopyOpS2R
|
||||
> : Sm90LinCombDeEltAct<
|
||||
CtaTileShapeMNK, EpilogueTile, StagesC, cutlass::gemm::TagToStrideC_t<GmemLayoutTagAux>, SmemLayoutAtom, CopyOpS2R, ActivationFn,
|
||||
ElementOutput, ElementCompute, ElementAux, ElementScalar, AlignmentAux, RoundStyle
|
||||
ElementOutput, ElementCompute, ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle
|
||||
> {
|
||||
|
||||
using Impl =
|
||||
Sm90LinCombDeEltAct<
|
||||
CtaTileShapeMNK, EpilogueTile, StagesC, cutlass::gemm::TagToStrideC_t<GmemLayoutTagAux>, SmemLayoutAtom, CopyOpS2R, ActivationFn,
|
||||
ElementOutput, ElementCompute, ElementAux, ElementScalar, AlignmentAux, RoundStyle
|
||||
ElementOutput, ElementCompute, ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle
|
||||
>;
|
||||
using Operation =
|
||||
fusion::LinCombDeEltAct<
|
||||
GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute,
|
||||
ElementAux, ElementScalar, AlignmentAux, RoundStyle
|
||||
ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle
|
||||
>;
|
||||
|
||||
struct Arguments {
|
||||
@ -1118,6 +1138,7 @@ template<
|
||||
class ElementCompute,
|
||||
class ElementAux = ElementOutput,
|
||||
class ElementBias = ElementOutput,
|
||||
class ElementSource = ElementOutput,
|
||||
class ElementScalar = ElementCompute,
|
||||
int AlignmentAux = 128 / sizeof_bits_v<ElementAux>,
|
||||
int AlignmentBias = 128 / sizeof_bits_v<ElementBias>,
|
||||
@ -1128,7 +1149,7 @@ using Sm90LinCombDeEltActDePerRowBias =
|
||||
Sm90EVT<Sm90ColReduction<plus, plus, 0, CtaTileShapeMNK,
|
||||
ElementBias, ElementCompute, RoundStyle, Stride<_1,_0,int>, AlignmentBias>,
|
||||
Sm90LinCombDeEltAct<CtaTileShapeMNK, EpilogueTile, Stages, StrideAux, SmemLayoutAtom, CopyOpS2R, ActivationFn,
|
||||
ElementCompute, ElementCompute, ElementAux, ElementScalar, AlignmentAux, RoundStyle>
|
||||
ElementCompute, ElementCompute, ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle>
|
||||
>
|
||||
>;
|
||||
|
||||
@ -1143,6 +1164,7 @@ template <
|
||||
class ElementCompute,
|
||||
class ElementAux,
|
||||
class ElementBias,
|
||||
class ElementSource,
|
||||
class ElementScalar,
|
||||
int AlignmentAux,
|
||||
int AlignmentBias,
|
||||
@ -1156,7 +1178,7 @@ struct FusionCallbacks<
|
||||
epilogue::Sm90TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmemC>,
|
||||
fusion::LinCombDeEltActDePerRowBias<
|
||||
GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute,
|
||||
ElementAux, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
|
||||
ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
|
||||
>,
|
||||
CtaTileShapeMNK,
|
||||
EpilogueTile,
|
||||
@ -1164,18 +1186,18 @@ struct FusionCallbacks<
|
||||
CopyOpS2R
|
||||
> : Sm90LinCombDeEltActDePerRowBias<
|
||||
CtaTileShapeMNK, EpilogueTile, StagesC, cutlass::gemm::TagToStrideC_t<GmemLayoutTagAux>, SmemLayoutAtom, CopyOpS2R, ActivationFn,
|
||||
ElementOutput, ElementCompute, ElementAux, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
|
||||
ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
|
||||
> {
|
||||
|
||||
using Impl =
|
||||
Sm90LinCombDeEltActDePerRowBias<
|
||||
CtaTileShapeMNK, EpilogueTile, StagesC, cutlass::gemm::TagToStrideC_t<GmemLayoutTagAux>, SmemLayoutAtom, CopyOpS2R, ActivationFn,
|
||||
ElementOutput, ElementCompute, ElementAux, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
|
||||
ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
|
||||
>;
|
||||
using Operation =
|
||||
fusion::LinCombDeEltActDePerRowBias<
|
||||
GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute,
|
||||
ElementAux, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
|
||||
ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle
|
||||
>;
|
||||
|
||||
struct Arguments {
|
||||
|
||||
@ -215,16 +215,17 @@ template <
|
||||
class StrideScalar,
|
||||
int ScalarCount,
|
||||
template <class> class ScalarReduceFn,
|
||||
class ElementSource,
|
||||
class InputAddOp // Z
|
||||
>
|
||||
struct Sm90TreeVisitor<
|
||||
Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>,
|
||||
Sm90ScalarBroadcast<ElementScalar, StrideScalar, ScalarCount, ScalarReduceFn>,
|
||||
Sm90SrcFetch,
|
||||
Sm90SrcFetch<ElementSource>,
|
||||
InputAddOp
|
||||
> : Sm90VisitorImpl<
|
||||
Sm90ScalarBroadcast<ElementScalar, StrideScalar, ScalarCount, ScalarReduceFn>,
|
||||
Sm90SrcFetch,
|
||||
Sm90SrcFetch<ElementSource>,
|
||||
InputAddOp,
|
||||
Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>
|
||||
>
|
||||
@ -232,11 +233,10 @@ struct Sm90TreeVisitor<
|
||||
using Impl =
|
||||
Sm90VisitorImpl<
|
||||
Sm90ScalarBroadcast<ElementScalar, StrideScalar, ScalarCount, ScalarReduceFn>,
|
||||
Sm90SrcFetch,
|
||||
Sm90SrcFetch<ElementSource>,
|
||||
InputAddOp,
|
||||
Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>
|
||||
>;
|
||||
|
||||
using Params = typename Impl::Params;
|
||||
using SharedStorage = typename Impl::SharedStorage;
|
||||
|
||||
@ -260,8 +260,9 @@ struct Sm90TreeVisitor<
|
||||
CUTLASS_DEVICE bool
|
||||
is_C_load_needed() const {
|
||||
auto const& bcast_op = get<0>(Impl::ops);
|
||||
auto const& src_op = get<1>(Impl::ops);
|
||||
auto const& added_op = get<2>(Impl::ops);
|
||||
return bcast_op.scalar != 0 || added_op.is_C_load_needed();
|
||||
return (bcast_op.scalar != 0 && src_op.is_C_load_needed()) || added_op.is_C_load_needed();
|
||||
}
|
||||
|
||||
template <class CallbacksImpl>
|
||||
@ -320,17 +321,9 @@ struct Sm90TreeVisitor<
|
||||
// ReLU with aux bit tensor dReLU/dZ
|
||||
// Aux(i) = Z(i) >= 0 ? 1 : 0
|
||||
namespace detail {
|
||||
template <
|
||||
class ElementOutput,
|
||||
class ElementCompute,
|
||||
FloatRoundStyle RoundStyle,
|
||||
class StrideMNL,
|
||||
int Alignment,
|
||||
bool EnableNullptr
|
||||
>
|
||||
struct Sm90ReLUAuxStore {
|
||||
static_assert(Alignment % 128 == 0, "sub-16B alignment not supported yet");
|
||||
|
||||
// Placeholder node so we can retain standard EVT structure
|
||||
template <class StrideMNL>
|
||||
struct Sm90ReLUAuxStore : Sm90VisitorImpl<> {
|
||||
struct SharedStorage {};
|
||||
|
||||
struct Arguments {
|
||||
@ -362,41 +355,90 @@ struct Sm90ReLUAuxStore {
|
||||
Sm90ReLUAuxStore() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90ReLUAuxStore(Params const& params, SharedStorage const& shared_storage)
|
||||
: params(params) { }
|
||||
Sm90ReLUAuxStore(Params const& params, SharedStorage const& shared_storage) { }
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
Params const params;
|
||||
// Specialization on the generic compute+aux EVT
|
||||
template <
|
||||
// Compute node
|
||||
template <class> class Activation,
|
||||
class ElementOutput,
|
||||
class ElementCompute,
|
||||
FloatRoundStyle RoundStyle,
|
||||
// Aux node
|
||||
int Stages,
|
||||
class EpilogueTile,
|
||||
class StrideMNL,
|
||||
class SmemLayoutAtom,
|
||||
class CopyOpR2S,
|
||||
int Alignment,
|
||||
bool EnableNullptr,
|
||||
// Input node
|
||||
class InputOp
|
||||
>
|
||||
struct Sm90TreeVisitor<
|
||||
Sm90Compute<Activation, ElementOutput, ElementCompute, RoundStyle,
|
||||
enable_if_t<is_same_v<Activation<ElementCompute>, cutlass::epilogue::thread::ReLu<ElementCompute>> ||
|
||||
is_same_v<Activation<ElementCompute>, cutlass::epilogue::thread::Clamp<ElementCompute>> >>,
|
||||
Sm90TreeVisitor<
|
||||
Sm90AuxStore<
|
||||
Stages,
|
||||
EpilogueTile,
|
||||
cutlass::uint1b_t,
|
||||
RoundStyle,
|
||||
StrideMNL,
|
||||
SmemLayoutAtom,
|
||||
CopyOpR2S,
|
||||
Alignment,
|
||||
EnableNullptr
|
||||
>,
|
||||
InputOp
|
||||
>
|
||||
> : Sm90VisitorImpl<
|
||||
Sm90VisitorImpl<
|
||||
InputOp,
|
||||
detail::Sm90ReLUAuxStore<StrideMNL>
|
||||
>,
|
||||
Sm90Compute<Activation, ElementOutput, ElementCompute, RoundStyle>
|
||||
>
|
||||
{
|
||||
using Impl =
|
||||
Sm90VisitorImpl<
|
||||
Sm90VisitorImpl<
|
||||
InputOp,
|
||||
detail::Sm90ReLUAuxStore<StrideMNL>
|
||||
>,
|
||||
Sm90Compute<Activation, ElementOutput, ElementCompute, RoundStyle>
|
||||
>;
|
||||
using Params = typename Impl::Params;
|
||||
using SharedStorage = typename Impl::SharedStorage;
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_producer_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90TreeVisitor() {}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_C_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90TreeVisitor(Params const& params_, SharedStorage const& shared_storage)
|
||||
: params(params_), Impl(params_, shared_storage) {}
|
||||
|
||||
template <class... Args>
|
||||
CUTLASS_DEVICE auto
|
||||
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
||||
return EmptyProducerLoadCallbacks{};
|
||||
}
|
||||
Params const& params;
|
||||
|
||||
template <class RTensor, class GTensor, class CTensor, class ResidueMN>
|
||||
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
||||
template <class RTensor, class GTensor, class CTensor, class ResidueMN, class CallbacksImpl>
|
||||
struct ConsumerStoreCallbacks : CallbacksImpl {
|
||||
CUTLASS_DEVICE
|
||||
ConsumerStoreCallbacks(
|
||||
RTensor&& tC_rAux,
|
||||
GTensor&& tC_gAux,
|
||||
CTensor tC_cAux,
|
||||
ResidueMN residue_mn,
|
||||
Params const& params)
|
||||
Params const& params,
|
||||
CallbacksImpl&& impl)
|
||||
: tC_rAux(cute::forward<RTensor>(tC_rAux)),
|
||||
tC_gAux(cute::forward<GTensor>(tC_gAux)),
|
||||
tC_cAux(tC_cAux),
|
||||
residue_mn(residue_mn),
|
||||
params(params) {}
|
||||
params(params),
|
||||
CallbacksImpl(cute::forward<CallbacksImpl>(impl)) {}
|
||||
|
||||
RTensor tC_rAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
GTensor tC_gAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
@ -404,13 +446,23 @@ struct Sm90ReLUAuxStore {
|
||||
ResidueMN residue_mn;
|
||||
Params const& params;
|
||||
|
||||
template <typename ElementAccumulator, typename ElementInput, int FragmentSize>
|
||||
CUTLASS_DEVICE auto
|
||||
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n,
|
||||
Array<ElementInput, FragmentSize> const& frg_input) {
|
||||
template <typename ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE Array<ElementOutput, FragmentSize>
|
||||
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
||||
// Unpack callbacks + params
|
||||
auto& [callbacks_input_aux, callbacks_compute] = CallbacksImpl::callbacks_tuple;
|
||||
auto& [callbacks_input, callbacks_aux] = callbacks_input_aux.callbacks_tuple;
|
||||
auto const& [params_input_aux, params_compute] = params;
|
||||
auto const& [params_input, params_aux] = params_input_aux;
|
||||
|
||||
// Visit the input node
|
||||
Array frg_input = callbacks_input.visit(frg_acc, epi_v, epi_m, epi_n);
|
||||
|
||||
// Compute activation + aux
|
||||
using ElementInput = typename decltype(frg_input)::Element;
|
||||
using ConvertInput = NumericArrayConverter<ElementCompute, ElementInput, FragmentSize, RoundStyle>;
|
||||
using ConvertAux = PackPredicates<FragmentSize>;
|
||||
using ComputeOutput = cutlass::epilogue::thread::ReLu<ElementCompute>;
|
||||
using ComputeOutput = Activation<ElementCompute>;
|
||||
using ConvertOutput = NumericArrayConverter<ElementOutput, ElementCompute, FragmentSize, RoundStyle>;
|
||||
ConvertInput convert_input{};
|
||||
ComputeOutput relu{};
|
||||
@ -422,7 +474,12 @@ struct Sm90ReLUAuxStore {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < FragmentSize; ++i) {
|
||||
ElementCompute pre_relu = frg_compute[i];
|
||||
frg_compute[i] = relu(frg_compute[i]);
|
||||
if constexpr (is_same_v<Activation<ElementCompute>, cutlass::epilogue::thread::Clamp<ElementCompute>>) {
|
||||
frg_compute[i] = relu(frg_compute[i], params_compute);
|
||||
}
|
||||
else {
|
||||
frg_compute[i] = relu(frg_compute[i]);
|
||||
}
|
||||
frg_aux[i] = frg_compute[i] == pre_relu;
|
||||
}
|
||||
|
||||
@ -435,8 +492,18 @@ struct Sm90ReLUAuxStore {
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
end() {
|
||||
// Unpack callbacks + params
|
||||
auto& [callbacks_input_aux, callbacks_compute] = CallbacksImpl::callbacks_tuple;
|
||||
auto& [callbacks_input, callbacks_aux] = callbacks_input_aux.callbacks_tuple;
|
||||
auto const& [params_input_aux, params_compute] = params;
|
||||
auto const& [params_input, params_aux] = params_input_aux;
|
||||
|
||||
// Visit the input node
|
||||
callbacks_input.end();
|
||||
|
||||
// Nullptr is no-op
|
||||
if constexpr (EnableNullptr) {
|
||||
if (params.ptr_aux == nullptr) {
|
||||
if (params_aux.ptr_aux == nullptr) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
@ -473,114 +540,25 @@ struct Sm90ReLUAuxStore {
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
// Unpack params
|
||||
auto const& [params_input_aux, params_compute] = params;
|
||||
auto const& [params_input, params_aux] = params_input_aux;
|
||||
|
||||
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||
gmem_ptr ptr_aux = make_gmem_ptr(subbyte_iterator<cutlass::uint1b_t>(params.ptr_aux));
|
||||
Tensor mAux = make_tensor(ptr_aux, make_layout(make_shape(M,N,L), params.dAux)); // (M,N,L)
|
||||
gmem_ptr ptr_aux = make_gmem_ptr(subbyte_iterator<cutlass::uint1b_t>(params_aux.ptr_aux));
|
||||
Tensor mAux = make_tensor(ptr_aux, make_layout(make_shape(M,N,L), params_aux.dAux)); // (M,N,L)
|
||||
Tensor gAux = local_tile(mAux, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N)
|
||||
|
||||
Tensor tC_gAux = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
gAux, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
Tensor tC_rAux = make_tensor<cutlass::uint1b_t>(shape(tC_gAux)); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
|
||||
return ConsumerStoreCallbacks<decltype(tC_rAux), decltype(tC_gAux), decltype(args.tCcD), decltype(args.residue_mn)>(
|
||||
cute::move(tC_rAux), cute::move(tC_gAux), args.tCcD, args.residue_mn, params);
|
||||
auto callbacks_impl = Impl::template get_consumer_store_callbacks<ReferenceSrc>(args);
|
||||
return ConsumerStoreCallbacks<decltype(tC_rAux), decltype(tC_gAux), decltype(args.tCcD), decltype(args.residue_mn), decltype(callbacks_impl)>(
|
||||
cute::move(tC_rAux), cute::move(tC_gAux), args.tCcD, args.residue_mn, params, cute::move(callbacks_impl));
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
// Specialization on the generic compute+aux EVT
|
||||
template <
|
||||
// Compute node
|
||||
template <class> class Activation,
|
||||
class ElementOutput,
|
||||
class ElementCompute,
|
||||
FloatRoundStyle RoundStyle,
|
||||
// Aux node
|
||||
int Stages,
|
||||
class EpilogueTile,
|
||||
class StrideMNL,
|
||||
class SmemLayoutAtom,
|
||||
class CopyOpR2S,
|
||||
int Alignment,
|
||||
bool EnableNullptr,
|
||||
// Input node
|
||||
class InputOp
|
||||
>
|
||||
struct Sm90TreeVisitor<
|
||||
Sm90Compute<Activation, ElementOutput, ElementCompute, RoundStyle,
|
||||
enable_if_t<is_same_v<Activation<ElementCompute>, cutlass::epilogue::thread::ReLu<ElementCompute>>, void>>,
|
||||
Sm90TreeVisitor<
|
||||
Sm90AuxStore<
|
||||
Stages,
|
||||
EpilogueTile,
|
||||
cutlass::uint1b_t,
|
||||
RoundStyle,
|
||||
StrideMNL,
|
||||
SmemLayoutAtom,
|
||||
CopyOpR2S,
|
||||
Alignment,
|
||||
EnableNullptr
|
||||
>,
|
||||
InputOp
|
||||
>
|
||||
> : Sm90VisitorImpl<
|
||||
Sm90VisitorImpl<
|
||||
InputOp,
|
||||
detail::Sm90ReLUAuxStore<ElementOutput, ElementCompute, RoundStyle, StrideMNL, Alignment, EnableNullptr>
|
||||
>,
|
||||
Sm90Compute<Activation, ElementOutput, ElementCompute, RoundStyle>
|
||||
>
|
||||
{
|
||||
using Impl =
|
||||
Sm90VisitorImpl<
|
||||
Sm90VisitorImpl<
|
||||
InputOp,
|
||||
detail::Sm90ReLUAuxStore<ElementOutput, ElementCompute, RoundStyle, StrideMNL, Alignment, EnableNullptr>
|
||||
>,
|
||||
Sm90Compute<Activation, ElementOutput, ElementCompute, RoundStyle>
|
||||
>;
|
||||
|
||||
using Params = typename Impl::Params;
|
||||
using SharedStorage = typename Impl::SharedStorage;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90TreeVisitor() {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90TreeVisitor(
|
||||
Params const& params,
|
||||
SharedStorage const& shared_storage)
|
||||
: Impl(params, shared_storage) {}
|
||||
|
||||
template <class CallbacksImpl>
|
||||
struct ConsumerStoreCallbacks : CallbacksImpl {
|
||||
CUTLASS_DEVICE
|
||||
ConsumerStoreCallbacks(CallbacksImpl&& impl)
|
||||
: CallbacksImpl(cute::forward<CallbacksImpl>(impl)) { }
|
||||
|
||||
template <typename ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE Array<ElementOutput, FragmentSize>
|
||||
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
||||
auto& [callbacks_input, callbacks_relu_aux] = get<0>(CallbacksImpl::callbacks_tuple).callbacks_tuple;
|
||||
|
||||
Array frg_input = callbacks_input.visit(frg_acc, epi_v, epi_m, epi_n);
|
||||
return callbacks_relu_aux.visit(frg_acc, epi_v, epi_m, epi_n, frg_input);
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
|
||||
class... Args
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
auto callbacks_tuple = Impl::template get_consumer_store_callbacks<ReferenceSrc>(args);
|
||||
return ConsumerStoreCallbacks<decltype(callbacks_tuple)>(std::move(callbacks_tuple));
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
// Aux load for uint1b_t
|
||||
template <
|
||||
|
||||
@ -85,16 +85,17 @@ using Sm90SplitTreeFetch = Sm90AccFetch;
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// returns C
|
||||
template <class Element>
|
||||
struct Sm90SrcFetch : Sm90VisitorImpl<> {
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_producer_load_needed() const {
|
||||
return true;
|
||||
return is_C_load_needed();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_C_load_needed() const {
|
||||
return true;
|
||||
return not is_void_v<Element>;
|
||||
}
|
||||
|
||||
using Sm90VisitorImpl<>::Sm90VisitorImpl;
|
||||
@ -105,7 +106,6 @@ struct Sm90SrcFetch : Sm90VisitorImpl<> {
|
||||
ConsumerStoreCallbacks(SrcTensor const& tCrC)
|
||||
: tCrC(tCrC) {}
|
||||
|
||||
// make this a pointer if we need default ctor for generic tuple of visitors
|
||||
SrcTensor const& tCrC; // (CPY,CPY_M,CPY_N)
|
||||
|
||||
template <typename ElementAccumulator, int FragmentSize>
|
||||
@ -122,7 +122,7 @@ struct Sm90SrcFetch : Sm90VisitorImpl<> {
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
|
||||
// register type may differ from logical type so we can't assert matching types here
|
||||
return ConsumerStoreCallbacks(args.tCrC);
|
||||
}
|
||||
};
|
||||
@ -344,7 +344,6 @@ struct Sm90AuxLoad {
|
||||
make_tensor(make_smem_ptr(smem_aux), SmemLayout{})); // (EPI_TILE_M,EPI_TILE_N,PIPE)
|
||||
auto tSR_sAux = tiled_s2r.get_slice(args.thread_idx).partition_S(sAux_epi); // (S2R,S2R_M,S2R_N,PIPE)
|
||||
|
||||
|
||||
return ConsumerStoreCallbacks<decltype(tC_rAux), decltype(tiled_s2r), decltype(tSR_sAux)>(
|
||||
cute::move(tC_rAux), tiled_s2r, cute::move(tSR_sAux), params_ptr);
|
||||
}
|
||||
|
||||
@ -303,7 +303,7 @@ private:
|
||||
static constexpr bool IsAtomic = is_atomic<GmemReduceFn<ElementCompute>>::value;
|
||||
static_assert(IsAtomic, "non-atomic scalar reduction not supported yet");
|
||||
|
||||
public:
|
||||
public:
|
||||
struct SharedStorage { };
|
||||
|
||||
struct Arguments {
|
||||
@ -814,7 +814,7 @@ public:
|
||||
}
|
||||
|
||||
auto& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout,
|
||||
lane_layout_MN, lane_mn, warp_layout_MN, warp_mn,
|
||||
lane_layout_MN, lane_mn, warp_layout_MN, warp_mn,
|
||||
tile_coord_mnkl, residue_mn, epi_tile, tiled_copy, thread_idx] = args_tuple;
|
||||
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
|
||||
Tensor tCcCol_mn = tCcCol(_,_,_,epi_m,epi_n);
|
||||
@ -843,8 +843,8 @@ public:
|
||||
return;
|
||||
}
|
||||
|
||||
auto& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout,
|
||||
lane_layout_MN, lane_mn, warp_layout_MN, warp_mn,
|
||||
auto& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout,
|
||||
lane_layout_MN, lane_mn, warp_layout_MN, warp_mn,
|
||||
tile_coord_mnkl, residue_mn, epi_tile, tiled_copy, thread_idx] = args_tuple;
|
||||
auto [m, n, k, l] = tile_coord_mnkl;
|
||||
constexpr bool ReferenceSrc = decltype(ref_src)::value;
|
||||
@ -1002,15 +1002,15 @@ public:
|
||||
return;
|
||||
}
|
||||
|
||||
auto& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout,
|
||||
lane_layout_MN, lane_mn, warp_layout_MN, warp_mn,
|
||||
auto& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout,
|
||||
lane_layout_MN, lane_mn, warp_layout_MN, warp_mn,
|
||||
tile_coord_mnkl, residue_mn, epi_tile, tiled_copy, thread_idx] = args_tuple;
|
||||
|
||||
using ReduceOutput = GmemReduceFn<ElementCompute>;
|
||||
using ConvertOutput = NumericConverter<ElementOutput, ElementCompute, RoundStyle>;
|
||||
ReduceOutput reduce_output{};
|
||||
ConvertOutput convert_output{};
|
||||
|
||||
|
||||
// Reduction over batches
|
||||
if (size<2>(stride(gCol_l)) == 0) {
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
@ -1051,8 +1051,8 @@ public:
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_reduction_buffer_needed(int epi_m, int epi_n, bool is_last_iteration) const {
|
||||
auto const& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout,
|
||||
lane_layout_MN, lane_mn, warp_layout_MN, warp_mn,
|
||||
auto const& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout,
|
||||
lane_layout_MN, lane_mn, warp_layout_MN, warp_mn,
|
||||
tile_coord_mnkl, residue_mn, epi_tile, tiled_copy, thread_idx] = args_tuple;
|
||||
|
||||
return (not IsAtomic && // atomic reduction doesn't use smem
|
||||
@ -1111,7 +1111,7 @@ public:
|
||||
|
||||
auto args_tuple = make_tuple(
|
||||
bool_constant<ReferenceSrc>{}, cute::move(tCrCol), args.tCcD, gCol_l, args.cD, gBuf_nl, sBuf_layout,
|
||||
lane_layout_MN, lane_mn, warp_layout_MN, warp_mn,
|
||||
lane_layout_MN, lane_mn, warp_layout_MN, warp_mn,
|
||||
args.tile_coord_mnkl, args.residue_mn, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
return ConsumerStoreCallbacks<decltype(args_tuple)>(std::move(args_tuple), params);
|
||||
}
|
||||
|
||||
@ -563,7 +563,6 @@ struct Sm90TreeVisitor : Sm90VisitorImpl<ChildOps..., NodeOp> {
|
||||
template get_consumer_store_callbacks<ReferenceSrc>(args);
|
||||
return ConsumerStoreCallbacks<decltype(callbacks_tuple)>(std::move(callbacks_tuple));
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -614,7 +613,6 @@ struct Sm90SplitTreeVisitor : Sm90VisitorImpl<InputTree, AuxOutTrees..., OutputT
|
||||
template get_consumer_store_callbacks<ReferenceSrc>(args);
|
||||
return ConsumerStoreCallbacks<decltype(callbacks_tuple)>(std::move(callbacks_tuple));
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -692,7 +690,6 @@ struct Sm90TopologicalVisitor : Sm90VisitorImpl<Ops...> {
|
||||
template get_consumer_store_callbacks<ReferenceSrc>(args);
|
||||
return ConsumerStoreCallbacks<decltype(callbacks_tuple)>(std::move(callbacks_tuple));
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -49,7 +49,6 @@ namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace thread {
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail {
|
||||
@ -66,13 +65,65 @@ struct ArrayMaximum {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < ElementsPerAccess; ++i) {
|
||||
result[i] = fmax(lhs[i], rhs[i]);
|
||||
result[i] = platform::max(lhs[i].get(), rhs[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<Element, ElementsPerAccess> operator()(
|
||||
Array<Element, ElementsPerAccess> const &lhs,
|
||||
Element rhs) const {
|
||||
|
||||
Array<Element, ElementsPerAccess> result;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < ElementsPerAccess; ++i) {
|
||||
result[i] = platform::max(lhs[i].get(), rhs);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// Partial specialization: Element=float
|
||||
template <int ElementsPerAccess>
|
||||
struct ArrayMaximum<float, ElementsPerAccess> {
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<float, ElementsPerAccess> operator()(
|
||||
Array<float, ElementsPerAccess> const &lhs,
|
||||
Array<float, ElementsPerAccess> const &rhs) const {
|
||||
|
||||
Array<float, ElementsPerAccess> result;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < ElementsPerAccess; ++i) {
|
||||
result[i] = fmax(lhs[i], rhs[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<float, ElementsPerAccess> operator()(
|
||||
Array<float, ElementsPerAccess> const &lhs,
|
||||
float rhs) const {
|
||||
|
||||
Array<float, ElementsPerAccess> result;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < ElementsPerAccess; ++i) {
|
||||
result[i] = fmax(lhs[i], rhs);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
/// Partial specialization: Element=half
|
||||
template <int ElementsPerAccess>
|
||||
struct ArrayMaximum<half_t, ElementsPerAccess> {
|
||||
|
||||
@ -96,6 +147,8 @@ struct ArrayMaximum<half_t, ElementsPerAccess> {
|
||||
res_ptr[i] = __hmax2(lhs_ptr[i], rhs_ptr[i]);
|
||||
}
|
||||
|
||||
static_assert(!(ElementsPerAccess % 2), "Output array must be divisible by vector length.");
|
||||
|
||||
#else
|
||||
__half const *lhs_ptr = reinterpret_cast<__half const *>(lhs.raw_data());
|
||||
__half const *rhs_ptr = reinterpret_cast<__half const *>(rhs.raw_data());
|
||||
@ -133,6 +186,8 @@ struct ArrayMaximum<half_t, ElementsPerAccess> {
|
||||
res_ptr[i] = __hmax2(lhs_ptr[i], rhs_pair);
|
||||
}
|
||||
|
||||
static_assert(!(ElementsPerAccess % 2), "Output array must be divisible by vector length.");
|
||||
|
||||
#else
|
||||
|
||||
__half const *lhs_ptr = reinterpret_cast<__half const *>(lhs.raw_data());
|
||||
@ -150,6 +205,90 @@ struct ArrayMaximum<half_t, ElementsPerAccess> {
|
||||
}
|
||||
};
|
||||
|
||||
/// Partial specialization: Element=bfloat16_t
|
||||
template <int ElementsPerAccess>
|
||||
struct ArrayMaximum<bfloat16_t, ElementsPerAccess> {
|
||||
|
||||
using NvType = __nv_bfloat16;
|
||||
using NvTypeV2 = __nv_bfloat162;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Array<bfloat16_t, ElementsPerAccess> operator()(
|
||||
Array<bfloat16_t, ElementsPerAccess> const &lhs,
|
||||
Array<bfloat16_t, ElementsPerAccess> const &rhs) const {
|
||||
|
||||
Array<bfloat16_t, ElementsPerAccess> result;
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
int const kVectorCount = ElementsPerAccess / 2;
|
||||
|
||||
|
||||
NvTypeV2 const *lhs_ptr = reinterpret_cast<NvTypeV2 const *>(lhs.raw_data());
|
||||
NvTypeV2 const *rhs_ptr = reinterpret_cast<NvTypeV2 const *>(rhs.raw_data());
|
||||
NvTypeV2 *res_ptr = reinterpret_cast<NvTypeV2 *>(result.raw_data());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kVectorCount; ++i) {
|
||||
res_ptr[i] = __hmax2(lhs_ptr[i], rhs_ptr[i]);
|
||||
}
|
||||
|
||||
#else
|
||||
NvType const *lhs_ptr = reinterpret_cast<NvType const *>(lhs.raw_data());
|
||||
NvType const *rhs_ptr = reinterpret_cast<NvType const *>(rhs.raw_data());
|
||||
NvType *res_ptr = reinterpret_cast<NvType *>(result.raw_data());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < ElementsPerAccess; ++i) {
|
||||
res_ptr[i] = ((lhs_ptr[i] < rhs_ptr[i]) ? rhs_ptr[i] : lhs_ptr[i]);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Array<bfloat16_t, ElementsPerAccess> operator()(
|
||||
Array<bfloat16_t, ElementsPerAccess> const &lhs,
|
||||
bfloat16_t rhs) const {
|
||||
|
||||
Array<bfloat16_t, ElementsPerAccess> result;
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
int const kVectorCount = ElementsPerAccess / 2;
|
||||
|
||||
|
||||
NvType rhs_raw = reinterpret_cast<NvType const &>(rhs);
|
||||
NvTypeV2 rhs_pair = __bfloat162bfloat162(rhs_raw);
|
||||
|
||||
NvTypeV2 const *lhs_ptr = reinterpret_cast<NvTypeV2 const *>(lhs.raw_data());
|
||||
NvTypeV2 *res_ptr = reinterpret_cast<NvTypeV2 *>(result.raw_data());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kVectorCount; ++i) {
|
||||
res_ptr[i] = __hmax2(lhs_ptr[i], rhs_pair);
|
||||
}
|
||||
|
||||
static_assert(!(ElementsPerAccess % 2), "Output array must be divisible by vector length.");
|
||||
|
||||
#else
|
||||
|
||||
NvType const *lhs_ptr = reinterpret_cast<NvType const *>(lhs.raw_data());
|
||||
NvType const rhs_raw = reinterpret_cast<NvType const &>(rhs);
|
||||
NvType *res_ptr = reinterpret_cast<NvType *>(result.raw_data());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < ElementsPerAccess; ++i) {
|
||||
res_ptr[i] = ((lhs_ptr[i] < rhs_raw) ? rhs_raw : lhs_ptr[i]);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Element, int ElementsPerAccess>
|
||||
@ -187,6 +326,25 @@ struct ReluConditional<half_t, ElementsPerAccess> {
|
||||
}
|
||||
};
|
||||
|
||||
template <int ElementsPerAccess>
|
||||
struct ReluConditional<bfloat16_t, ElementsPerAccess> {
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
bool conditional[],
|
||||
Array<bfloat16_t, ElementsPerAccess> const &fragment,
|
||||
bfloat16_t threshold) const {
|
||||
|
||||
__nv_bfloat16 y = reinterpret_cast<__nv_bfloat16 const &>(threshold);
|
||||
__nv_bfloat16 const *x = reinterpret_cast<__nv_bfloat16 const *>(fragment.raw_data());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < ElementsPerAccess; ++i) {
|
||||
conditional[i] = !__hlt(x[i], y);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -92,9 +92,9 @@ public:
|
||||
|
||||
ElementCompute alpha; ///< scales accumulators
|
||||
ElementCompute beta; ///< scales source tensor
|
||||
ElementCompute threshold; ///< minimum value that is output
|
||||
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
||||
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
||||
ElementCompute threshold; ///< minimum value that is output
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
@ -87,9 +87,9 @@ public:
|
||||
|
||||
ElementCompute alpha; ///< scales accumulators
|
||||
ElementCompute beta; ///< scales source tensor
|
||||
ElementCompute threshold; ///< minimum value that is output
|
||||
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
||||
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
||||
ElementCompute threshold; ///< minimum value that is output
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
@ -1698,13 +1698,13 @@ private:
|
||||
//
|
||||
|
||||
if (OutputOp::kStoreZ) {
|
||||
destination_iterator += reduce_fragment_idx;
|
||||
destination_iterator.store(frag_Z);
|
||||
++destination_iterator;
|
||||
}
|
||||
|
||||
if (OutputOp::kStoreT) {
|
||||
tensor_iterator += reduce_fragment_idx;
|
||||
tensor_iterator.store(frag_T);
|
||||
++tensor_iterator;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -187,7 +187,7 @@ CUTLASS_CONSTEXPR_IF_CXX17
|
||||
value_t lcm(value_t a, value_t b) {
|
||||
value_t temp = gcd(a, b);
|
||||
|
||||
return temp ? (cutlass::abs_for_integer(a) / temp * cutlass::abs_for_integer(b)) : 0;
|
||||
return temp ? (cutlass::abs_for_integer(a) / temp * cutlass::abs_for_integer(b)) : value_t();
|
||||
}
|
||||
|
||||
/**
|
||||
@ -207,8 +207,11 @@ template <typename value_t>
|
||||
CUTLASS_HOST_DEVICE
|
||||
constexpr
|
||||
value_t lcm_cxx11(value_t a, value_t b) {
|
||||
return gcd_cxx11(a, b) ? (cutlass::abs_for_integer(a) / gcd_cxx11(a, b) * cutlass::abs_for_integer(b)) : 0;
|
||||
return gcd_cxx11(a, b) ? (cutlass::abs_for_integer(a) / gcd_cxx11(a, b) *
|
||||
cutlass::abs_for_integer(b))
|
||||
: value_t();
|
||||
}
|
||||
|
||||
/// Returns the smallest value in the half-open range [a, a+b) that is a multiple of b
|
||||
CUTLASS_HOST_DEVICE
|
||||
CUTLASS_CONSTEXPR_IF_CXX17
|
||||
|
||||
@ -31,6 +31,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/detail/layout.hpp"
|
||||
#include "cutlass/detail/collective.hpp"
|
||||
|
||||
#include "cute/atom/mma_traits_sm90_gmma.hpp"
|
||||
#include "cute/atom/copy_traits_sm90_tma.hpp"
|
||||
@ -322,13 +323,21 @@ is_input_fp8() {
|
||||
(cute::is_same_v<ElementB, float_e4m3_t> || cute::is_same_v<ElementB, float_e5m2_t>));
|
||||
}
|
||||
|
||||
template <class ElementA, class LayoutA, class ElementB, class LayoutB>
|
||||
// We need to handle the tuples in this function since it is used in SFINAE dispatch in the CollectiveBuilder.
|
||||
// At that point, it is not guaranteed that the tuples have been split out into the required parts.
|
||||
template <class MaybeTupleElementA, class LayoutA, class MaybeTupleElementB, class LayoutB>
|
||||
constexpr bool
|
||||
is_use_rmem_A() {
|
||||
|
||||
using ElementA = detail::deduce_mixed_width_dtype_t<0, MaybeTupleElementA>;
|
||||
using ElementB = detail::deduce_mixed_width_dtype_t<0, MaybeTupleElementB>;
|
||||
|
||||
constexpr bool IsABDifferentWidth = cute::sizeof_bits_v<ElementA> != cute::sizeof_bits_v<ElementB>;
|
||||
constexpr bool HasScales = cute::is_tuple<MaybeTupleElementA>::value ^ cute::is_tuple<MaybeTupleElementB>::value;
|
||||
constexpr bool IsInputSizeTwoBytes = is_input_size_two_bytes<ElementA, ElementB>();
|
||||
constexpr bool IsLayoutAkBk = cutlass::gemm::detail::is_k_major_A<LayoutA>() &&
|
||||
cutlass::gemm::detail::is_k_major_B<LayoutB>();
|
||||
constexpr bool IsUseRmemA = !IsInputSizeTwoBytes && !IsLayoutAkBk;
|
||||
constexpr bool IsUseRmemA = (!IsInputSizeTwoBytes && !IsLayoutAkBk) || IsABDifferentWidth || HasScales;
|
||||
return IsUseRmemA;
|
||||
}
|
||||
|
||||
|
||||
@ -79,6 +79,50 @@ compute_stage_count_or_override(StageCountAutoCarveout<carveout_bytes> stage_cou
|
||||
return (CapacityBytes - carveout_bytes) / stage_bytes;
|
||||
}
|
||||
|
||||
// Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count.
|
||||
template<int CapacityBytes, class ElementA, class ElementB, class ElementScale, class ElementZero, class TileShapeMNK, int stages>
|
||||
constexpr int
|
||||
compute_stage_count_or_override_single_affine_transformed_input(StageCount<stages> stage_count) {
|
||||
return stages;
|
||||
}
|
||||
|
||||
template <class Element>
|
||||
constexpr int get_bits_for_possibly_void_element() {
|
||||
if constexpr (cute::is_same_v<Element, void>) {
|
||||
return 0;
|
||||
}
|
||||
else {
|
||||
return sizeof_bits<Element>::value;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count.
|
||||
template<int CapacityBytes, class ElementA, class ElementB, class ElementScale, class ElementZero, class TileShapeMNK, int carveout_bytes>
|
||||
constexpr int
|
||||
compute_stage_count_or_override_single_affine_transformed_input(StageCountAutoCarveout<carveout_bytes> stage_count) {
|
||||
|
||||
// 32 bytes to account for barriers etc.
|
||||
constexpr int stage_barrier_bytes = 32;
|
||||
constexpr int scale_zero_k_tile = 1;
|
||||
constexpr int a_bits = static_cast<int>(sizeof_bits<ElementA>::value);
|
||||
constexpr int b_bits = static_cast<int>(sizeof_bits<ElementB>::value);
|
||||
constexpr int s_bits = get_bits_for_possibly_void_element<ElementScale>();
|
||||
constexpr int z_bits = get_bits_for_possibly_void_element<ElementZero>();
|
||||
|
||||
constexpr int scale_bytes = (s_bits * size<0>(TileShapeMNK{}) * scale_zero_k_tile) / 8;
|
||||
constexpr int zero_bytes = (z_bits * size<0>(TileShapeMNK{}) * scale_zero_k_tile) / 8;
|
||||
static_assert(scale_bytes % 128 == 0, "Scale bytes must be a multiple of 128");
|
||||
static_assert(zero_bytes % 128 == 0, "Zero bytes must be a multiple of 128");
|
||||
|
||||
// When scales are void, s_bits will be 0 so no smem will be allocated for scales.
|
||||
constexpr int stage_bytes =
|
||||
(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 +
|
||||
(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 +
|
||||
scale_bytes + zero_bytes + stage_barrier_bytes;
|
||||
|
||||
return (CapacityBytes - carveout_bytes) / stage_bytes;
|
||||
}
|
||||
|
||||
template <class ElementA, class LayoutA, class ElementB, class LayoutB>
|
||||
constexpr bool
|
||||
is_swapAB(){
|
||||
@ -105,29 +149,6 @@ is_warpspecialized_transpose_B(){
|
||||
return IsWarpSpecializedTransposeB;
|
||||
}
|
||||
|
||||
template <typename ElementA, typename ElementB>
|
||||
struct Sm90TypeWidths {
|
||||
static constexpr bool IsElementALarger = (cute::sizeof_bits_v<ElementA>) > cute::sizeof_bits_v<ElementB>;
|
||||
using WideType = cute::conditional_t<IsElementALarger, ElementA, ElementB>;
|
||||
using NarrowType = cute::conditional_t<IsElementALarger, ElementB, ElementA>;
|
||||
};
|
||||
|
||||
|
||||
template <class ElementA, class LayoutA, class ElementB, class LayoutB>
|
||||
constexpr bool
|
||||
sm90_is_narrow_type_k_major() {
|
||||
using Widths = Sm90TypeWidths<ElementA, ElementB>;
|
||||
using NarrowType = typename Widths::NarrowType;
|
||||
using WideType = typename Widths::WideType;
|
||||
|
||||
constexpr bool IsANarrow = cute::is_same_v<NarrowType, ElementA>;
|
||||
constexpr cute::GMMA::Major NarrowGmmaMajor = IsANarrow ? detail::gmma_rs_tag_to_major_A<LayoutA>() :
|
||||
detail::gmma_rs_tag_to_major_B<LayoutB>();
|
||||
|
||||
constexpr bool IsNarrowLayoutKMajor = NarrowGmmaMajor == cute::GMMA::Major::K;
|
||||
return IsNarrowLayoutKMajor;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -163,18 +184,24 @@ struct CollectiveBuilder<
|
||||
cute::enable_if_t<
|
||||
(cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative>) &&
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelArrayTmaWarpSpecializedCooperative> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelGroupTmaWarpSpecializedCooperative>) &&
|
||||
not detail::is_use_rmem_A<ElementA, GmemLayoutA, ElementB, GmemLayoutB>()>
|
||||
> {
|
||||
static_assert(is_static<TileShape_MNK>::value);
|
||||
static_assert(is_static<ClusterShape_MNK>::value);
|
||||
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||
static_assert(cutlass::detail::dependent_false<ElementA> == 0, "Unsupported Toolkit for SM90 Collective Builder\n");
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
|
||||
#endif
|
||||
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
|
||||
"Should meet TMA alignment requirement\n");
|
||||
|
||||
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
|
||||
static constexpr bool IsArrayOfPointersGemm = (cute::is_same_v<KernelScheduleType, KernelArrayTmaWarpSpecializedCooperative> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelGroupTmaWarpSpecializedCooperative>);
|
||||
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
|
||||
static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm),
|
||||
"Kernel[Array/Group]TmaWarpSpecializedCooperative is only compatible with FP8 FastAccum version right now\n");
|
||||
|
||||
// For fp32 types, map to tf32 MMA value type
|
||||
using MmaElementA = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
|
||||
@ -183,7 +210,8 @@ static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
|
||||
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<MmaElementA, GmemLayoutA>();
|
||||
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<MmaElementB, GmemLayoutB>();
|
||||
|
||||
using AtomLayoutMNK = cute::conditional_t<cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative>,
|
||||
using AtomLayoutMNK = cute::conditional_t<
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative> || IsArrayOfPointersGemm,
|
||||
Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<
|
||||
@ -199,10 +227,12 @@ static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
|
||||
|
||||
static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes,
|
||||
MmaElementA, MmaElementB, TileShape_MNK>(StageCountType{});
|
||||
/* For FP8 use a separate mainloop compared to other datatypes */
|
||||
using DispatchPolicy = cute::conditional_t<IsFP8Input,
|
||||
MainloopSm90TmaGmmaWarpSpecializedFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
|
||||
MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>>;
|
||||
using DispatchPolicy = cute::conditional_t<IsArrayOfPointersGemm,
|
||||
MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
|
||||
/* For FP8 use a separate mainloop compared to other datatypes */
|
||||
cute::conditional_t<IsFP8Input,
|
||||
MainloopSm90TmaGmmaWarpSpecializedFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
|
||||
MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>>>;
|
||||
|
||||
using SmemCopyAtomA = void;
|
||||
using SmemCopyAtomB = void;
|
||||
@ -267,7 +297,7 @@ struct CollectiveBuilder<
|
||||
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
|
||||
"Should meet TMA alignment requirement\n");
|
||||
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||
static_assert(cutlass::detail::dependent_false<ElementA> == 0, "Unsupported Toolkit for SM90 Collective Builder\n");
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
|
||||
#endif
|
||||
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_rs_tag_to_major_A<GmemLayoutA>();
|
||||
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_rs_tag_to_major_B<GmemLayoutB>();
|
||||
@ -323,13 +353,13 @@ struct CollectiveBuilder<
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// GMMA_TMA_WS_RS Mixed GEMM
|
||||
// GMMA_TMA_WS_RS Mixed Scaled GEMM
|
||||
template <
|
||||
class ElementPairA_,
|
||||
class GmemLayoutPairA_,
|
||||
class GmemLayoutA_,
|
||||
int AlignmentA,
|
||||
class ElementPairB_,
|
||||
class GmemLayoutPairB_,
|
||||
class GmemLayoutB_,
|
||||
int AlignmentB,
|
||||
class ElementAccumulator,
|
||||
class TileShape_MNK,
|
||||
@ -341,10 +371,10 @@ struct CollectiveBuilder<
|
||||
arch::Sm90,
|
||||
arch::OpClassTensorOp,
|
||||
ElementPairA_,
|
||||
GmemLayoutPairA_,
|
||||
GmemLayoutA_,
|
||||
AlignmentA,
|
||||
ElementPairB_,
|
||||
GmemLayoutPairB_,
|
||||
GmemLayoutB_,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape_MNK,
|
||||
@ -357,22 +387,38 @@ struct CollectiveBuilder<
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeMixedInput>)>
|
||||
> {
|
||||
|
||||
private:
|
||||
using ScaleA = detail::deduce_mixed_width_dtype_t<1, ElementPairA_>;
|
||||
using ScaleB = detail::deduce_mixed_width_dtype_t<1, ElementPairB_>;
|
||||
using ZeroA = detail::deduce_mixed_width_dtype_t<2, ElementPairA_>;
|
||||
using ZeroB = detail::deduce_mixed_width_dtype_t<2, ElementPairB_>;
|
||||
static constexpr bool NeitherIsTuple = !cute::is_tuple<ElementPairA_>::value && !cute::is_tuple<ElementPairB_>::value;
|
||||
|
||||
public:
|
||||
static constexpr bool IsATransformed = cute::sizeof_bits_v<ElementPairA_> < cute::sizeof_bits_v<ElementPairB_>;
|
||||
using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementPairA_>;
|
||||
using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementPairB_>;
|
||||
static_assert(cute::is_tuple<ElementPairA_>::value ^ cute::is_tuple<ElementPairB_>::value ||
|
||||
(NeitherIsTuple && (sizeof_bits<ElementA>::value != sizeof_bits<ElementB>::value)),
|
||||
"Either A OR B must be a tuple or the widths of A and B must be different.");
|
||||
|
||||
// Split out items for processessing, no splitting for now since scales aren't supported.
|
||||
using ElementA = ElementPairA_;
|
||||
using ElementB = ElementPairB_;
|
||||
static constexpr bool IsANarrow = sizeof_bits<ElementA>::value < sizeof_bits<ElementB>::value;
|
||||
|
||||
using GmemLayoutA = GmemLayoutPairA_;
|
||||
using GmemLayoutB = GmemLayoutPairB_;
|
||||
using GmemLayoutA = GmemLayoutA_;
|
||||
using GmemLayoutB = GmemLayoutB_;
|
||||
|
||||
using ElementPairA = cute::conditional_t<IsANarrow && NeitherIsTuple, cute::tuple<ElementA>, ElementPairA_>;
|
||||
using ElementPairB = cute::conditional_t<!IsANarrow && NeitherIsTuple, cute::tuple<ElementB>, ElementPairB_>;
|
||||
|
||||
static constexpr bool IsATransformed = cute::is_tuple<ElementPairA>::value;
|
||||
using ElementScale = cute::conditional_t<IsATransformed, ScaleA, ScaleB>;
|
||||
using ElementZero = cute::conditional_t<IsATransformed, ZeroA, ZeroB>;
|
||||
|
||||
static_assert(is_static<TileShape_MNK>::value);
|
||||
static_assert(is_static<ClusterShape_MNK>::value);
|
||||
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
|
||||
"Should meet TMA alignment requirement\n");
|
||||
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||
static_assert(cutlass::detail::dependent_false<ElementA> == 0, "Unsupported Toolkit for SM90 Collective Builder\n");
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
|
||||
#endif
|
||||
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_rs_tag_to_major_A<GmemLayoutA>();
|
||||
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_rs_tag_to_major_B<GmemLayoutB>();
|
||||
@ -382,12 +428,7 @@ public:
|
||||
|
||||
// If A is scaled, then we don't need to swap. Otherwise, we must ensure B goes to RF and we must swap the operands.
|
||||
static constexpr bool SwapAB = !IsATransformed;
|
||||
static_assert(detail::sm90_is_narrow_type_k_major<ElementA, GmemLayoutA, ElementB, GmemLayoutB>(), "The narrow type must be K-major.");
|
||||
|
||||
static_assert((IsATransformed && (cute::sizeof_bits_v<ElementA> <= 8) && (sizeof(ElementB) == 2)) ||
|
||||
(!IsATransformed && (cute::sizeof_bits_v<ElementB> <= 8) && (sizeof(ElementA) == 2)) ||
|
||||
(GmmaMajorA == cute::GMMA::Major::K && GmmaMajorB == cute::GMMA::Major::K),
|
||||
"The unscaled element must be 2 bytes OR both inputs must be K-major");
|
||||
// When we relax the above assertion, we must handle setting the tile mma GmmaMajorB correctly.
|
||||
static constexpr cute::GMMA::Major TiledMmaGmmaMajorB = SwapAB ? GmmaMajorA : GmmaMajorB;
|
||||
|
||||
@ -400,6 +441,7 @@ public:
|
||||
|
||||
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
|
||||
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
|
||||
|
||||
using SmemLayoutAtomA = decltype(detail::rs_smem_selector<GmmaMajorA, ElementA,
|
||||
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), IsWarpSpecializedTransposeB>());
|
||||
using SmemLayoutAtomB = decltype(detail::rs_smem_selector<GmmaMajorB, ElementB,
|
||||
@ -407,8 +449,8 @@ public:
|
||||
|
||||
using RealElementA = cute::conditional_t<SwapAB, ElementB, ElementA>;
|
||||
using RealElementB = cute::conditional_t<SwapAB, ElementA, ElementB>;
|
||||
static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes,
|
||||
RealElementA, RealElementB, TileShape_MNK>(StageCountType{});
|
||||
static constexpr int PipelineStages = detail::compute_stage_count_or_override_single_affine_transformed_input<detail::sm90_smem_capacity_bytes,
|
||||
RealElementA, RealElementB, ElementScale, ElementZero, TileShape_MNK>(StageCountType{});
|
||||
|
||||
using SmemCopyAtomA = cute::conditional_t<SwapAB, void, Copy_Atom<cute::DefaultCopy, ElementA>>;
|
||||
using SmemCopyAtomB = cute::conditional_t<SwapAB, Copy_Atom<cute::DefaultCopy, ElementB>, void>;
|
||||
@ -416,35 +458,24 @@ public:
|
||||
using DispatchPolicy = MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput<PipelineStages, ClusterShape_MNK, KernelScheduleType>;
|
||||
|
||||
// We pack the scale data with the operand that will be optionally scaled and converted before MMA.
|
||||
using StrideAPair = TagToStrideA_t<GmemLayoutA>;
|
||||
using StrideBPair = TagToStrideB_t<GmemLayoutB>;
|
||||
using StrideA = TagToStrideA_t<GmemLayoutA>;
|
||||
using StrideB = TagToStrideB_t<GmemLayoutB>;
|
||||
|
||||
using GmemTiledCopyAPair = GmemTiledCopyA;
|
||||
using SmemLayoutAtomAPair = SmemLayoutAtomA;
|
||||
using SmemCopyAtomAPair = SmemCopyAtomA;
|
||||
|
||||
using GmemTiledCopyBPair = GmemTiledCopyB;
|
||||
using SmemLayoutAtomBPair = SmemLayoutAtomB;
|
||||
using SmemCopyAtomBPair = SmemCopyAtomB;
|
||||
|
||||
|
||||
// If the src type of the converter is the same as ElementA,
|
||||
// interpret this as if the user wanted to apply the scale to the A matrix.
|
||||
using CollectiveOp = CollectiveMma<
|
||||
DispatchPolicy,
|
||||
TileShape_MNK,
|
||||
ElementPairA_,
|
||||
StrideAPair,
|
||||
ElementPairB_,
|
||||
StrideBPair,
|
||||
ElementPairA,
|
||||
StrideA,
|
||||
ElementPairB,
|
||||
StrideB,
|
||||
TiledMma,
|
||||
GmemTiledCopyAPair,
|
||||
SmemLayoutAtomAPair,
|
||||
SmemCopyAtomAPair,
|
||||
GmemTiledCopyA,
|
||||
SmemLayoutAtomA,
|
||||
SmemCopyAtomA,
|
||||
cute::identity,
|
||||
GmemTiledCopyBPair,
|
||||
SmemLayoutAtomBPair,
|
||||
SmemCopyAtomBPair,
|
||||
GmemTiledCopyB,
|
||||
SmemLayoutAtomB,
|
||||
SmemCopyAtomB,
|
||||
cute::identity
|
||||
>;
|
||||
|
||||
@ -483,7 +514,9 @@ struct CollectiveBuilder<
|
||||
cute::enable_if_t<
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedFP8FastAccum> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpongFP8FastAccum> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8FastAccum>>
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8FastAccum> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelArrayTmaWarpSpecializedCooperativeFP8FastAccum> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelGroupTmaWarpSpecializedCooperativeFP8FastAccum>>
|
||||
> {
|
||||
static_assert(is_static<TileShape_MNK>::value);
|
||||
static_assert(is_static<ClusterShape_MNK>::value);
|
||||
@ -495,13 +528,16 @@ struct CollectiveBuilder<
|
||||
static_assert(!detail::is_use_rmem_A<ElementA, GmemLayoutA, ElementB, GmemLayoutB>(),
|
||||
"Not supported for fp8 non-TN warp specialized kernels yet\n");
|
||||
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||
static_assert(cutlass::detail::dependent_false<ElementA> == 0, "Unsupported Toolkit for SM90 Collective Builder\n");
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
|
||||
#endif
|
||||
|
||||
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementA, GmemLayoutA>();
|
||||
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementB, GmemLayoutB>();
|
||||
|
||||
using AtomLayoutMNK = cute::conditional_t<cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8FastAccum>,
|
||||
static constexpr bool IsArrayOfPointersGemm = (cute::is_same_v<KernelScheduleType, KernelArrayTmaWarpSpecializedCooperativeFP8FastAccum> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelGroupTmaWarpSpecializedCooperativeFP8FastAccum>);
|
||||
using AtomLayoutMNK = cute::conditional_t<cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8FastAccum> ||
|
||||
IsArrayOfPointersGemm,
|
||||
Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<
|
||||
@ -517,8 +553,9 @@ struct CollectiveBuilder<
|
||||
|
||||
static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes,
|
||||
ElementA, ElementB, TileShape_MNK>(StageCountType{});
|
||||
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecialized<
|
||||
PipelineStages, ClusterShape_MNK, KernelScheduleType>;
|
||||
using DispatchPolicy = cute::conditional_t<IsArrayOfPointersGemm,
|
||||
MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
|
||||
MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>>;
|
||||
|
||||
using SmemCopyAtomA = void;
|
||||
using SmemCopyAtomB = void;
|
||||
@ -580,7 +617,7 @@ struct CollectiveBuilder<
|
||||
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
|
||||
"Should meet TMA alignment requirement\n");
|
||||
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||
static_assert(cutlass::detail::dependent_false<ElementA> == 0, "Unsupported Toolkit for SM90 Collective Builder\n");
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
|
||||
#endif
|
||||
|
||||
// For fp32 types, map to tf32 MMA value type
|
||||
@ -721,7 +758,7 @@ struct CollectiveBuilder<
|
||||
static_assert(is_static<TileShape_MNK>::value);
|
||||
static_assert(is_static<ClusterShape_MNK>::value);
|
||||
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||
static_assert(cutlass::detail::dependent_false<ElementA> == 0, "Unsupported Toolkit for SM90 Collective Builder\n");
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
|
||||
#endif
|
||||
|
||||
// For fp32 types, map to tf32 MMA value type
|
||||
@ -819,7 +856,7 @@ struct CollectiveBuilder<
|
||||
static_assert(is_static<TileShape_MNK>::value);
|
||||
static_assert(is_static<ClusterShape_MNK>::value);
|
||||
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||
static_assert(cutlass::detail::dependent_false<ElementA> == 0, "Unsupported Toolkit for SM90 Collective Builder\n");
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
|
||||
#endif
|
||||
|
||||
// For fp32 types, map to tf32 MMA value type
|
||||
@ -918,13 +955,20 @@ struct CollectiveBuilder<
|
||||
static_assert(is_static<TileShape_MNK>::value);
|
||||
static_assert(is_static<ClusterShape_MNK>::value);
|
||||
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||
static_assert(cutlass::detail::dependent_false<ElementA> == 0, "Unsupported Toolkit for SM90 Collective Builder\n");
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
|
||||
#endif
|
||||
|
||||
static constexpr bool IsTmaCompatible = detail::is_aligned<
|
||||
ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>();
|
||||
using ExtractedElementA = detail::deduce_mixed_width_dtype_t<0, ElementA>;
|
||||
using ExtractedElementB = detail::deduce_mixed_width_dtype_t<0, ElementB>;
|
||||
|
||||
static constexpr bool IsMixedWidthInput = cute::sizeof_bits_v<ElementA> != cute::sizeof_bits_v<ElementB>;
|
||||
static constexpr bool IsTmaCompatible = detail::is_aligned<
|
||||
ExtractedElementA, AlignmentA, ExtractedElementB, AlignmentB, detail::tma_alignment_bytes>();
|
||||
|
||||
// Users opt into scales via the builder by passing a tuple of Elements for the input that will be scaled. We detect
|
||||
// scale support if ONLY one of the inputs have tuples to describe them.
|
||||
static constexpr bool OnlyOneIsTuple = cute::is_tuple<ElementA>::value ^ cute::is_tuple<ElementB>::value;
|
||||
static constexpr bool IsDifferentWidth = sizeof_bits<ExtractedElementA>::value != sizeof_bits<ExtractedElementB>::value;
|
||||
static constexpr bool IsMixedWidthInput = IsDifferentWidth || (IsDifferentWidth && OnlyOneIsTuple);
|
||||
|
||||
#if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 1)))
|
||||
// Persistent schedules perform best for CUDA Toolkits with version >= 12.1
|
||||
|
||||
@ -56,7 +56,7 @@ template <
|
||||
class TransformB
|
||||
>
|
||||
struct CollectiveMma {
|
||||
static_assert(cutlass::detail::dependent_false<ElementA> == 0, "Could not find a mainloop specialization.");
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>, "Could not find a mainloop specialization.");
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -73,5 +73,6 @@ struct CollectiveMma {
|
||||
#include "cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp"
|
||||
#include "cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp"
|
||||
#include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp"
|
||||
#include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp"
|
||||
#include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp"
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -0,0 +1,741 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/arch/cluster_sm90.hpp"
|
||||
#include "cute/arch/copy_sm90.hpp"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
|
||||
#include "cute/algorithm/functional.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cute/algorithm/gemm.hpp"
|
||||
#include "cute/tensor_predicate.hpp"
|
||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
using namespace cute;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// WarpSpecialized Mainloop
|
||||
template <
|
||||
int Stages,
|
||||
class ClusterShape,
|
||||
class KernelSchedule,
|
||||
class TileShape_,
|
||||
class ElementA_,
|
||||
class StrideA_,
|
||||
class ElementB_,
|
||||
class StrideB_,
|
||||
class TiledMma_,
|
||||
class GmemTiledCopyA_,
|
||||
class SmemLayoutAtomA_,
|
||||
class SmemCopyAtomA_,
|
||||
class TransformA_,
|
||||
class GmemTiledCopyB_,
|
||||
class SmemLayoutAtomB_,
|
||||
class SmemCopyAtomB_,
|
||||
class TransformB_>
|
||||
struct CollectiveMma<
|
||||
MainloopSm90ArrayTmaGmmaWarpSpecialized<Stages, ClusterShape, KernelSchedule>,
|
||||
TileShape_,
|
||||
ElementA_,
|
||||
StrideA_,
|
||||
ElementB_,
|
||||
StrideB_,
|
||||
TiledMma_,
|
||||
GmemTiledCopyA_,
|
||||
SmemLayoutAtomA_,
|
||||
SmemCopyAtomA_,
|
||||
TransformA_,
|
||||
GmemTiledCopyB_,
|
||||
SmemLayoutAtomB_,
|
||||
SmemCopyAtomB_,
|
||||
TransformB_>
|
||||
{
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using DispatchPolicy = MainloopSm90ArrayTmaGmmaWarpSpecialized<Stages, ClusterShape, KernelSchedule>;
|
||||
using TileShape = TileShape_;
|
||||
using ElementA = ElementA_;
|
||||
using StrideA = StrideA_;
|
||||
using ElementB = ElementB_;
|
||||
using StrideB = StrideB_;
|
||||
using TiledMma = TiledMma_;
|
||||
using ElementAccumulator = typename TiledMma::ValTypeC;
|
||||
using GmemTiledCopyA = GmemTiledCopyA_;
|
||||
using GmemTiledCopyB = GmemTiledCopyB_;
|
||||
using SmemLayoutAtomA = SmemLayoutAtomA_;
|
||||
using SmemLayoutAtomB = SmemLayoutAtomB_;
|
||||
using SmemCopyAtomA = SmemCopyAtomA_;
|
||||
using SmemCopyAtomB = SmemCopyAtomB_;
|
||||
using TransformA = TransformA_;
|
||||
using TransformB = TransformB_;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
|
||||
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
|
||||
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
|
||||
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
|
||||
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
// Tile along modes in a way that maximizes the TMA box size.
|
||||
using SmemLayoutA = decltype(tile_to_shape(
|
||||
SmemLayoutAtomA{},
|
||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
||||
conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
||||
using SmemLayoutB = decltype(tile_to_shape(
|
||||
SmemLayoutAtomB{},
|
||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
||||
conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
||||
|
||||
static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more.");
|
||||
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value &&
|
||||
cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
|
||||
"MMA atom must source both A and B operand from smem_desc for this mainloop.");
|
||||
static_assert(cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
|
||||
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
||||
static_assert(cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
|
||||
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
||||
|
||||
// TMA converts f32 input to tf32 when copying from GMEM to SMEM
|
||||
// For all other types, cast to size equivalent uint type to avoid any rounding by TMA.
|
||||
static constexpr bool ConvertF32toTF32A = cute::is_same_v<float, ElementA>;
|
||||
static constexpr bool ConvertF32toTF32B = cute::is_same_v<float, ElementB>;
|
||||
using InternalElementA = cute::conditional_t<ConvertF32toTF32A, tfloat32_t, uint_bit_t<sizeof_bits_v<ElementA>>>;
|
||||
using InternalElementB = cute::conditional_t<ConvertF32toTF32B, tfloat32_t, uint_bit_t<sizeof_bits_v<ElementB>>>;
|
||||
|
||||
// Assumption: StrideA is congruent with Problem_MK
|
||||
using TMA_A = decltype(make_tma_copy(
|
||||
GmemTiledCopyA{},
|
||||
make_tensor(static_cast<InternalElementA const*>(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}),
|
||||
SmemLayoutA{}(_,_,cute::Int<0>{}),
|
||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any
|
||||
// Assumption: StrideB is congruent with Problem_NK
|
||||
using TMA_B = decltype(make_tma_copy(
|
||||
GmemTiledCopyB{},
|
||||
make_tensor(static_cast<InternalElementB const*>(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}),
|
||||
SmemLayoutB{}(_,_,cute::Int<0>{}),
|
||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
|
||||
|
||||
struct SharedStorage {
|
||||
struct TensorStorage : cute::aligned_struct<128> {
|
||||
cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A;
|
||||
cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B;
|
||||
} tensors;
|
||||
|
||||
struct TensorMapStorage : cute::aligned_struct<128> {
|
||||
cute::TmaDescriptor smem_tensormap_A;
|
||||
cute::TmaDescriptor smem_tensormap_B;
|
||||
} tensormaps;
|
||||
|
||||
using PipelineStorage = typename MainloopPipeline::SharedStorage;
|
||||
PipelineStorage pipeline;
|
||||
};
|
||||
using TensorStorage = typename SharedStorage::TensorStorage;
|
||||
using TensorMapStorage = typename SharedStorage::TensorMapStorage;
|
||||
using PipelineStorage = typename SharedStorage::PipelineStorage;
|
||||
|
||||
static constexpr bool IsGroupedGemmKernel = cute::is_base_of_v<KernelGroupTmaWarpSpecializedCooperative, KernelSchedule>;
|
||||
using StridesA = cute::conditional_t<IsGroupedGemmKernel, StrideA const*, StrideA>;
|
||||
using StridesB = cute::conditional_t<IsGroupedGemmKernel, StrideB const*, StrideB>;
|
||||
|
||||
// Host side kernel arguments
|
||||
struct Arguments {
|
||||
ElementA const** ptr_A;
|
||||
StridesA dA;
|
||||
ElementB const** ptr_B;
|
||||
StridesB dB;
|
||||
};
|
||||
|
||||
// Device side kernel params
|
||||
struct Params {
|
||||
TMA_A tma_load_a;
|
||||
TMA_B tma_load_b;
|
||||
void* tensormaps;
|
||||
InternalElementA const** ptr_A;
|
||||
StridesA dA;
|
||||
InternalElementB const** ptr_B;
|
||||
StridesB dB;
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(
|
||||
ProblemShape problem_shapes,
|
||||
Arguments const& args,
|
||||
void* workspace) {
|
||||
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(0), 1);
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
const uint32_t mock_L = 1;
|
||||
|
||||
// These tensor pointers are only used to create tensormap/tma desc.
|
||||
// This address to the tensor will be replaced with correct address before the initial tma load
|
||||
InternalElementA const* ptr_A_first_batch = reinterpret_cast<InternalElementA const*>(args.ptr_A);
|
||||
InternalElementB const* ptr_B_first_batch = reinterpret_cast<InternalElementA const*>(args.ptr_B);
|
||||
cudaError_t cuda_error = cudaGetLastError(); // clear previous error
|
||||
|
||||
StrideA stride_a;
|
||||
StrideB stride_b;
|
||||
if constexpr (IsGroupedGemmKernel) {
|
||||
// Strides for Grouped Gemm will be replaced prior to the first access regardless
|
||||
stride_a = StrideA{};
|
||||
stride_b = StrideB{};
|
||||
}
|
||||
else {
|
||||
stride_a = args.dA;
|
||||
stride_b = args.dB;
|
||||
}
|
||||
Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(M,K,mock_L), stride_a));
|
||||
Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(N,K,mock_L), stride_b));
|
||||
TMA_A tma_load_a = make_tma_copy(
|
||||
GmemTiledCopyA{},
|
||||
tensor_a,
|
||||
SmemLayoutA{}(_,_,cute::Int<0>{}),
|
||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
|
||||
TMA_B tma_load_b = make_tma_copy(
|
||||
GmemTiledCopyB{},
|
||||
tensor_b,
|
||||
SmemLayoutB{}(_,_,cute::Int<0>{}),
|
||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
|
||||
|
||||
void* tensormaps = workspace;
|
||||
|
||||
return {
|
||||
tma_load_a,
|
||||
tma_load_b,
|
||||
tensormaps,
|
||||
reinterpret_cast<InternalElementA const**>(args.ptr_A),
|
||||
args.dA,
|
||||
reinterpret_cast<InternalElementB const**>(args.ptr_B),
|
||||
args.dB
|
||||
};
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) {
|
||||
constexpr uint32_t NumInputTensors = 2;
|
||||
constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor);
|
||||
// Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies
|
||||
return (NumInputTensors * SizeOfCuTensorMap * sm_count);
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static cutlass::Status
|
||||
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
template<class ProblemShape>
|
||||
CUTLASS_HOST_DEVICE static bool
|
||||
can_implement(
|
||||
ProblemShape problem_shapes,
|
||||
Arguments const& args) {
|
||||
constexpr int tma_alignment_bits = 128;
|
||||
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
||||
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
|
||||
|
||||
bool implementable = true;
|
||||
// Check alignment for all problem sizes
|
||||
for (int i = 0; i < problem_shapes.groups(); i++) {
|
||||
auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1);
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
|
||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{});
|
||||
}
|
||||
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
|
||||
}
|
||||
return implementable;
|
||||
}
|
||||
|
||||
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
|
||||
static constexpr int K_PIPE_MMAS = 1;
|
||||
static constexpr uint32_t TmaTransactionBytes =
|
||||
(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(sizeof_bits<ElementA>::value)) / 8+
|
||||
(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(sizeof_bits<ElementB>::value)) / 8;
|
||||
|
||||
|
||||
// Set up the data needed by this collective for load and mma.
|
||||
// Returns a tuple of tensors. The collective and the kernel layer have the contract that the
|
||||
// returned tuple must contain at least two elements, with the first two elements being:
|
||||
// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
|
||||
// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
|
||||
// The rest of the tensors can be specified as needed by this collective.
|
||||
template <class ProblemShape_MNKL>
|
||||
CUTLASS_DEVICE auto
|
||||
load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const {
|
||||
using X = Underscore;
|
||||
// Separate out problem shape for convenience
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
const int32_t mock_L = 1;
|
||||
|
||||
// TMA requires special handling of strides to deal with coord codomain mapping
|
||||
// Represent the full tensors -- get these from TMA
|
||||
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,mock_L)); // (m,k,l)
|
||||
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,mock_L)); // (n,k,l)
|
||||
|
||||
// Make tiled views, defer the slice
|
||||
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
|
||||
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
|
||||
|
||||
return cute::make_tuple(gA_mkl, gB_nkl);
|
||||
}
|
||||
|
||||
// Perform a collective-scoped matrix multiply-accumulate
|
||||
// Producer Perspective
|
||||
template <
|
||||
class TensorA, class TensorB,
|
||||
class TensorMapA, class TensorMapB,
|
||||
class KTileIterator, class BlockCoord
|
||||
>
|
||||
CUTLASS_DEVICE void
|
||||
load(
|
||||
Params const& mainloop_params,
|
||||
MainloopPipeline pipeline,
|
||||
PipelineState smem_pipe_write,
|
||||
cute::tuple<TensorA, TensorB> const& load_inputs,
|
||||
cute::tuple<TensorMapA, TensorMapB> const& input_tensormaps,
|
||||
BlockCoord const& blk_coord,
|
||||
KTileIterator k_tile_iter, int k_tile_count,
|
||||
int thread_idx,
|
||||
uint32_t block_rank_in_cluster,
|
||||
TensorStorage& shared_tensors) {
|
||||
int warp_idx = canonical_warp_idx_sync();
|
||||
int warp_idx_in_warp_group = warp_idx % 4;
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
if (warp_idx_in_warp_group == 0 and lane_predicate) {
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
|
||||
//
|
||||
// Prepare the TMA loads for A and B
|
||||
//
|
||||
|
||||
constexpr uint32_t cluster_shape_x = get<0>(DispatchPolicy::ClusterShape());
|
||||
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
||||
|
||||
Tensor gA_mkl = get<0>(load_inputs);
|
||||
Tensor gB_nkl = get<1>(load_inputs);
|
||||
|
||||
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
|
||||
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
|
||||
|
||||
// Partition the inputs based on the current block coordinates.
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
||||
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
|
||||
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
|
||||
|
||||
// Applies the mapping from block_tma_a
|
||||
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
|
||||
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
|
||||
|
||||
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
|
||||
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
|
||||
|
||||
uint16_t mcast_mask_a = 0;
|
||||
uint16_t mcast_mask_b = 0;
|
||||
|
||||
// Issue TmaLoads
|
||||
// Maps the tile -> block, value
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int n = 0; n < size<1>(block_layout); ++n) {
|
||||
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int m = 0; m < size<0>(block_layout); ++m) {
|
||||
mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
// Mainloop
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for ( ; k_tile_count > 0; --k_tile_count)
|
||||
{
|
||||
// LOCK smem_pipe_write for _writing_
|
||||
pipeline.producer_acquire(smem_pipe_write);
|
||||
|
||||
//
|
||||
// Copy gmem to smem for *k_tile_iter
|
||||
//
|
||||
|
||||
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
|
||||
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
|
||||
|
||||
int write_stage = smem_pipe_write.index();
|
||||
copy(mainloop_params.tma_load_a.with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
|
||||
copy(mainloop_params.tma_load_b.with(get<1>(input_tensormaps), *tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
|
||||
++k_tile_iter;
|
||||
|
||||
// Advance smem_pipe_write
|
||||
++smem_pipe_write;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
|
||||
CUTLASS_DEVICE void
|
||||
load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) {
|
||||
int warp_idx = canonical_warp_idx_sync();
|
||||
int warp_idx_in_warp_group = warp_idx % 4;
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
// Issue the epilogue waits
|
||||
if (warp_idx_in_warp_group == 0 and lane_predicate) {
|
||||
// This helps avoid early exit of blocks in Cluster.
|
||||
// Waits for all stages to either be released (all
|
||||
// Consumer UNLOCKs), or if the stage was never used
|
||||
// then it would just be acquired since the phase was
|
||||
// still inverted from make_producer_start_state.
|
||||
pipeline.producer_tail(smem_pipe_write);
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a collective-scoped matrix multiply-accumulate
|
||||
/// Consumer Perspective
|
||||
template <
|
||||
class FrgTensorC
|
||||
>
|
||||
CUTLASS_DEVICE void
|
||||
mma(MainloopPipeline pipeline,
|
||||
PipelineState smem_pipe_read,
|
||||
FrgTensorC& accum,
|
||||
int k_tile_count,
|
||||
int thread_idx,
|
||||
TensorStorage& shared_tensors,
|
||||
Params const& mainloop_params) {
|
||||
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
|
||||
static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomA>,
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomB>,
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
||||
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
|
||||
//
|
||||
// Define C accumulators and A/B partitioning
|
||||
//
|
||||
|
||||
TiledMma tiled_mma;
|
||||
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
|
||||
|
||||
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
// Allocate "fragments/descriptors"
|
||||
Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
|
||||
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
|
||||
|
||||
//
|
||||
// PIPELINED MAIN LOOP
|
||||
//
|
||||
static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX),
|
||||
"ERROR : Incorrect number of MMAs in flight");
|
||||
|
||||
// We release buffers to producer warps(dma load) with some mmas in flight
|
||||
PipelineState smem_pipe_release = smem_pipe_read;
|
||||
|
||||
// Prologue GMMAs
|
||||
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
|
||||
warpgroup_fence_operand(accum);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue)
|
||||
{
|
||||
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
|
||||
int read_stage = smem_pipe_read.index();
|
||||
warpgroup_arrive();
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum);
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
|
||||
warpgroup_commit_batch();
|
||||
|
||||
++smem_pipe_read;
|
||||
}
|
||||
|
||||
warpgroup_fence_operand(accum);
|
||||
// Mainloop GMMAs
|
||||
k_tile_count -= prologue_mma_count;
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for ( ; k_tile_count > 0; --k_tile_count)
|
||||
{
|
||||
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
|
||||
//
|
||||
// Compute on k_tile
|
||||
//
|
||||
|
||||
int read_stage = smem_pipe_read.index();
|
||||
warpgroup_fence_operand(accum);
|
||||
warpgroup_arrive();
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum);
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
|
||||
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
|
||||
warpgroup_wait<K_PIPE_MMAS>();
|
||||
warpgroup_fence_operand(accum);
|
||||
|
||||
// UNLOCK smem_pipe_release, done _computing_ on it
|
||||
pipeline.consumer_release(smem_pipe_release);
|
||||
|
||||
// Advance smem_pipe_read and smem_pipe_release
|
||||
++smem_pipe_read;
|
||||
++smem_pipe_release;
|
||||
}
|
||||
|
||||
warpgroup_fence_operand(accum);
|
||||
}
|
||||
|
||||
/// Perform a Consumer Epilogue to release all buffers
|
||||
CUTLASS_DEVICE void
|
||||
mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) {
|
||||
// Prologue GMMAs
|
||||
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||
k_tile_count -= prologue_mma_count;
|
||||
|
||||
smem_pipe_release.advance(k_tile_count);
|
||||
|
||||
// Wait on all GMMAs to complete
|
||||
warpgroup_wait<0>();
|
||||
|
||||
for (int count = 0; count < prologue_mma_count; ++count) {
|
||||
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
||||
++smem_pipe_release;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Methods to perform different parts of TMA/Tensormap modifications
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE auto
|
||||
tensormaps_init(Params const& mainloop_params, int32_t const sm_count, int32_t const sm_idx) const {
|
||||
cute::TmaDescriptor* gmem_tensormap = reinterpret_cast<cute::TmaDescriptor*>(mainloop_params.tensormaps);
|
||||
|
||||
cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx];
|
||||
cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count];
|
||||
|
||||
if (cute::elect_one_sync()) {
|
||||
// Bringing tensormaps from params to gmem for modification later
|
||||
Tensor pA_tensormap = make_tensor(mainloop_params.tma_load_a.get_tma_descriptor(), Int<1>{}, Int<1>{});
|
||||
Tensor gA_tensormap = make_tensor(tma_desc_a, Int<1>{}, Int<1>{});
|
||||
Tensor pB_tensormap = make_tensor(mainloop_params.tma_load_b.get_tma_descriptor(), Int<1>{}, Int<1>{});
|
||||
Tensor gB_tensormap = make_tensor(tma_desc_b, Int<1>{}, Int<1>{});
|
||||
|
||||
copy(recast<uint128_t>(pA_tensormap), recast<uint128_t>(gA_tensormap));
|
||||
copy(recast<uint128_t>(pB_tensormap), recast<uint128_t>(gB_tensormap));
|
||||
}
|
||||
|
||||
return cute::make_tuple(tma_desc_a, tma_desc_b);
|
||||
}
|
||||
|
||||
// Bringing tensormaps to smem (to be done by single thread)
|
||||
template <class TensorMapA, class TensorMapB>
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
tensormaps_fetch_to_smem(
|
||||
TensorMapStorage& shared_tensormap,
|
||||
cute::tuple<TensorMapA, TensorMapB> const& input_tensormaps) const {
|
||||
Tensor gA_tensormap = make_tensor(make_gmem_ptr(get<0>(input_tensormaps)), Int<1>{}, Int<1>{});
|
||||
Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_A), Int<1>{}, Int<1>{});
|
||||
Tensor gB_tensormap = make_tensor(make_gmem_ptr(get<1>(input_tensormaps)), Int<1>{}, Int<1>{});
|
||||
Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_B), Int<1>{}, Int<1>{});
|
||||
|
||||
copy(recast<uint128_t>(gA_tensormap), recast<uint128_t>(sA_tensormap));
|
||||
copy(recast<uint128_t>(gB_tensormap), recast<uint128_t>(sB_tensormap));
|
||||
|
||||
cp_async_fence();
|
||||
cp_async_wait<0>();
|
||||
}
|
||||
|
||||
// Replace address for the global tensor (to be done by single thread)
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
tensormaps_replace_global_address(
|
||||
TensorMapStorage& shared_tensormap,
|
||||
Params const& mainloop_params,
|
||||
int32_t next_batch) {
|
||||
// Replacing global_address for the next batch
|
||||
cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_A,
|
||||
mainloop_params.ptr_A[next_batch]);
|
||||
cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_B,
|
||||
mainloop_params.ptr_B[next_batch]);
|
||||
}
|
||||
|
||||
// Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread)
|
||||
template <class ProblemShape_MNKL>
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
tensormaps_replace_global_tensor_properties(
|
||||
TensorMapStorage& shared_tensormap,
|
||||
Params const& mainloop_params,
|
||||
int32_t next_group,
|
||||
ProblemShape_MNKL problem_shape_mnkl) {
|
||||
const uint32_t M = get<0>(problem_shape_mnkl);
|
||||
const uint32_t N = get<1>(problem_shape_mnkl);
|
||||
const uint32_t K = get<2>(problem_shape_mnkl);
|
||||
// Only consider dimensions and strides that we need to recalculate and replace for each group
|
||||
constexpr int TensorRank = rank(ProblemShape_MNKL{}) - 1; // excluding either M or N
|
||||
static_assert(TensorRank == Int<3>{},
|
||||
"Descriptor modification for global dims & strides expects rank as 3.");
|
||||
cute::array<uint32_t, TensorRank> prob_shape_A = {1,1,1};
|
||||
cute::array<uint64_t, TensorRank> prob_stride_A = {0,0,0};
|
||||
cute::array<uint32_t, TensorRank> prob_shape_B = {1,1,1};
|
||||
cute::array<uint64_t, TensorRank> prob_stride_B = {0,0,0};
|
||||
|
||||
InternalElementA const* ptr_A = nullptr;
|
||||
Tensor tensor_a = make_tensor(ptr_A, make_shape(M,K,Int<1>{}), mainloop_params.dA[next_group]);
|
||||
|
||||
InternalElementB const* ptr_B = nullptr;
|
||||
Tensor tensor_b = make_tensor(ptr_B, make_shape(N,K,Int<1>{}), mainloop_params.dB[next_group]);
|
||||
|
||||
cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_a, tensor_a,
|
||||
prob_shape_A, prob_stride_A);
|
||||
cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_b, tensor_b,
|
||||
prob_shape_B, prob_stride_B);
|
||||
|
||||
cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormap.smem_tensormap_A,
|
||||
prob_shape_A,
|
||||
prob_stride_A);
|
||||
cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormap.smem_tensormap_B,
|
||||
prob_shape_B,
|
||||
prob_stride_B);
|
||||
}
|
||||
|
||||
template <class TensorMapA, class TensorMapB, class ProblemShape_MNKL>
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
tensormaps_perform_update(
|
||||
TensorMapStorage& shared_tensormap,
|
||||
Params const& mainloop_params,
|
||||
cute::tuple<TensorMapA, TensorMapB> const& input_tensormaps,
|
||||
ProblemShape_MNKL problem_shape_mnkl,
|
||||
int32_t next_batch) {
|
||||
if (cute::elect_one_sync()) {
|
||||
// Bringing tensormaps to smem
|
||||
tensormaps_fetch_to_smem(shared_tensormap, input_tensormaps);
|
||||
|
||||
// Replacing global_address for the next batch
|
||||
tensormaps_replace_global_address(shared_tensormap, mainloop_params, next_batch);
|
||||
|
||||
if constexpr (IsGroupedGemmKernel) {
|
||||
// Replacing global dims and strides for the next batch
|
||||
tensormaps_replace_global_tensor_properties(shared_tensormap,
|
||||
mainloop_params, next_batch, problem_shape_mnkl);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class TensorMapA, class TensorMapB>
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
tensormaps_cp_fence_release (
|
||||
TensorMapStorage& shared_tensormap,
|
||||
cute::tuple<TensorMapA, TensorMapB> const& input_tensormaps) {
|
||||
// Entire warp must do this (ie its aligned)
|
||||
tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormap.smem_tensormap_A);
|
||||
tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormap.smem_tensormap_B);
|
||||
}
|
||||
|
||||
template <class TensorMapA, class TensorMapB>
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
tensormaps_fence_acquire(cute::tuple<TensorMapA, TensorMapB> const& input_tensormaps) {
|
||||
cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps));
|
||||
cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps));
|
||||
}
|
||||
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -134,9 +134,7 @@ struct CollectiveMma<
|
||||
using TransformB = TransformB_;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
|
||||
using MainloopPipeline = cutlass::PipelineTmaAsync<
|
||||
DispatchPolicy::Stages,
|
||||
typename DispatchPolicy::ClusterShape>;
|
||||
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
|
||||
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
|
||||
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
@ -336,22 +334,20 @@ struct CollectiveMma<
|
||||
|
||||
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& mainloop_params)
|
||||
{
|
||||
static void prefetch_tma_descriptors(Params const& mainloop_params) {
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
|
||||
}
|
||||
|
||||
/// Set up the data needed by this collective for load and mma.
|
||||
/// Returns a tuple of tensors. The collective and the kernel layer have the contract
|
||||
/// that the tuple must contain at least two elements, with the first two elements being:
|
||||
/// Returned tuple must contain at least two elements, with the first two elements being:
|
||||
/// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
|
||||
/// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
|
||||
/// The rest of the tensors can be specified as needed by this collective.
|
||||
template <class ProblemShape_MNKL,
|
||||
class TileShapeMNK>
|
||||
template <class ProblemShape_MNKL>
|
||||
CUTLASS_DEVICE auto
|
||||
tile_input_tensors(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params, TileShapeMNK const& tileshape_mnk) {
|
||||
load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const {
|
||||
using X = Underscore;
|
||||
// Separate out problem shape for convenience
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
@ -362,8 +358,8 @@ struct CollectiveMma<
|
||||
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l)
|
||||
|
||||
// Make tiled views, defer the slice
|
||||
Tensor gA_mkl = local_tile(mA_mkl, tileshape_mnk, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
|
||||
Tensor gB_nkl = local_tile(mB_nkl, tileshape_mnk, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
|
||||
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
|
||||
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
|
||||
|
||||
return cute::make_tuple(gA_mkl, gB_nkl);
|
||||
}
|
||||
@ -379,7 +375,7 @@ struct CollectiveMma<
|
||||
Params const& mainloop_params,
|
||||
MainloopPipeline pipeline,
|
||||
PipelineState smem_pipe_write,
|
||||
cute::tuple<TensorA, TensorB> const& tiled_tensors,
|
||||
cute::tuple<TensorA, TensorB> const& load_inputs,
|
||||
BlockCoord const& blk_coord,
|
||||
KTileIterator k_tile_iter, int k_tile_count,
|
||||
int thread_idx,
|
||||
@ -405,16 +401,16 @@ struct CollectiveMma<
|
||||
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
|
||||
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
||||
|
||||
Tensor gA_mkl = get<0>(tiled_tensors);
|
||||
Tensor gB_nkl = get<1>(tiled_tensors);
|
||||
Tensor gA_mkl = get<0>(load_inputs);
|
||||
Tensor gB_nkl = get<1>(load_inputs);
|
||||
|
||||
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
|
||||
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
|
||||
|
||||
// Partition the inputs based on the current block coordinates.
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
||||
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
|
||||
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
|
||||
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
|
||||
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
|
||||
|
||||
// Applies the mapping from block_tma_a
|
||||
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -106,9 +106,7 @@ struct CollectiveMma<
|
||||
using TransformB = TransformB_;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
|
||||
using MainloopPipeline = cutlass::PipelineTmaAsync<
|
||||
DispatchPolicy::Stages,
|
||||
typename DispatchPolicy::ClusterShape>;
|
||||
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
|
||||
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
using PipelineState = typename cutlass::PipelineState<DispatchPolicy::Stages>;
|
||||
@ -147,8 +145,7 @@ struct CollectiveMma<
|
||||
using InternalElementA = cute::conditional_t<ConvertF32toTF32A, tfloat32_t, uint_bit_t<sizeof_bits_v<ElementA>>>;
|
||||
using InternalElementB = cute::conditional_t<ConvertF32toTF32B, tfloat32_t, uint_bit_t<sizeof_bits_v<ElementB>>>;
|
||||
|
||||
struct SharedStorage
|
||||
{
|
||||
struct SharedStorage {
|
||||
cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A;
|
||||
cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B;
|
||||
|
||||
@ -244,13 +241,13 @@ struct CollectiveMma<
|
||||
|
||||
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& mainloop_params)
|
||||
{
|
||||
static void prefetch_tma_descriptors(Params const& mainloop_params) {
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
|
||||
}
|
||||
|
||||
/// Perform a collective-scoped matrix multiply-accumulate
|
||||
/// Producer Perspective
|
||||
template <
|
||||
class TensorA, class TMA_LOAD_A,
|
||||
class TensorB, class TMA_LOAD_B,
|
||||
@ -281,8 +278,8 @@ struct CollectiveMma<
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
||||
|
||||
SharedStorage& storage = *reinterpret_cast<SharedStorage*>(shared_memory);
|
||||
Tensor sA = make_tensor(make_smem_ptr(storage.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(storage.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
Tensor sA = make_tensor(make_smem_ptr(storage.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(storage.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
|
||||
//
|
||||
// Prepare the TMA loads for A and B
|
||||
@ -295,11 +292,11 @@ struct CollectiveMma<
|
||||
auto block_tma_b = tma_load_b.get_slice(cluster_local_block_id.x);
|
||||
|
||||
// Applies the mapping from block_tma_a
|
||||
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
|
||||
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
|
||||
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
|
||||
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
|
||||
|
||||
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
|
||||
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
|
||||
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
|
||||
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
|
||||
|
||||
//
|
||||
// Prepare TMA membars and PREFETCH
|
||||
@ -335,9 +332,7 @@ struct CollectiveMma<
|
||||
params.is_leader = warp_group_thread_idx == 0;
|
||||
params.num_consumers = NumThreadsPerWarpGroup;
|
||||
|
||||
MainloopPipeline pipeline(
|
||||
storage.pipeline_storage,
|
||||
params);
|
||||
MainloopPipeline pipeline(storage.pipeline_storage, params, ClusterShape{});
|
||||
|
||||
// State variables used for iterating the circular buffer
|
||||
// smem_pipe_read / release is used by the consumer of SMEM data - i.e MMA
|
||||
|
||||
@ -107,9 +107,7 @@ struct CollectiveMma<
|
||||
using TransformB = TransformB_;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
|
||||
using MainloopPipeline = cutlass::PipelineTmaAsync<
|
||||
DispatchPolicy::Stages,
|
||||
typename DispatchPolicy::ClusterShape>;
|
||||
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
|
||||
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
|
||||
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
@ -255,22 +253,20 @@ struct CollectiveMma<
|
||||
|
||||
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& mainloop_params)
|
||||
{
|
||||
static void prefetch_tma_descriptors(Params const& mainloop_params) {
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
|
||||
}
|
||||
|
||||
/// Set up the data needed by this collective for load and mma.
|
||||
/// Returns a tuple of tensors. The collective and the kernel layer have the contract
|
||||
/// that the tuple must contain at least two elements, with the first two elements being:
|
||||
/// Returned tuple must contain at least two elements, with the first two elements being:
|
||||
/// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
|
||||
/// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
|
||||
/// The rest of the tensors can be specified as needed by this collective.
|
||||
template <class ProblemShape_MNKL,
|
||||
class TileShapeMNK>
|
||||
template <class ProblemShape_MNKL>
|
||||
CUTLASS_DEVICE auto
|
||||
tile_input_tensors(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params, TileShapeMNK const& tileshape_mnk) {
|
||||
load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const {
|
||||
using X = Underscore;
|
||||
// Separate out problem shape for convenience
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
@ -281,8 +277,8 @@ struct CollectiveMma<
|
||||
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l)
|
||||
|
||||
// Make tiled views, defer the slice
|
||||
Tensor gA_mkl = local_tile(mA_mkl, tileshape_mnk, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
|
||||
Tensor gB_nkl = local_tile(mB_nkl, tileshape_mnk, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
|
||||
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
|
||||
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
|
||||
|
||||
return cute::make_tuple(gA_mkl, gB_nkl);
|
||||
}
|
||||
@ -298,14 +294,12 @@ struct CollectiveMma<
|
||||
Params const& mainloop_params,
|
||||
MainloopPipeline pipeline,
|
||||
PipelineState smem_pipe_write,
|
||||
cute::tuple<TensorA, TensorB> const& tiled_tensors,
|
||||
cute::tuple<TensorA, TensorB> const& load_inputs,
|
||||
BlockCoord const& blk_coord,
|
||||
KTileIterator k_tile_iter, int k_tile_count,
|
||||
int thread_idx,
|
||||
uint32_t block_rank_in_cluster,
|
||||
TensorStorage& shared_tensors)
|
||||
{
|
||||
|
||||
TensorStorage& shared_tensors) {
|
||||
using namespace cute;
|
||||
int warp_idx = canonical_warp_idx_sync();
|
||||
int warp_idx_in_warp_group = warp_idx % 4;
|
||||
@ -322,16 +316,16 @@ struct CollectiveMma<
|
||||
constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape());
|
||||
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
||||
|
||||
Tensor gA_mkl = get<0>(tiled_tensors);
|
||||
Tensor gB_nkl = get<1>(tiled_tensors);
|
||||
Tensor gA_mkl = get<0>(load_inputs);
|
||||
Tensor gB_nkl = get<1>(load_inputs);
|
||||
|
||||
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
|
||||
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
|
||||
|
||||
// Partition the inputs based on the current block coordinates.
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
||||
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
|
||||
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
|
||||
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
|
||||
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
|
||||
|
||||
// Applies the mapping from block_tma_a
|
||||
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
|
||||
@ -386,10 +380,7 @@ struct CollectiveMma<
|
||||
|
||||
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
|
||||
CUTLASS_DEVICE void
|
||||
load_tail(
|
||||
MainloopPipeline pipeline,
|
||||
PipelineState smem_pipe_write)
|
||||
{
|
||||
load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) {
|
||||
int warp_idx = canonical_warp_idx_sync();
|
||||
int warp_idx_in_warp_group = warp_idx % 4;
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
@ -418,8 +409,7 @@ struct CollectiveMma<
|
||||
int k_tile_count,
|
||||
int thread_idx,
|
||||
TensorStorage& shared_tensors,
|
||||
Params const& mainloop_params)
|
||||
{
|
||||
Params const& mainloop_params) {
|
||||
using namespace cute;
|
||||
|
||||
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
|
||||
|
||||
@ -108,9 +108,7 @@ struct CollectiveMma<
|
||||
using TransformB = TransformB_;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
|
||||
using MainloopPipeline = cutlass::PipelineTmaAsync<
|
||||
DispatchPolicy::Stages,
|
||||
typename DispatchPolicy::ClusterShape>;
|
||||
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
|
||||
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
|
||||
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
@ -261,14 +259,12 @@ struct CollectiveMma<
|
||||
|
||||
/// Set up the data needed by this collective for load and mma.
|
||||
/// Returns a tuple of tensors. The collective and the kernel layer have the contract
|
||||
/// that the tuple must contain at least two elements, with the first two elements being:
|
||||
/// Returned tuple must contain at least two elements, with the first two elements being:
|
||||
/// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
|
||||
/// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
|
||||
/// The rest of the tensors can be specified as needed by this collective.
|
||||
template <class ProblemShape_MNKL,
|
||||
class TileShapeMNK>
|
||||
template <class ProblemShape_MNKL>
|
||||
CUTLASS_DEVICE auto
|
||||
tile_input_tensors(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params, TileShapeMNK const& tileshape_mnk) {
|
||||
load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const {
|
||||
using X = Underscore;
|
||||
// Separate out problem shape for convenience
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
@ -279,8 +275,8 @@ struct CollectiveMma<
|
||||
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l)
|
||||
|
||||
// Make tiled views, defer the slice
|
||||
Tensor gA_mkl = local_tile(mA_mkl, tileshape_mnk, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
|
||||
Tensor gB_nkl = local_tile(mB_nkl, tileshape_mnk, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
|
||||
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
|
||||
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
|
||||
|
||||
return cute::make_tuple(gA_mkl, gB_nkl);
|
||||
}
|
||||
@ -296,7 +292,7 @@ struct CollectiveMma<
|
||||
Params const& mainloop_params,
|
||||
MainloopPipeline pipeline,
|
||||
PipelineState smem_pipe_write,
|
||||
cute::tuple<TensorA, TensorB> const& tiled_tensors,
|
||||
cute::tuple<TensorA, TensorB> const& load_inputs,
|
||||
BlockCoord const& blk_coord,
|
||||
KTileIterator k_tile_iter, int k_tile_count,
|
||||
int thread_idx,
|
||||
@ -320,8 +316,8 @@ struct CollectiveMma<
|
||||
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
|
||||
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
||||
|
||||
Tensor gA_mkl = get<0>(tiled_tensors);
|
||||
Tensor gB_nkl = get<1>(tiled_tensors);
|
||||
Tensor gA_mkl = get<0>(load_inputs);
|
||||
Tensor gB_nkl = get<1>(load_inputs);
|
||||
|
||||
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
|
||||
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
|
||||
|
||||
@ -42,6 +42,7 @@
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/detail/layout.hpp"
|
||||
#include "cutlass/detail/mma.hpp"
|
||||
#include "cutlass/cuda_host_adapter.hpp"
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
#include "cutlass/cluster_launch.hpp"
|
||||
@ -106,9 +107,12 @@ public:
|
||||
using LayoutC = gemm::detail::StrideToLayoutTagC_t<typename GemmKernel::StrideC>;
|
||||
using LayoutD = gemm::detail::StrideToLayoutTagC_t<typename GemmKernel::StrideD>;
|
||||
|
||||
// NOTE: 3.0 kernels do not support complex transforms for now ...
|
||||
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
||||
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
||||
static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER;
|
||||
|
||||
static ComplexTransform const kTransformA = cute::is_same_v<typename GemmKernel::CollectiveMainloop::TransformA, cute::conjugate> ?
|
||||
ComplexTransform::kConjugate : ComplexTransform::kNone;
|
||||
static ComplexTransform const kTransformB = cute::is_same_v<typename GemmKernel::CollectiveMainloop::TransformB, cute::conjugate> ?
|
||||
ComplexTransform::kConjugate : ComplexTransform::kNone;
|
||||
|
||||
// Legacy: Assume MultiplyAdd only since we do not use this tag type in 3.0
|
||||
using MathOperator = cutlass::arch::OpMultiplyAdd;
|
||||
@ -141,17 +145,17 @@ public:
|
||||
static int const kThreadCount = GemmKernel::MaxThreadsPerBlock;
|
||||
|
||||
// Warp shape is not a primary API type in 3.x
|
||||
// But we can best approximate it by inspecting the TiledMma::TiledShape_MNK
|
||||
// But we can best approximate it by inspecting the TiledMma
|
||||
// For this, we make the assumption that we always have 4 warps along M, and rest along N, none along K
|
||||
// We also always round up the warp count to 4 if the tiled mma is smaller than 128 threads
|
||||
static constexpr int WarpsInMma = cute::max(4, cute::size(typename GemmKernel::TiledMma{}) / 32);
|
||||
static constexpr int WarpsInMma = cute::max(4, CUTE_STATIC_V(cute::size(typename GemmKernel::TiledMma{})) / 32);
|
||||
static constexpr int WarpsInMmaM = 4;
|
||||
static constexpr int WarpsInMmaN = cute::ceil_div(WarpsInMma, WarpsInMmaM);
|
||||
using WarpCount = cutlass::gemm::GemmShape<WarpsInMmaM, WarpsInMmaN, 1>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<
|
||||
cute::size<0>(typename CollectiveMainloop::TiledMma::TiledShape_MNK{}) / WarpsInMmaM,
|
||||
cute::size<1>(typename CollectiveMainloop::TiledMma::TiledShape_MNK{}) / WarpsInMmaN,
|
||||
cute::size<2>(typename CollectiveMainloop::TiledMma::TiledShape_MNK{})>;
|
||||
CUTE_STATIC_V(cute::tile_size<0>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaM,
|
||||
CUTE_STATIC_V(cute::tile_size<1>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaN,
|
||||
CUTE_STATIC_V(cute::tile_size<2>(typename CollectiveMainloop::TiledMma{}))>;
|
||||
|
||||
static int constexpr kStages = CollectiveMainloop::DispatchPolicy::Stages;
|
||||
|
||||
@ -270,7 +274,12 @@ public:
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status
|
||||
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
initialize(
|
||||
Arguments const& args,
|
||||
void* workspace = nullptr,
|
||||
cudaStream_t stream = nullptr,
|
||||
CudaHostAdapter *cuda_adapter = nullptr) {
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversal::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
@ -283,20 +292,33 @@ public:
|
||||
// Initialize the Params structure
|
||||
params_ = GemmKernel::to_underlying_arguments(args, workspace);
|
||||
|
||||
// account for dynamic smem capacity if needed
|
||||
int smem_size = GemmKernel::SharedStorageSize;
|
||||
if (smem_size >= (48 << 10)) {
|
||||
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
|
||||
cudaError_t result = cudaFuncSetAttribute(
|
||||
device_kernel<GemmKernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
// Don't set the function attributes - require the CudaHostAdapter to set it.
|
||||
if constexpr (kEnableCudaHostAdapter) {
|
||||
CUTLASS_ASSERT(cuda_adapter);
|
||||
return Status::kSuccess;
|
||||
}
|
||||
else {
|
||||
//
|
||||
// Account for dynamic smem capacity if needed
|
||||
//
|
||||
int smem_size = GemmKernel::SharedStorageSize;
|
||||
|
||||
CUTLASS_ASSERT(cuda_adapter == nullptr);
|
||||
|
||||
if (smem_size >= (48 << 10)) {
|
||||
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
|
||||
cudaError_t result = cudaFuncSetAttribute(
|
||||
device_kernel<GemmKernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
@ -317,7 +339,7 @@ public:
|
||||
/// Primary run() entry point API that is static allowing users to create and manage their own params.
|
||||
/// Supplied params struct must be construct by calling GemmKernel::to_underling_arguments()
|
||||
static Status
|
||||
run(Params& params, cudaStream_t stream = nullptr) {
|
||||
run(Params& params, cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) {
|
||||
CUTLASS_TRACE_HOST("GemmUniversal::run()");
|
||||
dim3 const block = GemmKernel::get_block_shape();
|
||||
dim3 const grid = get_grid_shape(params);
|
||||
@ -331,13 +353,54 @@ public:
|
||||
dim3 cluster(cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}),
|
||||
cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}),
|
||||
cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{}));
|
||||
void const* kernel = (void const*) device_kernel<GemmKernel>;
|
||||
|
||||
void* kernel_params[] = {¶ms};
|
||||
launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params);
|
||||
|
||||
if constexpr (kEnableCudaHostAdapter) {
|
||||
//
|
||||
// Use the cuda host adapter
|
||||
//
|
||||
CUTLASS_ASSERT(cuda_adapter);
|
||||
if (cuda_adapter) {
|
||||
|
||||
launch_result = cuda_adapter->launch(
|
||||
grid, cluster, block, smem_size, stream, kernel_params, 0
|
||||
);
|
||||
}
|
||||
else {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
CUTLASS_ASSERT(cuda_adapter == nullptr);
|
||||
void const* kernel = (void const*) device_kernel<GemmKernel>;
|
||||
|
||||
launch_result = ClusterLauncher::launch(
|
||||
grid, cluster, block, smem_size, stream, kernel, kernel_params);
|
||||
|
||||
}
|
||||
}
|
||||
else {
|
||||
launch_result = Status::kSuccess;
|
||||
device_kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params);
|
||||
if constexpr (kEnableCudaHostAdapter) {
|
||||
CUTLASS_ASSERT(cuda_adapter);
|
||||
if (cuda_adapter) {
|
||||
void* kernel_params[] = {¶ms};
|
||||
|
||||
launch_result = cuda_adapter->launch(
|
||||
grid, block, smem_size, stream, kernel_params, 0
|
||||
);
|
||||
|
||||
}
|
||||
else {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
else {
|
||||
CUTLASS_ASSERT(cuda_adapter == nullptr);
|
||||
device_kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params);
|
||||
}
|
||||
}
|
||||
|
||||
cudaError_t result = cudaGetLastError();
|
||||
@ -356,18 +419,27 @@ public:
|
||||
|
||||
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
||||
Status
|
||||
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
run(
|
||||
Arguments const& args,
|
||||
void* workspace = nullptr,
|
||||
cudaStream_t stream = nullptr,
|
||||
CudaHostAdapter *cuda_adapter = nullptr
|
||||
) {
|
||||
Status status = initialize(args, workspace, stream);
|
||||
if (Status::kSuccess == status) {
|
||||
status = run(params_, stream);
|
||||
status = run(params_, stream, cuda_adapter);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
||||
Status
|
||||
operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
return run(args, workspace, stream);
|
||||
operator()(
|
||||
Arguments const& args,
|
||||
void* workspace = nullptr,
|
||||
cudaStream_t stream = nullptr,
|
||||
CudaHostAdapter *cuda_adapter = nullptr) {
|
||||
return run(args, workspace, stream, cuda_adapter);
|
||||
}
|
||||
|
||||
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
||||
@ -387,7 +459,7 @@ public:
|
||||
////////////////////////////// CUTLASS 2.x API /////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmKernel_>
|
||||
template <class GemmKernel_>
|
||||
class GemmUniversalAdapter<
|
||||
GemmKernel_,
|
||||
cute::enable_if_t<not gemm::detail::IsCutlass3GemmKernel<GemmKernel_>::value>>
|
||||
@ -501,9 +573,14 @@ public:
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
Status initialize(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr,
|
||||
CudaHostAdapter *cuda_adapter = nullptr
|
||||
) {
|
||||
|
||||
return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream);
|
||||
return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream, cuda_adapter);
|
||||
}
|
||||
|
||||
/// Lightweight update given a subset of arguments.
|
||||
@ -513,13 +590,18 @@ public:
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr) {
|
||||
Status run(
|
||||
cudaStream_t stream = nullptr,
|
||||
CudaHostAdapter *cuda_adapter = nullptr) {
|
||||
|
||||
return underlying_operator_.run(stream);
|
||||
return underlying_operator_.run(stream, cuda_adapter);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(cudaStream_t stream = nullptr) {
|
||||
Status operator()(
|
||||
cudaStream_t stream = nullptr,
|
||||
CudaHostAdapter *cuda_adapter = nullptr) {
|
||||
|
||||
return run(stream);
|
||||
}
|
||||
|
||||
@ -527,12 +609,13 @@ public:
|
||||
Status operator()(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) {
|
||||
cudaStream_t stream = nullptr,
|
||||
CudaHostAdapter *cuda_adapter = nullptr) {
|
||||
|
||||
Status status = initialize(args, workspace, stream);
|
||||
Status status = initialize(args, workspace, stream, cuda_adapter);
|
||||
|
||||
if (status == Status::kSuccess) {
|
||||
status = run(stream);
|
||||
status = run(stream, cuda_adapter);
|
||||
}
|
||||
|
||||
return status;
|
||||
|
||||
@ -46,6 +46,7 @@
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
#include "cutlass/cuda_host_adapter.hpp"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.h"
|
||||
@ -69,6 +70,8 @@ class GemmUniversalBase {
|
||||
public:
|
||||
|
||||
using GemmKernel = GemmKernel_;
|
||||
static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER;
|
||||
|
||||
using ThreadblockShape = typename GemmKernel::Mma::Shape;
|
||||
|
||||
using ElementA = typename GemmKernel::ElementA;
|
||||
@ -295,7 +298,8 @@ public:
|
||||
Status initialize(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr)
|
||||
cudaStream_t stream = nullptr,
|
||||
CudaHostAdapter *cuda_adapter = nullptr)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBase::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
@ -323,9 +327,8 @@ public:
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr)
|
||||
Status run(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBase::run()");
|
||||
|
||||
@ -339,13 +342,27 @@ public:
|
||||
"block: (" << block << "), "
|
||||
"SMEM: (" << smem_size_ << ")");
|
||||
|
||||
Kernel2<GemmKernel><<<grid, block, smem_size_, stream>>>(params_);
|
||||
if constexpr (kEnableCudaHostAdapter) {
|
||||
CUTLASS_ASSERT(cuda_adapter);
|
||||
if (cuda_adapter) {
|
||||
void* kernel_params[] = {¶ms_};
|
||||
return cuda_adapter->launch(grid, block, smem_size_, stream, kernel_params, 0);
|
||||
}
|
||||
else {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
else {
|
||||
CUTLASS_ASSERT(cuda_adapter == nullptr);
|
||||
|
||||
// Query for errors
|
||||
cudaError_t result = cudaGetLastError();
|
||||
if (result != cudaSuccess) {
|
||||
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
Kernel2<GemmKernel><<<grid, block, smem_size_, stream>>>(params_);
|
||||
|
||||
// Query for errors
|
||||
cudaError_t result = cudaGetLastError();
|
||||
if (result != cudaSuccess) {
|
||||
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
@ -363,12 +380,13 @@ public:
|
||||
Status operator()(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr)
|
||||
cudaStream_t stream = nullptr,
|
||||
CudaHostAdapter *cuda_adapter = nullptr)
|
||||
{
|
||||
Status status = initialize(args, workspace, stream);
|
||||
|
||||
if (status == Status::kSuccess) {
|
||||
status = run(stream);
|
||||
status = run(stream, cuda_adapter);
|
||||
}
|
||||
|
||||
return status;
|
||||
|
||||
@ -53,6 +53,8 @@ struct KernelTma { };
|
||||
struct KernelTmaWarpSpecialized { };
|
||||
struct KernelTmaWarpSpecializedPingpong { };
|
||||
struct KernelTmaWarpSpecializedCooperative { };
|
||||
struct KernelArrayTmaWarpSpecializedCooperative { };
|
||||
struct KernelGroupTmaWarpSpecializedCooperative { };
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -65,6 +67,8 @@ struct KernelTmaWarpSpecializedCooperative { };
|
||||
struct KernelTmaWarpSpecializedFP8FastAccum : KernelTmaWarpSpecialized { };
|
||||
struct KernelTmaWarpSpecializedPingpongFP8FastAccum : KernelTmaWarpSpecializedPingpong { };
|
||||
struct KernelTmaWarpSpecializedCooperativeFP8FastAccum: KernelTmaWarpSpecializedCooperative { };
|
||||
struct KernelArrayTmaWarpSpecializedCooperativeFP8FastAccum : KernelArrayTmaWarpSpecializedCooperative { };
|
||||
struct KernelGroupTmaWarpSpecializedCooperativeFP8FastAccum : KernelGroupTmaWarpSpecializedCooperative { };
|
||||
|
||||
// Policies to opt into mixed type GEMMs
|
||||
struct KernelTmaWarpSpecializedMixedInput : KernelTmaWarpSpecialized { };
|
||||
@ -225,6 +229,23 @@ struct MainloopSm90TmaGmmaWarpSpecializedFP8
|
||||
"KernelSchedule must be one of the warp specialized policies");
|
||||
};
|
||||
|
||||
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp specialized dynamic schedule for Ptr-Array and Grouped Gemm
|
||||
template<
|
||||
int Stages_,
|
||||
class ClusterShape_ = Shape<_1,_1,_1>,
|
||||
class KernelSchedule = KernelGroupTmaWarpSpecializedCooperative
|
||||
>
|
||||
struct MainloopSm90ArrayTmaGmmaWarpSpecialized {
|
||||
constexpr static int Stages = Stages_;
|
||||
using ClusterShape = ClusterShape_;
|
||||
using ArchTag = arch::Sm90;
|
||||
using Schedule = KernelSchedule;
|
||||
static_assert(
|
||||
cute::is_base_of_v<KernelArrayTmaWarpSpecializedCooperative, KernelSchedule> ||
|
||||
cute::is_base_of_v<KernelGroupTmaWarpSpecializedCooperative, KernelSchedule>,
|
||||
"KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative policies");
|
||||
};
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm
|
||||
|
||||
@ -69,6 +69,7 @@ enum class GemmUniversalMode {
|
||||
kGemmSplitKParallel,
|
||||
kBatched,
|
||||
kArray,
|
||||
kGrouped,
|
||||
kInvalid
|
||||
};
|
||||
|
||||
|
||||
111
include/cutlass/gemm/group_array_problem_shape.hpp
Normal file
111
include/cutlass/gemm/group_array_problem_shape.hpp
Normal file
@ -0,0 +1,111 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief This file contains definitions and utility functions for describing problem shapes
|
||||
for 3.x Ptr-Array GEMMs and Grouped GEMMs.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/tensor_coord.h"
|
||||
|
||||
#include "cute/container/array.hpp"
|
||||
|
||||
#if ! defined(__CUDACC_RTC__)
|
||||
#include <initializer_list>
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class ProblemShape_>
|
||||
struct GroupProblemShape {
|
||||
using UnderlyingProblemShape = ProblemShape_;
|
||||
int32_t num_groups = 1;
|
||||
UnderlyingProblemShape* problem_shapes = nullptr;
|
||||
UnderlyingProblemShape const* host_problem_shapes = nullptr;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
int32_t groups() const { return num_groups; }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
UnderlyingProblemShape const
|
||||
get_problem_shape(int32_t group_idx) const {
|
||||
return problem_shapes[group_idx];
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
UnderlyingProblemShape const
|
||||
get_host_problem_shape(int32_t group_idx) const {
|
||||
return host_problem_shapes[group_idx];
|
||||
}
|
||||
};
|
||||
|
||||
template <class ProblemShape_>
|
||||
class ArrayProblemShape {
|
||||
public:
|
||||
using UnderlyingProblemShape = ProblemShape_;
|
||||
|
||||
ArrayProblemShape() = default;
|
||||
ArrayProblemShape(UnderlyingProblemShape ps) : problem_shape_(ps) {}
|
||||
|
||||
// Num of groups for Ptr-Array GEMM always remain one, just the number of batches (l) can vary
|
||||
// This is just to maintain uniformity with GroupProblemShape
|
||||
constexpr int32_t groups() const { return 1; }
|
||||
|
||||
UnderlyingProblemShape* problem_shapes() const {
|
||||
return &problem_shape_;
|
||||
}
|
||||
UnderlyingProblemShape const* host_problem_shapes() const {
|
||||
return &problem_shape_;
|
||||
}
|
||||
|
||||
// This is just to maintain uniformity with GroupProblemShape
|
||||
CUTLASS_HOST_DEVICE
|
||||
UnderlyingProblemShape const
|
||||
get_problem_shape(int32_t /* unused */ = 0) const {
|
||||
return problem_shape_;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
UnderlyingProblemShape const
|
||||
get_host_problem_shape(int32_t /* unused */ = 0) const {
|
||||
return problem_shape_;
|
||||
}
|
||||
private:
|
||||
UnderlyingProblemShape problem_shape_{};
|
||||
};
|
||||
|
||||
} // namespace cutlass::gemm
|
||||
@ -851,7 +851,9 @@ protected:
|
||||
}
|
||||
ptr_D += tile_work.tiled_coord.k() * params.batch_stride_D;
|
||||
if (ptr_Tensor) {
|
||||
ptr_Tensor += tile_work.tiled_coord.k() * params.batch_stride_Tensor;
|
||||
ptr_Tensor = ReferenceFactory<typename Epilogue::ElementTensor>::add_pointer_offset(
|
||||
ptr_Tensor,
|
||||
tile_work.tiled_coord.k() * params.batch_stride_Tensor);
|
||||
}
|
||||
if (ptr_Vector) {
|
||||
ptr_Vector += tile_work.tiled_coord.k() * params.batch_stride_Vector;
|
||||
@ -2024,7 +2026,9 @@ protected:
|
||||
ptr_C += tile_work.tiled_coord.k() * params.batch_stride_C;
|
||||
ptr_D += tile_work.tiled_coord.k() * params.batch_stride_D;
|
||||
if (ptr_Tensor) {
|
||||
ptr_Tensor += tile_work.tiled_coord.k() * params.batch_stride_Tensor;
|
||||
ptr_Tensor = ReferenceFactory<typename Epilogue::ElementTensor>::add_pointer_offset(
|
||||
ptr_Tensor,
|
||||
tile_work.tiled_coord.k() * params.batch_stride_Tensor);
|
||||
}
|
||||
if (ptr_Vector) {
|
||||
ptr_Vector += tile_work.tiled_coord.k() * params.batch_stride_Vector;
|
||||
|
||||
@ -67,9 +67,9 @@ class GemmUniversal<
|
||||
Epilogue_,
|
||||
ThreadblockSwizzle_,
|
||||
void,
|
||||
// 3.x kernels use the first template argument to define the ProblemShape tuple
|
||||
// 3.x kernels use the first template argument to define the ProblemShape
|
||||
// We use this invariant to SFINAE dispatch against either the 2.x API or the 3.x API
|
||||
cute::enable_if_t<not cute::is_tuple<Mma_>::value>
|
||||
cute::enable_if_t<not (cute::is_tuple<Mma_>::value || IsCutlass3ArrayKernel<Mma_>::value)>
|
||||
> {
|
||||
public:
|
||||
|
||||
|
||||
@ -61,6 +61,19 @@ template <
|
||||
>
|
||||
class GemmUniversal;
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// In cases where ProblemShape is not a tuple, this is used to check if the
|
||||
// underlying problem shape type is aliased within or not.
|
||||
// Used for dispatching GemmUniversal to 2.x API or 3.x API
|
||||
template <class ProblemShape, class = void>
|
||||
struct IsCutlass3ArrayKernel : cute::false_type { };
|
||||
|
||||
template <typename ProblemShape>
|
||||
struct IsCutlass3ArrayKernel<ProblemShape, cute::void_t<typename ProblemShape::UnderlyingProblemShape>>
|
||||
: cute::true_type { };
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::kernel
|
||||
@ -75,4 +88,5 @@ class GemmUniversal;
|
||||
#include "cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp"
|
||||
#include "cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp"
|
||||
#include "cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp"
|
||||
#include "cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp"
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -42,7 +42,7 @@
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
#include "cutlass/gemm/kernel/params_universal_base.h"
|
||||
|
||||
#include "cutlass/subbyte_reference.h"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -676,7 +676,9 @@ public:
|
||||
}
|
||||
ptr_D += threadblock_tile_offset.k() * params.batch_stride_D;
|
||||
if (ptr_Tensor) {
|
||||
ptr_Tensor += threadblock_tile_offset.k() * params.batch_stride_Tensor;
|
||||
ptr_Tensor = ReferenceFactory<typename Epilogue::ElementTensor>::add_pointer_offset(
|
||||
ptr_Tensor,
|
||||
threadblock_tile_offset.k() * params.batch_stride_Tensor);
|
||||
}
|
||||
if (ptr_Vector) {
|
||||
ptr_Vector += threadblock_tile_offset.k() * params.batch_stride_Vector;
|
||||
@ -1387,7 +1389,9 @@ public:
|
||||
ptr_C += threadblock_tile_offset.k() * params.batch_stride_C;
|
||||
ptr_D += threadblock_tile_offset.k() * params.batch_stride_D;
|
||||
if (ptr_Tensor) {
|
||||
ptr_Tensor += threadblock_tile_offset.k() * params.batch_stride_Tensor;
|
||||
ptr_Tensor = ReferenceFactory<typename Epilogue::ElementTensor>::add_pointer_offset(
|
||||
ptr_Tensor,
|
||||
threadblock_tile_offset.k() * params.batch_stride_Tensor);
|
||||
}
|
||||
if (ptr_Vector) {
|
||||
ptr_Vector += threadblock_tile_offset.k() * params.batch_stride_Vector;
|
||||
|
||||
@ -217,6 +217,8 @@ public:
|
||||
// Only used by device-level operator
|
||||
GemmCoord *host_problem_sizes;
|
||||
|
||||
bool allow_early_exit;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
@ -235,7 +237,8 @@ public:
|
||||
ldb(nullptr),
|
||||
ldc(nullptr),
|
||||
ldd(nullptr),
|
||||
host_problem_sizes(nullptr)
|
||||
host_problem_sizes(nullptr),
|
||||
allow_early_exit(false)
|
||||
{
|
||||
|
||||
}
|
||||
@ -256,7 +259,8 @@ public:
|
||||
typename LayoutB::Stride::LongIndex *ldb,
|
||||
typename LayoutC::Stride::LongIndex *ldc,
|
||||
typename LayoutC::Stride::LongIndex *ldd,
|
||||
GemmCoord *host_problem_sizes=nullptr
|
||||
GemmCoord *host_problem_sizes=nullptr,
|
||||
bool allow_early_exit=false
|
||||
):
|
||||
mode(mode),
|
||||
problem_sizes(problem_sizes),
|
||||
@ -271,7 +275,8 @@ public:
|
||||
ldb(ldb),
|
||||
ldc(ldc),
|
||||
ldd(ldd),
|
||||
host_problem_sizes(host_problem_sizes)
|
||||
host_problem_sizes(host_problem_sizes),
|
||||
allow_early_exit(allow_early_exit)
|
||||
{
|
||||
|
||||
}
|
||||
@ -303,6 +308,7 @@ public:
|
||||
typename LayoutC::Stride::LongIndex *ldc;
|
||||
typename LayoutC::Stride::LongIndex *ldd;
|
||||
|
||||
bool allow_early_exit;
|
||||
|
||||
//
|
||||
// Methods
|
||||
@ -318,7 +324,8 @@ public:
|
||||
lda(nullptr),
|
||||
ldb(nullptr),
|
||||
ldc(nullptr),
|
||||
ldd(nullptr)
|
||||
ldd(nullptr),
|
||||
allow_early_exit(false)
|
||||
{ }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -333,7 +340,8 @@ public:
|
||||
lda(args.lda),
|
||||
ldb(args.ldb),
|
||||
ldc(args.ldc),
|
||||
ldd(args.ldd)
|
||||
ldd(args.ldd),
|
||||
allow_early_exit(args.allow_early_exit)
|
||||
{
|
||||
|
||||
}
|
||||
@ -388,6 +396,12 @@ public:
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
// Early exit following LAPACK's definition
|
||||
if (params.allow_early_exit &&
|
||||
(params.output_op.alpha == ElementC(0)) && (params.output_op.beta == ElementC(1))) {
|
||||
return;
|
||||
}
|
||||
|
||||
//
|
||||
// Problem visitor.
|
||||
//
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user