3.6.0 update (#2005)
* 3.6.0 update * doc and swap stuff --------- Co-authored-by: yuzhai <yuzhai@nvidia.com> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
@ -118,6 +118,5 @@ class TestEVTCompute(EVTTestCaseBase):
|
||||
result_keys = ["D"]
|
||||
launcher.verify((m, n, k), input_keys, result_keys, l)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@ -131,6 +131,115 @@ set(header_files_to_check
|
||||
cute/atom/mma_traits_sm80.hpp
|
||||
cute/atom/mma_traits_sm90.hpp
|
||||
cute/atom/mma_traits_sm90_gmma.hpp
|
||||
# cutlass
|
||||
cutlass/aligned_buffer.h
|
||||
cutlass/array.h
|
||||
cutlass/array_planar_complex.h
|
||||
cutlass/array_subbyte.h
|
||||
cutlass/barrier.h
|
||||
cutlass/bfloat16.h
|
||||
cutlass/blas3.h
|
||||
cutlass/blas3_types.h
|
||||
cutlass/block_striped.h
|
||||
cutlass/cluster_launch.hpp
|
||||
cutlass/complex.h
|
||||
cutlass/constants.h
|
||||
cutlass/coord.h
|
||||
cutlass/core_io.h
|
||||
cutlass/cuda_host_adapter.hpp
|
||||
cutlass/cutlass.h
|
||||
cutlass/device_kernel.h
|
||||
cutlass/fast_math.h
|
||||
cutlass/float8.h
|
||||
# cutlass/floating_point_nvrtc.h
|
||||
cutlass/functional.h
|
||||
cutlass/gemm_coord.h
|
||||
cutlass/gemm_coord.hpp
|
||||
cutlass/half.h
|
||||
cutlass/integer_subbyte.h
|
||||
cutlass/kernel_hardware_info.h
|
||||
cutlass/kernel_hardware_info.hpp
|
||||
cutlass/kernel_launch.h
|
||||
cutlass/matrix.h
|
||||
cutlass/matrix_coord.h
|
||||
cutlass/matrix_shape.h
|
||||
cutlass/numeric_conversion.h
|
||||
cutlass/numeric_size.h
|
||||
cutlass/numeric_types.h
|
||||
cutlass/pitch_linear_coord.h
|
||||
cutlass/predicate.h
|
||||
cutlass/predicate_vector.h
|
||||
cutlass/quaternion.h
|
||||
cutlass/real.h
|
||||
cutlass/relatively_equal.h
|
||||
cutlass/semaphore.h
|
||||
cutlass/subbyte_reference.h
|
||||
cutlass/tensor_coord.h
|
||||
cutlass/tensor_ref.h
|
||||
cutlass/tensor_ref_planar_complex.h
|
||||
cutlass/tensor_view.h
|
||||
cutlass/tensor_view_planar_complex.h
|
||||
cutlass/tfloat32.h
|
||||
cutlass/trace.h
|
||||
cutlass/uint128.h
|
||||
cutlass/version.h
|
||||
cutlass/wmma_array.h
|
||||
cutlass/workspace.h
|
||||
# cutlass/platform
|
||||
cutlass/platform/platform.h
|
||||
|
||||
# cutlass/pipeline
|
||||
cutlass/pipeline/pipeline.hpp
|
||||
cutlass/pipeline/sm90_pipeline.hpp
|
||||
# cutlass/detail
|
||||
cutlass/detail/cluster.hpp
|
||||
cutlass/detail/collective.hpp
|
||||
cutlass/detail/dependent_false.hpp
|
||||
cutlass/detail/helper_macros.hpp
|
||||
cutlass/detail/layout.hpp
|
||||
cutlass/detail/mainloop_fusion_helper_bgrada.hpp
|
||||
cutlass/detail/mma.hpp
|
||||
# cutlass/arch
|
||||
cutlass/arch/arch.h
|
||||
cutlass/arch/barrier.h
|
||||
cutlass/arch/cache_operation.h
|
||||
cutlass/arch/config.h
|
||||
cutlass/arch/custom_abi.h
|
||||
cutlass/arch/grid_dependency_control.h
|
||||
cutlass/arch/memory.h
|
||||
# cutlass/arch/memory_sm75.h
|
||||
# cutlass/arch/memory_sm80.h
|
||||
cutlass/arch/mma.h
|
||||
# cutlass/arch/mma_sm50.h
|
||||
# cutlass/arch/mma_sm60.h
|
||||
# cutlass/arch/mma_sm61.h
|
||||
# cutlass/arch/mma_sm70.h
|
||||
# cutlass/arch/mma_sm75.h
|
||||
# cutlass/arch/mma_sm80.h
|
||||
# cutlass/arch/mma_sm89.h
|
||||
# cutlass/arch/mma_sm90.h
|
||||
cutlass/arch/mma_sparse_sm80.h
|
||||
cutlass/arch/mma_sparse_sm89.h
|
||||
# cutlass/arch/simd.h
|
||||
# cutlass/arch/simd_sm60.h
|
||||
# cutlass/arch/simd_sm61.h
|
||||
cutlass/arch/reg_reconfig.h
|
||||
cutlass/arch/tma_operation.h
|
||||
cutlass/arch/wmma.h
|
||||
# cutlass/arch/wmma_sm70.h
|
||||
# cutlass/arch/wmma_sm72.h
|
||||
# cutlass/arch/wmma_sm75.h
|
||||
# cutlass/arch/wmma_sm80.h
|
||||
# cutlass/layout
|
||||
cutlass/layout/layout.h
|
||||
cutlass/layout/matrix.h
|
||||
cutlass/layout/permute.h
|
||||
cutlass/layout/pitch_linear.h
|
||||
cutlass/layout/tensor.h
|
||||
cutlass/layout/tensor_op_multiplicand_sm70.h
|
||||
cutlass/layout/tensor_op_multiplicand_sm75.h
|
||||
cutlass/layout/tensor_op_multiplicand_sm80.h
|
||||
cutlass/layout/vector.h
|
||||
)
|
||||
|
||||
# for each header in _header_files:
|
||||
|
||||
@ -63,7 +63,7 @@ set(CUTLASS_TEST_UNIT_RESULTS_CACHE_DIR ${CMAKE_CURRENT_LIST_DIR}/data/hashes)
|
||||
|
||||
function(cutlass_test_unit_add_executable NAME)
|
||||
|
||||
set(options WITHOUT_CUDA)
|
||||
set(options WITHOUT_CUDA DO_NOT_LOWERCASE_TEST_NAME)
|
||||
set(oneValueArgs)
|
||||
set(multiValueArgs TEST_SETS_SUPPORTED EXTRA_INCLUDE_DIRS)
|
||||
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
@ -109,14 +109,22 @@ function(cutlass_test_unit_add_executable NAME)
|
||||
|
||||
set(CUTLASS_TEST_UNIT_TEST_COMMAND_OPTIONS --gtest_output=xml:${NAME_STEM}.gtest.xml)
|
||||
|
||||
if (__DO_NOT_LOWERCASE_TEST_NAME)
|
||||
set(DO_NOT_LOWERCASE_TEST_NAME DO_NOT_LOWERCASE_TEST_NAME)
|
||||
else()
|
||||
set(DO_NOT_LOWERCASE_TEST_NAME)
|
||||
endif()
|
||||
|
||||
cutlass_add_executable_tests(
|
||||
${NAME_STEM} ${NAME}
|
||||
TEST_SETS_SUPPORTED ${__TEST_SETS_SUPPORTED}
|
||||
TEST_COMMAND_OPTIONS CUTLASS_TEST_UNIT_TEST_COMMAND_OPTIONS
|
||||
${RESULT_CACHE_FILE_ARGS}
|
||||
${DO_NOT_LOWERCASE_TEST_NAME}
|
||||
)
|
||||
|
||||
endfunction()
|
||||
|
||||
add_custom_target(cutlass_test_unit)
|
||||
add_custom_target(test_unit)
|
||||
|
||||
|
||||
@ -87,7 +87,6 @@ void FilterArchitecture() {
|
||||
<< " [" << cudaGetErrorString(err) << "]" << std::endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
cudaDeviceProp deviceProperties;
|
||||
err = cudaGetDeviceProperties(&deviceProperties, cudaDeviceId);
|
||||
if (cudaSuccess != err) {
|
||||
|
||||
@ -1159,6 +1159,37 @@ std::vector<cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kDgrad, 1>>
|
||||
get_conv_problem_vector<1, cutlass::conv::Operator::kDgrad, true>() {
|
||||
using ProblemShape = cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kDgrad, 1>;
|
||||
std::vector<ProblemShape> problem_shapes;
|
||||
// Test TMA truncation
|
||||
problem_shapes.push_back({
|
||||
cutlass::conv::Mode::kCrossCorrelation,
|
||||
{1, 512, 64}, // nqk
|
||||
{64, 1, 64}, // ksc
|
||||
{0}, // padding lower (pad_w)
|
||||
{0}, // padding upper (pad_w)
|
||||
{2}, // stride (stride_w)
|
||||
{1}, // dilation (dilation_w)
|
||||
1 // group
|
||||
});
|
||||
problem_shapes.push_back({
|
||||
cutlass::conv::Mode::kCrossCorrelation,
|
||||
{1, 1024, 64}, // nqk
|
||||
{64, 1, 64}, // ksc
|
||||
{0}, // padding lower (pad_w)
|
||||
{0}, // padding upper (pad_w)
|
||||
{4}, // stride (stride_w)
|
||||
{1}, // dilation (dilation_w)
|
||||
1 // group
|
||||
});
|
||||
problem_shapes.push_back({
|
||||
cutlass::conv::Mode::kCrossCorrelation,
|
||||
{1, 2048, 64}, // nqk
|
||||
{64, 1, 64}, // ksc
|
||||
{0}, // padding lower (pad_w)
|
||||
{0}, // padding upper (pad_w)
|
||||
{8}, // stride (stride_w)
|
||||
{1}, // dilation (dilation_w)
|
||||
1 // group
|
||||
});
|
||||
// non-packed input/output strides.
|
||||
// stride divides dilation
|
||||
// asymmetric padding
|
||||
|
||||
@ -336,10 +336,17 @@ struct ConvTestbed {
|
||||
|
||||
// Scale
|
||||
if constexpr (cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::ScaledGELU_taylor<ElementCompute>> ||
|
||||
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::ScaledGELU<ElementCompute>>) {
|
||||
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::ScaledGELU<ElementCompute>> ||
|
||||
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::ScaledSiLu<ElementCompute>> ||
|
||||
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::ScaledHardSwish<ElementCompute>> ) {
|
||||
fusion_args.activation.scale = ElementCompute{1};
|
||||
}
|
||||
|
||||
// LeakyRelu
|
||||
if constexpr (cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::LeakyReLU<ElementCompute>> ) {
|
||||
fusion_args.activation.leaky_alpha = ElementCompute{0};
|
||||
}
|
||||
|
||||
cutlass::Status status = cutlass::Status::kInvalid;
|
||||
|
||||
status = conv_op.can_implement(args);
|
||||
@ -617,8 +624,9 @@ bool TestAllConv(double alpha = 1.0, double beta = 0.0, float epsilon = 0.0f
|
||||
for (DecompositionMode decomp_mode : decomposition_modes) {
|
||||
std::vector problem_splits = {Splits{1}};
|
||||
if constexpr (UsesStreamKScheduler) {
|
||||
if (decomp_mode == DecompositionMode::Heuristic || decomp_mode == DecompositionMode::SplitK) {
|
||||
if (decomp_mode == DecompositionMode::SplitK) {
|
||||
problem_splits.push_back(Splits{2});
|
||||
problem_splits.push_back(Splits{4});
|
||||
}
|
||||
}
|
||||
for (auto splits : problem_splits) {
|
||||
|
||||
@ -35,6 +35,7 @@
|
||||
#include "../common/cutlass_unit_test.h"
|
||||
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include <bitset>
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -28,7 +28,7 @@
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_cute_ampere
|
||||
cp_async.cu
|
||||
cp_sync.cu
|
||||
ldsm.cu
|
||||
cooperative_gemm.cu
|
||||
cooperative_copy.cu
|
||||
|
||||
@ -46,6 +46,7 @@
|
||||
#include <cute/swizzle.hpp> // cute::Swizzle
|
||||
#include <cute/swizzle_layout.hpp> // cute::compose(cute::Swizzle)
|
||||
#include <cute/numeric/numeric_types.hpp>
|
||||
#include <cute/atom/copy_traits_sm80.hpp>
|
||||
|
||||
using namespace cute;
|
||||
|
||||
@ -71,7 +72,7 @@ cooperative_copy_default_gs(T const* g_in, T* g_out, GMemLayout const& gmem_layo
|
||||
Tensor g_out_tensor = make_tensor(make_gmem_ptr(g_out), gmem_layout);
|
||||
Tensor s_tensor = make_tensor(make_smem_ptr(smem), smem_layout);
|
||||
|
||||
cooperative_copy<ThreadBlockSize, MaxVecBits>(threadIdx.x, g_in_tensor, s_tensor);
|
||||
cooperative_copy<ThreadBlockSize, MaxVecBits>(threadIdx.x, g_in_tensor, s_tensor, AutoCopyAsync{});
|
||||
|
||||
cp_async_fence();
|
||||
cp_async_wait<0>();
|
||||
@ -84,7 +85,7 @@ cooperative_copy_default_gs(T const* g_in, T* g_out, GMemLayout const& gmem_layo
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
cooperative_copy<ThreadBlockSize, MaxVecBits>(threadIdx.x, s_tensor, g_out_tensor);
|
||||
cooperative_copy<ThreadBlockSize, MaxVecBits>(threadIdx.x, s_tensor, g_out_tensor, AutoCopyAsync{});
|
||||
}
|
||||
|
||||
// ss --> shared to shared
|
||||
@ -106,7 +107,7 @@ cooperative_copy_default_ss(T const* g_in, T* g_out, Layout1 const& layout1, Lay
|
||||
Tensor s1_tensor = make_tensor(make_smem_ptr(smem1), layout2);
|
||||
Tensor s2_tensor = make_tensor(make_smem_ptr(smem2), layout1);
|
||||
|
||||
cooperative_copy<ThreadBlockSize, cute::sizeof_bits_v<T>>(threadIdx.x, g_in_tensor, s1_tensor);
|
||||
cooperative_copy<ThreadBlockSize, cute::sizeof_bits_v<T>>(threadIdx.x, g_in_tensor, s1_tensor, AutoCopyAsync{});
|
||||
|
||||
cp_async_fence();
|
||||
cp_async_wait<0>();
|
||||
@ -119,10 +120,10 @@ cooperative_copy_default_ss(T const* g_in, T* g_out, Layout1 const& layout1, Lay
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
cooperative_copy<ThreadBlockSize, MaxVecBits>(threadIdx.x, s1_tensor, s2_tensor);
|
||||
cooperative_copy<ThreadBlockSize, MaxVecBits>(threadIdx.x, s1_tensor, s2_tensor, AutoCopyAsync{});
|
||||
__syncthreads();
|
||||
|
||||
cooperative_copy<ThreadBlockSize, cute::sizeof_bits_v<T>>(threadIdx.x, s2_tensor, g_out_tensor);
|
||||
cooperative_copy<ThreadBlockSize, cute::sizeof_bits_v<T>>(threadIdx.x, s2_tensor, g_out_tensor, AutoCopyAsync{});
|
||||
}
|
||||
|
||||
// gg --> global to global
|
||||
@ -135,7 +136,7 @@ cooperative_copy_default_gg(T const* g_in, T* g_out, Layout1 const& layout1, Lay
|
||||
Tensor g_in_tensor = make_tensor(make_gmem_ptr(g_in), layout1);
|
||||
Tensor g_out_tensor = make_tensor(make_gmem_ptr(g_out), layout2);
|
||||
|
||||
cooperative_copy<ThreadBlockSize, MaxVecBits>(threadIdx.x, g_in_tensor, g_out_tensor);
|
||||
cooperative_copy<ThreadBlockSize, MaxVecBits>(threadIdx.x, g_in_tensor, g_out_tensor, AutoCopyAsync{});
|
||||
}
|
||||
|
||||
template <class Mode, int MaxVecBits, uint32_t ThreadBlockSize, class T, class Layout1, class Layout2>
|
||||
@ -252,7 +253,7 @@ typedef testing::Types<
|
||||
std::tuple<cooperative_copy_mode::shared_shared, cute::Int<128>>,
|
||||
std::tuple<cooperative_copy_mode::shared_shared, cute::Int<64>>,
|
||||
std::tuple<cooperative_copy_mode::shared_shared, cute::Int<32>>,
|
||||
std::tuple<cooperative_copy_mode::shared_shared, cute::Int<16>>,
|
||||
std::tuple<cooperative_copy_mode::shared_shared, cute::Int<16>>
|
||||
> CooperativeCopyModeMaxVecBitsList;
|
||||
|
||||
TYPED_TEST_SUITE(SM80_CuTe_Ampere, CooperativeCopyModeMaxVecBitsList);
|
||||
|
||||
@ -40,406 +40,462 @@
|
||||
using namespace cute;
|
||||
|
||||
TEST(SM80_CuTe_Ampere, CooperativeGemm1_Half_MMA) {
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
using value_type = cutlass::half_t;
|
||||
|
||||
constexpr uint32_t m = 64;
|
||||
constexpr uint32_t n = 64;
|
||||
constexpr uint32_t k = 64;
|
||||
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
|
||||
using tiled_mma_t =
|
||||
auto shape_mnk = Shape<_64, _64, _64>{};
|
||||
auto tiled_mma =
|
||||
TiledMMA<
|
||||
MMA_Atom<SM80_16x8x8_F16F16F16F16_TN>,
|
||||
Layout<Shape<_2, _2, _1>>
|
||||
>;
|
||||
>{};
|
||||
|
||||
test_cooperative_gemm_col_major_layout<m, n, k, thread_block_size, tiled_mma_t, value_type>();
|
||||
test_cooperative_gemm_col_major_layout<thread_block_size, value_type>(shape_mnk, tiled_mma);
|
||||
}
|
||||
|
||||
TEST(SM80_CuTe_Ampere, CooperativeGemm2_Double_MMA) {
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
using value_type = double;
|
||||
|
||||
constexpr uint32_t m = 64;
|
||||
constexpr uint32_t n = 64;
|
||||
constexpr uint32_t k = 64;
|
||||
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
|
||||
using tiled_mma_t =
|
||||
auto shape_mnk = Shape<_64, _64, _64>{};
|
||||
auto tiled_mma =
|
||||
TiledMMA<
|
||||
MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>,
|
||||
Layout<Shape<_2,_2,_1>>
|
||||
>;
|
||||
>{};
|
||||
|
||||
test_cooperative_gemm_col_major_layout<m, n, k, thread_block_size, tiled_mma_t, value_type>();
|
||||
test_cooperative_gemm_col_major_layout<thread_block_size, value_type>(shape_mnk, tiled_mma);
|
||||
}
|
||||
|
||||
TEST(SM80_CuTe_Ampere, CooperativeGemm3_Half_MMA_CustomSmemLayouts) {
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
constexpr uint32_t max_vec_bits = 128;
|
||||
using value_type = cutlass::half_t;
|
||||
|
||||
constexpr uint32_t m = 128;
|
||||
constexpr uint32_t n = 128;
|
||||
constexpr uint32_t k = 128;
|
||||
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
|
||||
using tiled_mma_t =
|
||||
auto shape_mnk = Shape<_128, _128, _128>{};
|
||||
auto tiled_mma =
|
||||
TiledMMA<
|
||||
MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>,
|
||||
Layout<Shape<_2, _2, _1>>, // 2x2x1 thread group
|
||||
Tile<_32, _32, _16> // 32x32x16 MMA for LDSM, 1x2x1 value group`
|
||||
>;
|
||||
>{};
|
||||
|
||||
using smem_a_atom_layout_t = Layout<Shape<_64, _8>, Stride< _1,_64>>;
|
||||
using smem_b_atom_layout_t = Layout<Shape< _8,_32>, Stride<_32, _1>>;
|
||||
using smem_c_atom_layout_t = decltype(make_layout(make_shape(Int<m>{}, Int<n>{})));
|
||||
auto smem_a_atom_layout = Layout<Shape<_64, _8>, Stride< _1,_64>>{};
|
||||
auto smem_b_atom_layout = Layout<Shape< _8,_32>, Stride<_32, _1>>{};
|
||||
auto smem_c_atom_layout = make_layout(select<0,1>(shape_mnk));
|
||||
|
||||
test_cooperative_gemm_col_major_layout<smem_a_atom_layout_t,
|
||||
smem_b_atom_layout_t,
|
||||
smem_c_atom_layout_t,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
thread_block_size,
|
||||
tiled_mma_t,
|
||||
128,
|
||||
test_cooperative_gemm_col_major_layout<thread_block_size,
|
||||
max_vec_bits,
|
||||
value_type,
|
||||
value_type,
|
||||
value_type>();
|
||||
value_type>
|
||||
(smem_a_atom_layout,
|
||||
smem_b_atom_layout,
|
||||
smem_c_atom_layout,
|
||||
shape_mnk, tiled_mma);
|
||||
}
|
||||
|
||||
TEST(SM80_CuTe_Ampere, CooperativeGemm4_Half_MMA_SwizzledSmemLayouts) {
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
constexpr uint32_t max_vec_bits = 128;
|
||||
using value_type = cutlass::half_t;
|
||||
|
||||
constexpr uint32_t m = 128;
|
||||
constexpr uint32_t n = 128;
|
||||
constexpr uint32_t k = 128;
|
||||
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
|
||||
using tiled_mma_t =
|
||||
auto shape_mnk = Shape<_128, _128, _128>{};
|
||||
auto tiled_mma =
|
||||
TiledMMA<
|
||||
MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>,
|
||||
Layout<Shape<_2, _2, _1>>, // 2x2x1 thread group
|
||||
Tile<_32, _32, _16> // 32x32x16 MMA for LDSM, 1x2x1 value group`
|
||||
>;
|
||||
>{};
|
||||
|
||||
// RowMajor
|
||||
using smem_rowmajor_atom_layout_t = decltype(
|
||||
auto smem_a_atom_layout =
|
||||
composition(Swizzle<3,3,3>{},
|
||||
Layout<Shape < _8,_64>,
|
||||
Stride<_64, _1>>{}));
|
||||
Stride<_64, _1>>{});
|
||||
// ColMajor
|
||||
using smem_colmajor_atom_layout_t = decltype(
|
||||
auto smem_b_atom_layout =
|
||||
composition(Swizzle<3,3,3>{},
|
||||
Layout<Shape <_64, _8>,
|
||||
Stride< _1,_64>>{}));
|
||||
using smem_a_atom_layout_t = smem_rowmajor_atom_layout_t;
|
||||
using smem_b_atom_layout_t = smem_colmajor_atom_layout_t;
|
||||
using smem_c_atom_layout_t = decltype(make_layout(make_shape(Int<m>{}, Int<n>{}), GenRowMajor{}));
|
||||
Stride< _1,_64>>{});
|
||||
|
||||
using gmem_a_layout_t = decltype(make_layout(make_shape(Int<m> {}, Int<k> {}), GenRowMajor{}));
|
||||
using gmem_b_layout_t = decltype(make_layout(make_shape(Int<n> {}, Int<k> {}), GenColMajor{}));
|
||||
using gmem_c_layout_t = decltype(make_layout(make_shape(Int<m> {}, Int<n> {}), GenRowMajor{}));
|
||||
auto smem_c_atom_layout = make_layout(select<0, 1>(shape_mnk), GenRowMajor{});
|
||||
|
||||
using smem_a_atom_layout_t = smem_a_atom_layout_t;
|
||||
using smem_a_layout_t = decltype(tile_to_shape(
|
||||
smem_a_atom_layout_t{},
|
||||
make_shape(shape<0>(gmem_a_layout_t{}), shape<1>(gmem_a_layout_t{})))
|
||||
);
|
||||
auto gmem_a_layout = make_layout(select<0, 2>(shape_mnk), GenRowMajor{});
|
||||
auto gmem_b_layout = make_layout(select<1, 2>(shape_mnk), GenColMajor{});
|
||||
auto gmem_c_layout = make_layout(select<0, 1>(shape_mnk), GenRowMajor{});
|
||||
|
||||
using smem_b_atom_layout_t = smem_b_atom_layout_t;
|
||||
using smem_b_layout_t = decltype(tile_to_shape(
|
||||
smem_b_atom_layout_t{},
|
||||
make_shape(shape<0>(gmem_b_layout_t{}), shape<1>(gmem_b_layout_t{})))
|
||||
);
|
||||
auto smem_a_layout = tile_to_shape(
|
||||
smem_a_atom_layout,
|
||||
make_shape(shape<0>(gmem_a_layout), shape<1>(gmem_a_layout)));
|
||||
|
||||
using smem_c_atom_layout_t = smem_c_atom_layout_t;
|
||||
using smem_c_layout_t = decltype(tile_to_shape(
|
||||
smem_c_atom_layout_t{},
|
||||
make_shape(shape<0>(gmem_c_layout_t{}), shape<1>(gmem_c_layout_t{})))
|
||||
);
|
||||
auto smem_b_layout = tile_to_shape(
|
||||
smem_b_atom_layout,
|
||||
make_shape(shape<0>(gmem_b_layout), shape<1>(gmem_b_layout)));
|
||||
|
||||
test_cooperative_gemm<gmem_a_layout_t,
|
||||
gmem_b_layout_t,
|
||||
gmem_c_layout_t,
|
||||
smem_a_layout_t,
|
||||
smem_b_layout_t,
|
||||
smem_c_layout_t,
|
||||
SM75_U32x4_LDSM_N, // A
|
||||
SM75_U16x8_LDSM_T, // B
|
||||
AutoVectorizingCopyWithAssumedAlignment<128>, // C
|
||||
thread_block_size,
|
||||
tiled_mma_t,
|
||||
128,
|
||||
auto smem_c_layout = tile_to_shape(
|
||||
smem_c_atom_layout,
|
||||
make_shape(shape<0>(gmem_c_layout), shape<1>(gmem_c_layout)));
|
||||
|
||||
test_cooperative_gemm<thread_block_size,
|
||||
max_vec_bits,
|
||||
value_type,
|
||||
value_type,
|
||||
value_type>();
|
||||
value_type>
|
||||
(gmem_a_layout,
|
||||
gmem_b_layout,
|
||||
gmem_c_layout,
|
||||
smem_a_layout,
|
||||
smem_b_layout,
|
||||
smem_c_layout,
|
||||
tiled_mma,
|
||||
cute::identity{}, // TransformLoadA
|
||||
cute::identity{}, // TransformLoadB
|
||||
cute::identity{}, // TransformLoadC
|
||||
cute::identity{}, // TransformStoreC
|
||||
SM75_U32x4_LDSM_N{}, // A
|
||||
SM75_U16x8_LDSM_T{}, // B
|
||||
AutoVectorizingCopyWithAssumedAlignment<128>{}); // C
|
||||
}
|
||||
|
||||
TEST(SM80_CuTe_Ampere, CooperativeGemm5_Double_MMA_SwizzledSmemLayouts) {
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
constexpr uint32_t max_vec_bits = 128;
|
||||
using value_type = double;
|
||||
|
||||
constexpr uint32_t m = 128;
|
||||
constexpr uint32_t n = 64;
|
||||
constexpr uint32_t k = 16;
|
||||
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
|
||||
using tiled_mma_t =
|
||||
auto shape_mnk = Shape<_128, _64, _16>{};
|
||||
auto tiled_mma =
|
||||
TiledMMA<MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>, // Atom
|
||||
Layout<Shape<_2, _2, _1>>, // Atom layout
|
||||
Tile<Layout<Shape<_16, _2>, Stride<_2, _1>>, // 32x32x4 MMA with perm for load vectorization
|
||||
Layout<Shape<_16, _2>, Stride<_2, _1>>,
|
||||
Underscore>>;
|
||||
Underscore>>{};
|
||||
|
||||
using smem_a_atom_layout_t = decltype(
|
||||
auto smem_a_atom_layout =
|
||||
composition(Swizzle<2,2,2>{},
|
||||
Layout<Shape <_16, _4>,
|
||||
Stride< _1,_16>>{})); // M, K
|
||||
using smem_b_atom_layout_t = decltype(
|
||||
Stride< _1,_16>>{}); // M, K
|
||||
auto smem_b_atom_layout =
|
||||
composition(Swizzle<2,2,2>{},
|
||||
Layout<Shape <_16, _4>,
|
||||
Stride< _1,_16>>{})); // N, K
|
||||
using smem_c_atom_layout_t = decltype(make_layout(make_shape(Int<m>{}, Int<n>{}), GenRowMajor{}));
|
||||
Stride< _1,_16>>{}); // N, K
|
||||
|
||||
using gmem_a_layout_t = decltype(make_layout(make_shape(Int<m> {}, Int<k> {}), GenRowMajor{}));
|
||||
using gmem_b_layout_t = decltype(make_layout(make_shape(Int<n> {}, Int<k> {}), GenColMajor{}));
|
||||
using gmem_c_layout_t = decltype(make_layout(make_shape(Int<m> {}, Int<n> {}), GenRowMajor{}));
|
||||
auto smem_c_atom_layout = make_layout(select<0, 1>(shape_mnk), GenRowMajor{});
|
||||
|
||||
using smem_a_atom_layout_t = smem_a_atom_layout_t;
|
||||
using smem_a_layout_t = decltype(tile_to_shape(
|
||||
smem_a_atom_layout_t{},
|
||||
make_shape(shape<0>(gmem_a_layout_t{}), shape<1>(gmem_a_layout_t{})))
|
||||
);
|
||||
using smem_b_atom_layout_t = smem_b_atom_layout_t;
|
||||
using smem_b_layout_t = decltype(tile_to_shape(
|
||||
smem_b_atom_layout_t{},
|
||||
make_shape(shape<0>(gmem_b_layout_t{}), shape<1>(gmem_b_layout_t{})))
|
||||
);
|
||||
using smem_c_atom_layout_t = smem_c_atom_layout_t;
|
||||
using smem_c_layout_t = decltype(tile_to_shape(
|
||||
smem_c_atom_layout_t{},
|
||||
make_shape(shape<0>(gmem_c_layout_t{}), shape<1>(gmem_c_layout_t{})))
|
||||
);
|
||||
auto gmem_a_layout = make_layout(select<0, 2>(shape_mnk), GenRowMajor{});
|
||||
auto gmem_b_layout = make_layout(select<1, 2>(shape_mnk), GenColMajor{});
|
||||
auto gmem_c_layout = make_layout(select<0, 1>(shape_mnk), GenRowMajor{});
|
||||
|
||||
test_cooperative_gemm<gmem_a_layout_t,
|
||||
gmem_b_layout_t,
|
||||
gmem_c_layout_t,
|
||||
smem_a_layout_t,
|
||||
smem_b_layout_t,
|
||||
smem_c_layout_t,
|
||||
AutoVectorizingCopyWithAssumedAlignment<128>, // A
|
||||
AutoVectorizingCopyWithAssumedAlignment<128>, // B
|
||||
AutoVectorizingCopyWithAssumedAlignment<128>, // C
|
||||
thread_block_size,
|
||||
tiled_mma_t,
|
||||
128,
|
||||
auto smem_a_layout = tile_to_shape(
|
||||
smem_a_atom_layout,
|
||||
make_shape(shape<0>(gmem_a_layout), shape<1>(gmem_a_layout)));
|
||||
auto smem_b_layout = tile_to_shape(
|
||||
smem_b_atom_layout,
|
||||
make_shape(shape<0>(gmem_b_layout), shape<1>(gmem_b_layout)));
|
||||
auto smem_c_layout = tile_to_shape(
|
||||
smem_c_atom_layout,
|
||||
make_shape(shape<0>(gmem_c_layout), shape<1>(gmem_c_layout)));
|
||||
|
||||
test_cooperative_gemm<thread_block_size,
|
||||
max_vec_bits,
|
||||
value_type,
|
||||
value_type,
|
||||
value_type>();
|
||||
value_type>
|
||||
(gmem_a_layout,
|
||||
gmem_b_layout,
|
||||
gmem_c_layout,
|
||||
smem_a_layout,
|
||||
smem_b_layout,
|
||||
smem_c_layout,
|
||||
tiled_mma);
|
||||
}
|
||||
|
||||
TEST(SM80_CuTe_Ampere, CooperativeGemm6_MixedPrecisionFP16FP32_MMA) {
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
constexpr uint32_t max_vec_bits = 128;
|
||||
using TA = cutlass::half_t;
|
||||
using TB = cutlass::half_t;
|
||||
using TC = float;
|
||||
|
||||
constexpr uint32_t m = 64;
|
||||
constexpr uint32_t n = 64;
|
||||
constexpr uint32_t k = 64;
|
||||
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
|
||||
using tiled_mma_t =
|
||||
auto shape_mnk = Shape<_64, _64, _64>{};
|
||||
auto tiled_mma =
|
||||
TiledMMA<
|
||||
MMA_Atom<SM80_16x8x8_F32F16F16F32_TN>,
|
||||
Layout<Shape<_2, _2, _1>>
|
||||
>;
|
||||
>{};
|
||||
|
||||
test_cooperative_gemm_col_major_layout<m, n, k, thread_block_size, tiled_mma_t, 128, TA, TB, TC>();
|
||||
test_cooperative_gemm_col_major_layout<thread_block_size, max_vec_bits, TA, TB, TC>(shape_mnk, tiled_mma);
|
||||
}
|
||||
|
||||
TEST(SM80_CuTe_Ampere, CooperativeGemm7_MixedPrecisionBF16FP32_MMA) {
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
constexpr uint32_t max_vec_bits = 128;
|
||||
using TA = cutlass::bfloat16_t;
|
||||
using TB = cutlass::bfloat16_t;
|
||||
using TC = float;
|
||||
|
||||
constexpr uint32_t m = 64;
|
||||
constexpr uint32_t n = 64;
|
||||
constexpr uint32_t k = 64;
|
||||
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
|
||||
using tiled_mma_t =
|
||||
auto shape_mnk = Shape<_64, _64, _64>{};
|
||||
auto tiled_mma =
|
||||
TiledMMA<
|
||||
MMA_Atom<SM80_16x8x8_F32BF16BF16F32_TN>,
|
||||
Layout<Shape<_2, _2, _1>>
|
||||
>;
|
||||
>{};
|
||||
|
||||
test_cooperative_gemm_col_major_layout<m, n, k, thread_block_size, tiled_mma_t, 128, TA, TB, TC>();
|
||||
test_cooperative_gemm_col_major_layout<thread_block_size, max_vec_bits, TA, TB, TC>(shape_mnk, tiled_mma);
|
||||
}
|
||||
|
||||
TEST(SM80_CuTe_Ampere, CooperativeGemm8_MixedPrecisionTF32FP32_MMA) {
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
constexpr uint32_t max_vec_bits = 128;
|
||||
using TA = cutlass::tfloat32_t;
|
||||
using TB = cutlass::tfloat32_t;
|
||||
using TC = float;
|
||||
|
||||
constexpr uint32_t m = 64;
|
||||
constexpr uint32_t n = 64;
|
||||
constexpr uint32_t k = 64;
|
||||
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
|
||||
using tiled_mma_t =
|
||||
auto shape_mnk = Shape<_64, _64, _64>{};
|
||||
auto tiled_mma =
|
||||
TiledMMA<
|
||||
MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>,
|
||||
Layout<Shape<_2, _2, _1>>
|
||||
>;
|
||||
>{};
|
||||
|
||||
test_cooperative_gemm_col_major_layout<m, n, k, thread_block_size, tiled_mma_t, 128, TA, TB, TC>();
|
||||
test_cooperative_gemm_col_major_layout<thread_block_size, max_vec_bits, TA, TB, TC>(shape_mnk, tiled_mma);
|
||||
}
|
||||
|
||||
TEST(SM80_CuTe_Ampere, CooperativeGemm9_C64C64C64_MMA) {
|
||||
|
||||
TEST(SM80_CuTe_Ampere, CooperativeGemm9_C64C64C64_MMA_Dynamic) {
|
||||
constexpr uint32_t thread_block_size = 256;
|
||||
constexpr int MaxVecBits = 128;
|
||||
using TA = cutlass::complex<double>;
|
||||
using TB = cutlass::complex<double>;
|
||||
using TC = cutlass::complex<double>;
|
||||
|
||||
constexpr uint32_t thread_block_size = 256;
|
||||
constexpr int MaxVecBits = 128;
|
||||
|
||||
using tiled_mma_t =
|
||||
auto tiled_mma =
|
||||
TiledMMA<
|
||||
MMA_Atom<SM80_8x8x4_C64C64C64C64_TN>,
|
||||
Layout<Shape<_4, _4, _1>, Stride<_1, _4, _0>>,
|
||||
Tile<Underscore, Underscore, Underscore>
|
||||
>;
|
||||
>{};
|
||||
|
||||
using ALayout = Layout<Shape<Int<13>,Int<35>>, Stride<Int<44>, Int<1> >>;
|
||||
using BLayout = Layout<Shape< Int<7>, Int<35>>, Stride<Int<44>, Int<1> >>;
|
||||
using CLayout = Layout<Shape<Int<13>, Int<7>>, Stride< Int<1>, Int<30>>>;
|
||||
auto a_layout = make_layout(Shape<Int<13>,Int<35>>{}, make_stride(44, 1));
|
||||
auto b_layout = make_layout(Shape< Int<7>, Int<35>>{}, make_stride(44, 1));
|
||||
auto c_layout = make_layout(Shape<Int<13>, Int<7>>{}, make_stride(1, 30));
|
||||
|
||||
|
||||
test_cooperative_gemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
AutoVectorizingCopyWithAssumedAlignment<MaxVecBits>, // A
|
||||
AutoVectorizingCopyWithAssumedAlignment<MaxVecBits>, // B
|
||||
AutoVectorizingCopyWithAssumedAlignment<MaxVecBits>, // C
|
||||
thread_block_size,
|
||||
tiled_mma_t,
|
||||
test_cooperative_gemm<thread_block_size,
|
||||
MaxVecBits,
|
||||
TA,
|
||||
TB,
|
||||
TC>();
|
||||
TA, TB, TC>
|
||||
(a_layout,
|
||||
b_layout,
|
||||
c_layout,
|
||||
a_layout,
|
||||
b_layout,
|
||||
c_layout,
|
||||
tiled_mma);
|
||||
}
|
||||
|
||||
TEST(SM80_CuTe_Ampere, CooperativeGemm9_C64C64C64_MMA) {
|
||||
constexpr uint32_t thread_block_size = 256;
|
||||
constexpr int MaxVecBits = 128;
|
||||
using TA = cutlass::complex<double>;
|
||||
using TB = cutlass::complex<double>;
|
||||
using TC = cutlass::complex<double>;
|
||||
|
||||
auto tiled_mma =
|
||||
TiledMMA<
|
||||
MMA_Atom<SM80_8x8x4_C64C64C64C64_TN>,
|
||||
Layout<Shape<_4, _4, _1>, Stride<_1, _4, _0>>,
|
||||
Tile<Underscore, Underscore, Underscore>
|
||||
>{};
|
||||
|
||||
auto a_layout = Layout<Shape<Int<13>,Int<35>>, Stride<Int<44>, Int<1> >>{};
|
||||
auto b_layout = Layout<Shape< Int<7>, Int<35>>, Stride<Int<44>, Int<1> >>{};
|
||||
auto c_layout = Layout<Shape<Int<13>, Int<7>>, Stride< Int<1>, Int<30>>>{};
|
||||
|
||||
test_cooperative_gemm<thread_block_size,
|
||||
MaxVecBits,
|
||||
TA, TB, TC>
|
||||
(a_layout,
|
||||
b_layout,
|
||||
c_layout,
|
||||
a_layout,
|
||||
b_layout,
|
||||
c_layout,
|
||||
tiled_mma);
|
||||
}
|
||||
|
||||
TEST(SM80_CuTe_Ampere, CooperativeGemm10_F16F64F16_FMA) {
|
||||
|
||||
constexpr uint32_t thread_block_size = 256;
|
||||
constexpr int MaxVecBits = 128;
|
||||
using TA = cutlass::half_t;
|
||||
using TB = double;
|
||||
using TC = cutlass::half_t;
|
||||
|
||||
constexpr uint32_t thread_block_size = 256;
|
||||
constexpr int MaxVecBits = 128;
|
||||
|
||||
using tiled_mma_t =
|
||||
auto tiled_mma =
|
||||
TiledMMA<
|
||||
MMA_Atom<UniversalFMA<half_t, half_t, double, half_t>>,
|
||||
Layout<Shape<_16, _16, _1>, Stride<_1, _16, _0>>,
|
||||
Tile<Underscore, Underscore, Underscore>
|
||||
>;
|
||||
>{};
|
||||
|
||||
using ALayout = Layout<Shape<Int<64>,Int<64>>, Stride<Int<64>, Int< 1>>>;
|
||||
using BLayout = Layout<Shape<Int<64>,Int<64>>, Stride<Int< 1>, Int<64>>>;
|
||||
using CLayout = Layout<Shape<Int<64>,Int<64>>, Stride<Int< 1>, Int<64>>>;
|
||||
auto a_layout = Layout<Shape<Int<64>,Int<64>>, Stride<Int<64>, Int< 1>>>{};
|
||||
auto b_layout = Layout<Shape<Int<64>,Int<64>>, Stride<Int< 1>, Int<64>>>{};
|
||||
auto c_layout = Layout<Shape<Int<64>,Int<64>>, Stride<Int< 1>, Int<64>>>{};
|
||||
|
||||
|
||||
test_cooperative_gemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
AutoVectorizingCopyWithAssumedAlignment<MaxVecBits>, // A
|
||||
AutoVectorizingCopyWithAssumedAlignment<MaxVecBits>, // B
|
||||
AutoVectorizingCopyWithAssumedAlignment<MaxVecBits>, // C
|
||||
thread_block_size,
|
||||
tiled_mma_t,
|
||||
test_cooperative_gemm<thread_block_size,
|
||||
MaxVecBits,
|
||||
TA,
|
||||
TB,
|
||||
TC>();
|
||||
TC>
|
||||
(a_layout,
|
||||
b_layout,
|
||||
c_layout,
|
||||
a_layout,
|
||||
b_layout,
|
||||
c_layout,
|
||||
tiled_mma);
|
||||
}
|
||||
|
||||
TEST(SM80_CuTe_Ampere, CooperativeGemmComposedStride) {
|
||||
|
||||
using T = cute::half_t;
|
||||
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
constexpr int MaxVecBits = 16;
|
||||
using T = cute::half_t;
|
||||
|
||||
using tiled_mma_t =
|
||||
auto tiled_mma =
|
||||
TiledMMA<
|
||||
MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>,
|
||||
Layout<Shape<_2, _2, _1>, Stride<_1, _2, _0>>,
|
||||
Tile<Underscore, Underscore, Underscore>
|
||||
>;
|
||||
>{};
|
||||
|
||||
using swizzle = cute::Swizzle<3, 3, 3>;
|
||||
using offset = cute::_0;
|
||||
using atom_tile_right = decltype(cute::make_layout(cute::Shape<cute::_8, cute::_64>{}, cute::LayoutRight{}));
|
||||
using FP16AtomLayoutRight = decltype(cute::composition(swizzle{}, offset{}, atom_tile_right{}));
|
||||
auto swizzle = cute::Swizzle<3, 3, 3>{};
|
||||
auto offset = cute::_0{};
|
||||
auto atom_tile_right = cute::make_layout(cute::Shape<cute::_8, cute::_64>{}, cute::LayoutRight{});
|
||||
auto FP16AtomLayoutRight = cute::composition(swizzle, offset, atom_tile_right);
|
||||
|
||||
using shape = cute::Shape<cute::Int<128>, cute::Int<128>>;
|
||||
using global_a_layout = decltype(cute::make_layout(shape{}, cute::LayoutRight{}));
|
||||
using global_b_layout = decltype(cute::make_layout(shape{}, cute::LayoutLeft{}));
|
||||
using global_c_layout = decltype(cute::make_layout(shape{}, cute::LayoutRight{}));
|
||||
auto shape = cute::Shape<cute::Int<128>, cute::Int<128>>{};
|
||||
auto global_a_layout = cute::make_layout(shape, cute::LayoutRight{});
|
||||
auto global_b_layout = cute::make_layout(shape, cute::LayoutLeft{});
|
||||
auto global_c_layout = cute::make_layout(shape, cute::LayoutRight{});
|
||||
|
||||
// This is for A row major, B col major according to CUTLASS default configs
|
||||
using ALayout = decltype(cute::tile_to_shape(FP16AtomLayoutRight{}, global_a_layout{}));
|
||||
using BLayout = decltype(cute::tile_to_shape(FP16AtomLayoutRight{}, global_b_layout{}));
|
||||
using CLayout = global_c_layout;
|
||||
auto a_layout = cute::tile_to_shape(FP16AtomLayoutRight, global_a_layout);
|
||||
auto b_layout = cute::tile_to_shape(FP16AtomLayoutRight, global_b_layout);
|
||||
auto c_layout = global_c_layout;
|
||||
|
||||
test_cooperative_gemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
AutoVectorizingCopyWithAssumedAlignment<MaxVecBits>, // A
|
||||
AutoVectorizingCopyWithAssumedAlignment<MaxVecBits>, // B
|
||||
AutoVectorizingCopyWithAssumedAlignment<MaxVecBits>, // C
|
||||
thread_block_size,
|
||||
tiled_mma_t,
|
||||
test_cooperative_gemm<thread_block_size,
|
||||
MaxVecBits,
|
||||
T,
|
||||
T,
|
||||
T>();
|
||||
T, T, T>
|
||||
(a_layout,
|
||||
b_layout,
|
||||
c_layout,
|
||||
a_layout,
|
||||
b_layout,
|
||||
c_layout,
|
||||
tiled_mma);
|
||||
}
|
||||
|
||||
TEST(SM89_CuTe_Ampere, CooperativeGemm8_MixedPrecisionTF32FP32_Transform) {
|
||||
TEST(SM80_CuTe_Ampere, CooperativeGemm8_MixedPrecisionTF32FP32_Transform) {
|
||||
constexpr uint32_t thread_block_size = 64;
|
||||
constexpr uint32_t max_vec_bits = 16;
|
||||
using TA = cutlass::tfloat32_t;
|
||||
using TB = cutlass::tfloat32_t;
|
||||
using TC = float;
|
||||
|
||||
constexpr uint32_t m = 9;
|
||||
constexpr uint32_t n = 9;
|
||||
constexpr uint32_t k = 9;
|
||||
|
||||
constexpr uint32_t thread_block_size = 64;
|
||||
|
||||
using tiled_mma_t =
|
||||
auto shape_mnk = Shape<C<9>, C<9>, C<9>>{};
|
||||
auto tiled_mma =
|
||||
TiledMMA<
|
||||
MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>,
|
||||
Layout<Shape<_1, _2, _1>>
|
||||
>;
|
||||
>{};
|
||||
|
||||
test_cooperative_gemm_col_major_layout<m, n, k, thread_block_size, tiled_mma_t, 16, TA, TB, TC>(cute::negate{}, cute::negate{}, cute::negate{}, cute::negate{});
|
||||
test_cooperative_gemm_col_major_layout<thread_block_size, max_vec_bits, TA, TB, TC>
|
||||
(shape_mnk, tiled_mma, cute::negate{}, cute::negate{}, cute::negate{}, cute::negate{});
|
||||
}
|
||||
|
||||
TEST(SM80_CuTe_Ampere, CooperativeGemm8_MixedPrecisionTF32FP32_TransformPrecision) {
|
||||
constexpr uint32_t thread_block_size = 64;
|
||||
constexpr uint32_t max_vec_bits = 16;
|
||||
using InputTA = cutlass::half_t;
|
||||
using InputTB = cutlass::half_t;
|
||||
using InputTC = cutlass::half_t;
|
||||
|
||||
using ComputeTA = cutlass::tfloat32_t;
|
||||
using ComputeTB = cutlass::tfloat32_t;
|
||||
using ComputeTC = float;
|
||||
|
||||
auto shape_mnk = Shape<C<9>, C<9>, C<9>>{};
|
||||
auto tiled_mma =
|
||||
TiledMMA<
|
||||
MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>,
|
||||
Layout<Shape<_1, _2, _1>>
|
||||
>{};
|
||||
|
||||
test_cooperative_gemm_col_major_layout<thread_block_size, max_vec_bits, InputTA, InputTB, InputTC>
|
||||
(shape_mnk, tiled_mma);
|
||||
}
|
||||
|
||||
TEST(SM80_CuTe_Ampere, CooperativeGemm8_MixedPrecisionTF32FP32_TransformPrecisionReg) {
|
||||
constexpr uint32_t thread_block_size = 64;
|
||||
constexpr uint32_t max_vec_bits = 16;
|
||||
using InputTA = cutlass::half_t;
|
||||
using InputTB = cutlass::half_t;
|
||||
using InputTC = cutlass::half_t;
|
||||
|
||||
using ComputeTA = cutlass::tfloat32_t;
|
||||
using ComputeTB = cutlass::tfloat32_t;
|
||||
using ComputeTC = float;
|
||||
|
||||
auto shape_mnk = Shape<C<9>, C<9>, C<9>>{};
|
||||
auto tiled_mma =
|
||||
TiledMMA<
|
||||
MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>,
|
||||
Layout<Shape<_1, _2, _1>>
|
||||
>{};
|
||||
|
||||
test_cooperative_gemm_col_major_layout_rmem_c<thread_block_size, max_vec_bits, InputTA, InputTB, InputTC>
|
||||
(shape_mnk, tiled_mma);
|
||||
}
|
||||
|
||||
TEST(SM80_CuTe_Ampere, CooperativeGemm1_Half_MMA_Reg) {
|
||||
using value_type = cutlass::half_t;
|
||||
|
||||
auto shape_mnk = Shape<_64, _64, _64>{};
|
||||
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
|
||||
auto tiled_mma =
|
||||
TiledMMA<
|
||||
MMA_Atom<SM80_16x8x8_F16F16F16F16_TN>,
|
||||
Layout<Shape<_2, _2, _1>>
|
||||
>{};
|
||||
|
||||
test_cooperative_gemm_col_major_layout_rmem_c<thread_block_size, value_type>(shape_mnk, tiled_mma);
|
||||
}
|
||||
|
||||
TEST(SM80_CuTe_Ampere, CooperativeGemm2_Double_MMA_Reg) {
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
using value_type = double;
|
||||
|
||||
auto shape_mnk = Shape<_64, _64, _64>{};
|
||||
auto tiled_mma =
|
||||
TiledMMA<
|
||||
MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>,
|
||||
Layout<Shape<_2,_2,_1>>
|
||||
>{};
|
||||
|
||||
test_cooperative_gemm_col_major_layout_rmem_c<thread_block_size, value_type>(shape_mnk, tiled_mma);
|
||||
}
|
||||
|
||||
TEST(SM80_CuTe_Ampere, CooperativeGemm2_Double_MMA_Predicated_Reg) {
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
using value_type = double;
|
||||
|
||||
auto shape_mnk = Shape<C<62>, C<62>, C<62>>{};
|
||||
auto tiled_mma =
|
||||
TiledMMA<
|
||||
MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>,
|
||||
Layout<Shape<_2,_2,_1>>
|
||||
>{};
|
||||
|
||||
test_cooperative_gemm_col_major_layout_rmem_c<thread_block_size, value_type>(shape_mnk, tiled_mma);
|
||||
}
|
||||
|
||||
@ -69,14 +69,12 @@ test2(double const* g_in, double* g_out)
|
||||
|
||||
copy(g_tensor, s_tensor);
|
||||
|
||||
cp_async_fence();
|
||||
cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
g_out[threadIdx.x] = 2 * smem[threadIdx.x];
|
||||
}
|
||||
|
||||
TEST(SM80_CuTe_Ampere, CpAsync)
|
||||
TEST(SM80_CuTe_Ampere, CpSync)
|
||||
{
|
||||
constexpr int count = 32;
|
||||
thrust::host_vector<double> h_in(count);
|
||||
File diff suppressed because it is too large
Load Diff
@ -104,7 +104,7 @@ TEST(CuTe_core, Inverse_left)
|
||||
auto layout = Layout<Shape <_8, _4>,
|
||||
Stride<_4, _1>>{};
|
||||
|
||||
test_left_inverse(filter(layout));
|
||||
test_left_inverse(layout);
|
||||
}
|
||||
|
||||
{
|
||||
|
||||
@ -44,91 +44,74 @@ using namespace cute;
|
||||
#if USE_FP8
|
||||
TEST(SM90_CuTe_Hopper, CooperativeGemmTilingF8) {
|
||||
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
constexpr int MaxVecBits = 16;
|
||||
using TA = uint8_t;
|
||||
using TB = uint8_t;
|
||||
using TC = uint32_t;
|
||||
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
constexpr int MaxVecBits = 16;
|
||||
|
||||
using tiled_mma_t =
|
||||
auto tiled_mma =
|
||||
TiledMMA<
|
||||
MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>,
|
||||
Layout<Shape<_2, _2, _1>, Stride<_1, _2, _0>>,
|
||||
Tile<_32, _32, _32>
|
||||
>;
|
||||
>{};
|
||||
|
||||
using swizzle = Swizzle<2, 4, 3>;
|
||||
auto swizzle = Swizzle<2, 4, 3>{};
|
||||
|
||||
// This is for A row major, B col major according to CUTLASS default configs
|
||||
using ALayout = decltype(composition(swizzle{}, Layout<Shape<_64, _64>, Stride<_64, _1>>{}));
|
||||
using BLayout = decltype(composition(swizzle{}, Layout<Shape<_64, _64>, Stride<_1, _64>>{}));
|
||||
auto a_layout = composition(swizzle, Layout<Shape<_64, _64>, Stride<_64, _1>>{});
|
||||
auto b_layout = composition(swizzle, Layout<Shape<_64, _64>, Stride<_1, _64>>{});
|
||||
auto c_layout = make_layout(Shape<_64, _64>{}, LayoutLeft{});
|
||||
|
||||
using CLayout = decltype(make_layout(Shape<_64, _64>{}, LayoutLeft{}));
|
||||
|
||||
test_cooperative_gemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
AutoVectorizingCopyWithAssumedAlignment<MaxVecBits>, // A
|
||||
AutoVectorizingCopyWithAssumedAlignment<MaxVecBits>, // B
|
||||
AutoVectorizingCopyWithAssumedAlignment<MaxVecBits>, // C
|
||||
thread_block_size,
|
||||
tiled_mma_t,
|
||||
test_cooperative_gemm<thread_block_size,
|
||||
MaxVecBits,
|
||||
TA,
|
||||
TB,
|
||||
TC>();
|
||||
|
||||
TA, TB, TC>
|
||||
(a_layout,
|
||||
b_layout,
|
||||
c_layout,
|
||||
a_layout,
|
||||
b_layout,
|
||||
c_layout,
|
||||
tiled_mma);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
TEST(SM90_CuTe_Hopper, CooperativeGemmTilingF16) {
|
||||
|
||||
constexpr uint32_t thread_block_size = 64;
|
||||
constexpr int max_vec_bits = 16;
|
||||
using TA = half_t;
|
||||
using TB = half_t;
|
||||
using TC = half_t;
|
||||
|
||||
constexpr uint32_t thread_block_size = 64;
|
||||
constexpr int MaxVecBits = 16;
|
||||
|
||||
using tiled_mma_t =
|
||||
auto tiled_mma =
|
||||
TiledMMA<
|
||||
MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>,
|
||||
Layout<Shape<_2, _1, _1>, Stride<_1, _0, _0>>,
|
||||
Tile<_32, _32, _32>
|
||||
>;
|
||||
|
||||
using swizzle = Swizzle<3, 3, 3>;
|
||||
>{};
|
||||
|
||||
// This is for A row major, B col major according to CUTLASS default configs
|
||||
using ALayout = decltype(composition(swizzle{},
|
||||
Layout<Shape<_64, _64>, Stride<_64, _1>>{}));
|
||||
auto swizzle = Swizzle<3, 3, 3>{};
|
||||
auto ALayout = composition(swizzle{}, Layout<Shape<_64, _64>, Stride<_64, _1>>{});
|
||||
auto BLayout = composition(swizzle{}, Layout<Shape<_64, _64>, Stride<_1, _64>>{});
|
||||
auto CLayout = make_layout(Shape<_64, _64>{}, LayoutLeft{});
|
||||
|
||||
using BLayout = decltype(composition(swizzle{},
|
||||
Layout<Shape<_64, _64>, Stride<_1, _64>>{}));
|
||||
|
||||
using CLayout = decltype(make_layout(Shape<_64, _64>{}, LayoutLeft{}));
|
||||
|
||||
test_cooperative_gemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
AutoVectorizingCopyWithAssumedAlignment<MaxVecBits>, // A
|
||||
AutoVectorizingCopyWithAssumedAlignment<MaxVecBits>, // B
|
||||
AutoVectorizingCopyWithAssumedAlignment<MaxVecBits>, // C
|
||||
thread_block_size,
|
||||
tiled_mma_t,
|
||||
MaxVecBits,
|
||||
test_cooperative_gemm<thread_block_size,
|
||||
max_vec_bits,
|
||||
TA,
|
||||
TB,
|
||||
TC>();
|
||||
TC>
|
||||
|
||||
(ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
tiled_mma);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@ -131,7 +131,7 @@ tma_test_device_cute(T const* g_in, T* g_out,
|
||||
for (int stage = 0; stage < size<1>(tAgA); ++stage)
|
||||
{
|
||||
// Set the bytes transferred in this TMA transaction (may involve multiple issues)
|
||||
constexpr int kTmaTransactionBytes = sizeof(ArrayEngine<T, size(sA)>);
|
||||
constexpr int kTmaTransactionBytes = sizeof(ArrayEngine<T, CUTE_STATIC_V(size(filter_zeros(sA)))>);
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
|
||||
@ -146,7 +146,7 @@ tma_test_device_cute(T const* g_in, T* g_out, GmemLayout gmem_layout, SmemLayout
|
||||
for (int stage = 0; stage < size<1>(tAgA); ++stage)
|
||||
{
|
||||
// Set the bytes transferred in this TMA transaction (may involve multiple issues)
|
||||
constexpr int kTmaTransactionBytes = sizeof(ArrayEngine<T, size(sA)>);
|
||||
constexpr int kTmaTransactionBytes = sizeof(ArrayEngine<T, CUTE_STATIC_V(size(filter_zeros(sA)))>);
|
||||
|
||||
if (elect_one_thr)
|
||||
{
|
||||
|
||||
@ -38,21 +38,19 @@
|
||||
using namespace cute;
|
||||
|
||||
TEST(SM75_CuTe_Turing, CooperativeGemm1_MixedPrecisionFP16FP32_MMA) {
|
||||
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
constexpr uint32_t max_vec_bits = 128;
|
||||
using TA = cutlass::half_t;
|
||||
using TB = cutlass::half_t;
|
||||
using TC = float;
|
||||
|
||||
constexpr uint32_t m = 64;
|
||||
constexpr uint32_t n = 64;
|
||||
constexpr uint32_t k = 64;
|
||||
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
|
||||
using tiled_mma_t =
|
||||
auto shape_mnk = make_shape(_64{}, _64{}, _64{});
|
||||
auto tiled_mma =
|
||||
TiledMMA<
|
||||
MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>,
|
||||
Layout<Shape<_2, _2, _1>>
|
||||
>;
|
||||
>{};
|
||||
|
||||
test_cooperative_gemm_col_major_layout<m, n, k, thread_block_size, tiled_mma_t, 128, TA, TB, TC>();
|
||||
test_cooperative_gemm_col_major_layout<thread_block_size, max_vec_bits, TA, TB, TC>(shape_mnk, tiled_mma);
|
||||
}
|
||||
|
||||
@ -40,105 +40,85 @@
|
||||
using namespace cute;
|
||||
|
||||
TEST(SM70_CuTe_Volta, CooperativeGemm1_FloatFMA) {
|
||||
using value_type = float;
|
||||
|
||||
constexpr uint32_t m = 64;
|
||||
constexpr uint32_t n = 32;
|
||||
constexpr uint32_t k = 16;
|
||||
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
using value_type = float;
|
||||
|
||||
using tiled_mma_t =
|
||||
auto shape_mnk = make_shape(_64{}, _32{}, _16{});
|
||||
auto tiled_mma =
|
||||
TiledMMA<
|
||||
MMA_Atom<UniversalFMA<value_type, value_type, value_type, value_type>>,
|
||||
Layout<Shape<_16, _8, _1>>
|
||||
>;
|
||||
>{};
|
||||
|
||||
test_cooperative_gemm_col_major_layout<m, n, k, thread_block_size, tiled_mma_t, value_type>();
|
||||
test_cooperative_gemm_col_major_layout<thread_block_size, value_type>(shape_mnk, tiled_mma);
|
||||
}
|
||||
|
||||
TEST(SM70_CuTe_Volta, CooperativeGemm1_FloatFMA_Predication) {
|
||||
using value_type = float;
|
||||
|
||||
constexpr uint32_t m = 88;
|
||||
constexpr uint32_t n = 20;
|
||||
constexpr uint32_t k = 12;
|
||||
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
using value_type = float;
|
||||
|
||||
using tiled_mma_t =
|
||||
auto shape_mnk = make_shape(C<88>{}, C<20>{}, C<12>{});
|
||||
auto tiled_mma =
|
||||
TiledMMA<
|
||||
MMA_Atom<UniversalFMA<value_type, value_type, value_type, value_type>>,
|
||||
Layout<Shape<_2, _64, _1>>
|
||||
>;
|
||||
>{};
|
||||
|
||||
test_cooperative_gemm_col_major_layout<m, n, k, thread_block_size, tiled_mma_t, value_type>();
|
||||
test_cooperative_gemm_col_major_layout<thread_block_size, value_type>(shape_mnk, tiled_mma);
|
||||
}
|
||||
|
||||
TEST(SM70_CuTe_Volta, CooperativeGemm1_FloatFMA_Predication2) {
|
||||
using value_type = float;
|
||||
|
||||
constexpr uint32_t m = 88;
|
||||
constexpr uint32_t n = 36;
|
||||
constexpr uint32_t k = 24;
|
||||
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
using value_type = float;
|
||||
|
||||
using tiled_mma_t =
|
||||
auto shape_mnk = make_shape(C<88>{}, C<36>{}, C<24>{});
|
||||
auto tiled_mma =
|
||||
TiledMMA<
|
||||
MMA_Atom<UniversalFMA<value_type, value_type, value_type, value_type>>,
|
||||
Layout<Shape<_4, _32, _1>>
|
||||
>;
|
||||
>{};
|
||||
|
||||
test_cooperative_gemm_col_major_layout<m, n, k, thread_block_size, tiled_mma_t, value_type>();
|
||||
test_cooperative_gemm_col_major_layout<thread_block_size, value_type>(shape_mnk, tiled_mma);
|
||||
}
|
||||
|
||||
TEST(SM70_CuTe_Volta, CooperativeGemm1_FloatFMA_Predication3) {
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
using value_type = float;
|
||||
|
||||
constexpr uint32_t m = 67;
|
||||
constexpr uint32_t n = 13;
|
||||
constexpr uint32_t k = 11;
|
||||
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
|
||||
using tiled_mma_t =
|
||||
auto shape_mnk = make_shape(C<67>{}, C<13>{}, C<11>{});
|
||||
auto tiled_mma =
|
||||
TiledMMA<
|
||||
MMA_Atom<UniversalFMA<value_type, value_type, value_type, value_type>>,
|
||||
Layout<Shape<_1, _128, _1>>
|
||||
>;
|
||||
>{};
|
||||
|
||||
test_cooperative_gemm_col_major_layout<m, n, k, thread_block_size, tiled_mma_t, value_type>();
|
||||
test_cooperative_gemm_col_major_layout<thread_block_size, value_type>(shape_mnk, tiled_mma);
|
||||
}
|
||||
|
||||
TEST(SM70_CuTe_Volta, CooperativeGemm2_DoubleFMA) {
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
using value_type = double;
|
||||
|
||||
constexpr uint32_t m = 16;
|
||||
constexpr uint32_t n = 32;
|
||||
constexpr uint32_t k = 32;
|
||||
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
|
||||
using tiled_mma_t =
|
||||
auto shape_mnk = make_shape(C<16>{}, C<32>{}, C<32>{});
|
||||
auto tiled_mma =
|
||||
TiledMMA<
|
||||
MMA_Atom<UniversalFMA<value_type, value_type, value_type, value_type>>,
|
||||
Layout<Shape<_16, _8, _1>>
|
||||
>;
|
||||
>{};
|
||||
|
||||
test_cooperative_gemm_col_major_layout<m, n, k, thread_block_size, tiled_mma_t, value_type>();
|
||||
test_cooperative_gemm_col_major_layout<thread_block_size, value_type>(shape_mnk, tiled_mma);
|
||||
}
|
||||
|
||||
TEST(SM70_CuTe_Volta, CooperativeGemm3_Float_FMA_CustomPermutationMNK) {
|
||||
using value_type = float;
|
||||
|
||||
constexpr uint32_t m = 32;
|
||||
constexpr uint32_t n = 32;
|
||||
constexpr uint32_t k = 32;
|
||||
|
||||
constexpr uint32_t thread_block_size = 256;
|
||||
using value_type = float;
|
||||
|
||||
using tiled_mma_t = TiledMMA<
|
||||
auto shape_mnk = make_shape(_32{}, _32{}, _32{});
|
||||
auto tiled_mma = TiledMMA<
|
||||
MMA_Atom<
|
||||
UniversalFMA<value_type, value_type, value_type, value_type>
|
||||
>,
|
||||
@ -154,228 +134,188 @@ TEST(SM70_CuTe_Volta, CooperativeGemm3_Float_FMA_CustomPermutationMNK) {
|
||||
>,
|
||||
Underscore
|
||||
>
|
||||
>;
|
||||
>{};
|
||||
|
||||
test_cooperative_gemm_col_major_layout<m, n, k, thread_block_size, tiled_mma_t, value_type>();
|
||||
test_cooperative_gemm_col_major_layout<thread_block_size, value_type>(shape_mnk, tiled_mma);
|
||||
}
|
||||
|
||||
TEST(SM70_CuTe_Volta, CooperativeGemm4_Half_MMA) {
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
using value_type = cutlass::half_t;
|
||||
|
||||
constexpr uint32_t m = 32;
|
||||
constexpr uint32_t n = 32;
|
||||
constexpr uint32_t k = 32;
|
||||
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
|
||||
using tiled_mma_t = TiledMMA<
|
||||
auto shape_mnk = make_shape(_32{}, _32{}, _32{});
|
||||
auto tiled_mma = TiledMMA<
|
||||
MMA_Atom<SM70_8x8x4_F16F16F16F16_TN>,
|
||||
Layout<Shape<_4, _4, _1>>
|
||||
>;
|
||||
>{};
|
||||
|
||||
using smem_a_atom_layout_t = typename tiled_mma_t::AtomLayoutB_TV;
|
||||
using smem_b_atom_layout_t = typename tiled_mma_t::AtomLayoutA_TV;
|
||||
using smem_c_atom_layout_t = decltype(make_layout(make_shape(Int<m> {}, Int<n> {})));
|
||||
auto smem_a_atom_layout = typename decltype(tiled_mma)::AtomLayoutB_TV{};
|
||||
auto smem_b_atom_layout = typename decltype(tiled_mma)::AtomLayoutA_TV{};
|
||||
auto smem_c_atom_layout = make_layout(select<0, 1>(shape_mnk));
|
||||
|
||||
test_cooperative_gemm_col_major_layout<smem_a_atom_layout_t,
|
||||
smem_b_atom_layout_t,
|
||||
smem_c_atom_layout_t,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
thread_block_size,
|
||||
tiled_mma_t,
|
||||
value_type>();
|
||||
test_cooperative_gemm_col_major_layout<thread_block_size,
|
||||
value_type>
|
||||
(smem_a_atom_layout,
|
||||
smem_b_atom_layout,
|
||||
smem_c_atom_layout,
|
||||
shape_mnk,
|
||||
tiled_mma);
|
||||
}
|
||||
|
||||
TEST(SM70_CuTe_Volta, CooperativeGemm5_Half_MMA) {
|
||||
using value_type = cutlass::half_t;
|
||||
|
||||
constexpr uint32_t m = 32;
|
||||
constexpr uint32_t n = 32;
|
||||
constexpr uint32_t k = 32;
|
||||
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
constexpr uint32_t max_vec_bits = 128;
|
||||
using value_type = cutlass::half_t;
|
||||
|
||||
using tiled_mma_t = TiledMMA<
|
||||
auto shape_mnk = make_shape(_32{}, _32{}, _32{});
|
||||
auto tiled_mma = TiledMMA<
|
||||
MMA_Atom<SM70_8x8x4_F16F16F16F16_TN>,
|
||||
Layout<Shape<_4, _4, _1>>
|
||||
>;
|
||||
>{};
|
||||
|
||||
using gmem_a_layout_t = decltype(make_layout(make_shape(Int<m>{}, Int<k>{})));
|
||||
using gmem_b_layout_t = decltype(make_layout(make_shape(Int<n>{}, Int<k>{}), GenColMajor{}));
|
||||
using gmem_c_layout_t = decltype(make_layout(make_shape(Int<m>{}, Int<n>{})));
|
||||
auto gmem_a_layout = make_layout(select<0, 2>(shape_mnk));
|
||||
auto gmem_b_layout = make_layout(select<1, 2>(shape_mnk), GenColMajor{});
|
||||
auto gmem_c_layout = make_layout(select<0, 1>(shape_mnk));
|
||||
|
||||
using smem_a_layout_t = decltype(make_layout(make_shape(Int<m>{}, Int<k>{})));
|
||||
using smem_b_layout_t = decltype(make_layout(make_shape(Int<n>{}, Int<k>{}), GenColMajor{}));
|
||||
using smem_c_layout_t = decltype(make_layout(make_shape(Int<m>{}, Int<n>{})));
|
||||
auto smem_a_layout = make_layout(select<0, 2>(shape_mnk));
|
||||
auto smem_b_layout = make_layout(select<1, 2>(shape_mnk), GenColMajor{});
|
||||
auto smem_c_layout = make_layout(select<0, 1>(shape_mnk));
|
||||
|
||||
test_cooperative_gemm<gmem_a_layout_t,
|
||||
gmem_b_layout_t,
|
||||
gmem_c_layout_t,
|
||||
smem_a_layout_t,
|
||||
smem_b_layout_t,
|
||||
smem_c_layout_t,
|
||||
AutoVectorizingCopyWithAssumedAlignment<128>, // A
|
||||
AutoVectorizingCopyWithAssumedAlignment<128>, // B
|
||||
AutoVectorizingCopyWithAssumedAlignment<128>, // C
|
||||
thread_block_size,
|
||||
tiled_mma_t,
|
||||
128,
|
||||
test_cooperative_gemm<thread_block_size,
|
||||
max_vec_bits,
|
||||
value_type,
|
||||
value_type,
|
||||
value_type>();
|
||||
value_type>
|
||||
(gmem_a_layout,
|
||||
gmem_b_layout,
|
||||
gmem_c_layout,
|
||||
smem_a_layout,
|
||||
smem_b_layout,
|
||||
smem_c_layout,
|
||||
tiled_mma);
|
||||
}
|
||||
|
||||
TEST(SM70_CuTe_Volta, CooperativeGemm5_Half_MMA_Predicated) {
|
||||
using value_type = cutlass::half_t;
|
||||
|
||||
constexpr uint32_t m = 31;
|
||||
constexpr uint32_t n = 27;
|
||||
constexpr uint32_t k = 17;
|
||||
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
constexpr uint32_t max_vec_bits = 16;
|
||||
using value_type = cutlass::half_t;
|
||||
|
||||
using tiled_mma_t = TiledMMA<
|
||||
auto shape_mnk = make_shape(C<31>{}, C<27>{}, C<17>{});
|
||||
auto tiled_mma = TiledMMA<
|
||||
MMA_Atom<SM70_8x8x4_F16F16F16F16_TN>,
|
||||
Layout<Shape<_4, _4, _1>>
|
||||
>;
|
||||
>{};
|
||||
|
||||
using gmem_a_layout_t = decltype(make_layout(make_shape(Int<m>{}, Int<k>{})));
|
||||
using gmem_b_layout_t = decltype(make_layout(make_shape(Int<n>{}, Int<k>{}), GenColMajor{}));
|
||||
using gmem_c_layout_t = decltype(make_layout(make_shape(Int<m>{}, Int<n>{})));
|
||||
auto gmem_a_layout = make_layout(select<0, 2>(shape_mnk));
|
||||
auto gmem_b_layout = make_layout(select<1, 2>(shape_mnk), GenColMajor{});
|
||||
auto gmem_c_layout = make_layout(select<0, 1>(shape_mnk));
|
||||
|
||||
using smem_a_layout_t = decltype(make_layout(make_shape(Int<m>{}, Int<k>{})));
|
||||
using smem_b_layout_t = decltype(make_layout(make_shape(Int<n>{}, Int<k>{}), GenColMajor{}));
|
||||
using smem_c_layout_t = decltype(make_layout(make_shape(Int<m>{}, Int<n>{})));
|
||||
auto smem_a_layout = make_layout(select<0, 2>(shape_mnk));
|
||||
auto smem_b_layout = make_layout(select<1, 2>(shape_mnk), GenColMajor{});
|
||||
auto smem_c_layout = make_layout(select<0, 1>(shape_mnk));
|
||||
|
||||
test_cooperative_gemm<gmem_a_layout_t,
|
||||
gmem_b_layout_t,
|
||||
gmem_c_layout_t,
|
||||
smem_a_layout_t,
|
||||
smem_b_layout_t,
|
||||
smem_c_layout_t,
|
||||
AutoVectorizingCopyWithAssumedAlignment<16>, // A
|
||||
AutoVectorizingCopyWithAssumedAlignment<16>, // B
|
||||
AutoVectorizingCopyWithAssumedAlignment<16>, // C
|
||||
thread_block_size,
|
||||
tiled_mma_t,
|
||||
16,
|
||||
test_cooperative_gemm<thread_block_size,
|
||||
max_vec_bits,
|
||||
value_type,
|
||||
value_type,
|
||||
value_type>();
|
||||
value_type>
|
||||
(gmem_a_layout,
|
||||
gmem_b_layout,
|
||||
gmem_c_layout,
|
||||
smem_a_layout,
|
||||
smem_b_layout,
|
||||
smem_c_layout,
|
||||
tiled_mma);
|
||||
}
|
||||
|
||||
TEST(SM70_CuTe_Volta, CooperativeGemm6_Half_MAA_SwizzledSmemLayouts) {
|
||||
using value_type = cutlass::half_t;
|
||||
|
||||
constexpr uint32_t m = 128;
|
||||
constexpr uint32_t n = 128;
|
||||
constexpr uint32_t k = 64;
|
||||
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
constexpr uint32_t max_vec_bits = 128;
|
||||
using value_type = cutlass::half_t;
|
||||
|
||||
using tiled_mma_t = TiledMMA<
|
||||
auto shape_mnk = make_shape(_128{}, _128{}, _64{});
|
||||
auto tiled_mma = TiledMMA<
|
||||
MMA_Atom<SM70_8x8x4_F16F16F16F16_TN>,
|
||||
Layout<Shape<_4, _4, _1>>
|
||||
>;
|
||||
>{};
|
||||
|
||||
using smem_a_atom_layout_t = decltype(
|
||||
composition(Swizzle<3,3,3>{},
|
||||
Layout<Shape < _8,_64>,
|
||||
Stride<_64, _1>>{}));
|
||||
using smem_b_atom_layout_t = decltype(
|
||||
composition(Swizzle<3,3,3>{},
|
||||
Layout<Shape <_64, _8>,
|
||||
Stride< _1,_64>>{}));
|
||||
using smem_c_atom_layout_t = decltype(make_layout(make_shape(Int<m>{}, Int<n>{}), GenRowMajor{}));
|
||||
auto smem_a_atom_layout = composition(Swizzle<3,3,3>{}, Layout<Shape < _8,_64>, Stride<_64, _1>>{});
|
||||
auto smem_b_atom_layout = composition(Swizzle<3,3,3>{}, Layout<Shape <_64, _8>, Stride< _1,_64>>{});
|
||||
auto smem_c_atom_layout = make_layout(select<0, 1>(shape_mnk), GenRowMajor{});
|
||||
|
||||
using gmem_a_layout_t = decltype(make_layout(make_shape(Int<m> {}, Int<k> {}), GenRowMajor{}));
|
||||
using gmem_b_layout_t = decltype(make_layout(make_shape(Int<n> {}, Int<k> {}), GenColMajor{}));
|
||||
using gmem_c_layout_t = decltype(make_layout(make_shape(Int<m> {}, Int<n> {}), GenRowMajor{}));
|
||||
auto gmem_a_layout = make_layout(select<0, 2>(shape_mnk), GenRowMajor{});
|
||||
auto gmem_b_layout = make_layout(select<1, 2>(shape_mnk), GenColMajor{});
|
||||
auto gmem_c_layout = make_layout(select<0, 1>(shape_mnk), GenRowMajor{});
|
||||
|
||||
using smem_a_atom_layout_t = smem_a_atom_layout_t;
|
||||
using smem_a_layout_t = decltype(tile_to_shape(
|
||||
smem_a_atom_layout_t{},
|
||||
make_shape(shape<0>(gmem_a_layout_t{}), shape<1>(gmem_a_layout_t{})))
|
||||
);
|
||||
auto smem_a_layout = tile_to_shape(
|
||||
smem_a_atom_layout,
|
||||
make_shape(shape<0>(gmem_a_layout), shape<1>(gmem_a_layout)));
|
||||
|
||||
// Transposed
|
||||
using smem_b_atom_layout_t = smem_b_atom_layout_t;
|
||||
using smem_b_layout_t = decltype(tile_to_shape(
|
||||
smem_b_atom_layout_t{},
|
||||
make_shape(shape<0>(gmem_b_layout_t{}), shape<1>(gmem_b_layout_t{})))
|
||||
);
|
||||
auto smem_b_layout = tile_to_shape(
|
||||
smem_b_atom_layout,
|
||||
make_shape(shape<0>(gmem_b_layout), shape<1>(gmem_b_layout)));
|
||||
|
||||
using smem_c_atom_layout_t = smem_c_atom_layout_t;
|
||||
using smem_c_layout_t = decltype(tile_to_shape(
|
||||
smem_c_atom_layout_t{},
|
||||
make_shape(shape<0>(gmem_c_layout_t{}), shape<1>(gmem_c_layout_t{})))
|
||||
);
|
||||
auto smem_c_layout = tile_to_shape(
|
||||
smem_c_atom_layout,
|
||||
make_shape(shape<0>(gmem_c_layout), shape<1>(gmem_c_layout)));
|
||||
|
||||
test_cooperative_gemm<gmem_a_layout_t,
|
||||
gmem_b_layout_t,
|
||||
gmem_c_layout_t,
|
||||
smem_a_layout_t,
|
||||
smem_b_layout_t,
|
||||
smem_c_layout_t,
|
||||
AutoVectorizingCopyWithAssumedAlignment<128>, // A
|
||||
AutoVectorizingCopyWithAssumedAlignment<128>, // B
|
||||
AutoVectorizingCopyWithAssumedAlignment<128>, // C
|
||||
thread_block_size,
|
||||
tiled_mma_t,
|
||||
128,
|
||||
test_cooperative_gemm<thread_block_size,
|
||||
max_vec_bits,
|
||||
value_type,
|
||||
value_type,
|
||||
value_type>();
|
||||
value_type>
|
||||
(gmem_a_layout,
|
||||
gmem_b_layout,
|
||||
gmem_c_layout,
|
||||
smem_a_layout,
|
||||
smem_b_layout,
|
||||
smem_c_layout,
|
||||
tiled_mma);
|
||||
}
|
||||
|
||||
TEST(SM70_CuTe_Volta, CooperativeGemm7_TransformNegate_FMA) {
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
constexpr uint32_t max_vec_bits = 64;
|
||||
using TA = float;
|
||||
using TB = float;
|
||||
using TC = double;
|
||||
|
||||
constexpr uint32_t m = 32;
|
||||
constexpr uint32_t n = 32;
|
||||
constexpr uint32_t k = 32;
|
||||
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
|
||||
using tiled_mma_t = TiledMMA<
|
||||
auto shape_mnk = make_shape(_32{}, _32{}, _32{});
|
||||
auto tiled_mma = TiledMMA<
|
||||
MMA_Atom<UniversalFMA<TC, TA, TB, TC>>,
|
||||
Layout<Shape<_16, _8, _1>>
|
||||
>;
|
||||
>{};
|
||||
|
||||
auto aload = cute::negate {};
|
||||
auto bload = cute::negate {};
|
||||
auto cload = cute::negate {};
|
||||
auto cstore = cute::negate {};
|
||||
|
||||
test_cooperative_gemm_col_major_layout<m, n, k, thread_block_size, tiled_mma_t, 64, TA, TB, TC>(
|
||||
aload, bload, cload, cstore);
|
||||
test_cooperative_gemm_col_major_layout<thread_block_size, max_vec_bits, TA, TB, TC>(
|
||||
shape_mnk, tiled_mma, aload, bload, cload, cstore);
|
||||
}
|
||||
|
||||
TEST(SM70_CuTe_Volta, CooperativeGemm7_TransformNegate_MMA) {
|
||||
using value_type = cutlass::half_t;
|
||||
|
||||
constexpr uint32_t m = 32;
|
||||
constexpr uint32_t n = 32;
|
||||
constexpr uint32_t k = 32;
|
||||
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
using value_type = cutlass::half_t;
|
||||
|
||||
using tiled_mma_t = TiledMMA<
|
||||
auto shape_mnk = make_shape(_32{}, _32{}, _32{});
|
||||
auto tiled_mma = TiledMMA<
|
||||
MMA_Atom<SM70_8x8x4_F16F16F16F16_TN>,
|
||||
Layout<Shape<_4, _4, _1>>
|
||||
>;
|
||||
>{};
|
||||
|
||||
auto aload = cute::negate {};
|
||||
auto bload = cute::negate {};
|
||||
auto cload = cute::negate {};
|
||||
auto cstore = cute::negate {};
|
||||
|
||||
test_cooperative_gemm_col_major_layout<m, n, k, thread_block_size, tiled_mma_t, value_type>(
|
||||
aload, bload, cload, cstore);
|
||||
test_cooperative_gemm_col_major_layout<thread_block_size, value_type>(
|
||||
shape_mnk, tiled_mma, aload, bload, cload, cstore);
|
||||
}
|
||||
|
||||
template<class ConstantType>
|
||||
@ -398,26 +338,25 @@ struct convert_to {
|
||||
};
|
||||
|
||||
TEST(SM70_CuTe_Volta, CooperativeGemm7_TransformCustomOp_FMA) {
|
||||
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
constexpr uint32_t max_vec_bits = 64;
|
||||
|
||||
using TA = float;
|
||||
using TB = float;
|
||||
using TC = double;
|
||||
|
||||
constexpr uint32_t m = 32;
|
||||
constexpr uint32_t n = 32;
|
||||
constexpr uint32_t k = 32;
|
||||
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
|
||||
using tiled_mma_t = TiledMMA<
|
||||
auto shape_mnk = make_shape(_32{}, _32{}, _32{});
|
||||
auto tiled_mma = TiledMMA<
|
||||
MMA_Atom<UniversalFMA<TC, TA, TB, TC>>,
|
||||
Layout<Shape<_16, _8, _1>>
|
||||
>;
|
||||
>{};
|
||||
|
||||
auto aload = increment_by_x<float>{1.111f};
|
||||
auto bload = convert_to<float, double> {};
|
||||
auto cload = cute::negate {};
|
||||
auto cstore = cute::negate {};
|
||||
|
||||
test_cooperative_gemm_col_major_layout<m, n, k, thread_block_size, tiled_mma_t, 64, TA, TB, TC>(
|
||||
aload, bload, cload, cstore);
|
||||
test_cooperative_gemm_col_major_layout<thread_block_size, max_vec_bits, TA, TB, TC>(
|
||||
shape_mnk, tiled_mma, aload, bload, cload, cstore);
|
||||
}
|
||||
|
||||
@ -67,7 +67,6 @@ kernel(GmemTensor gC, RmemTiler tiler, CopyPolicy policy)
|
||||
|
||||
// NOTE: only 1 thread, this thread produce a block of 8x8 output. The fringe will not be touched.
|
||||
//copy(rC, tCgC); // Enable auto-vectorization if static
|
||||
//copy_vec<T>(rC, tCgC); // Disable auto-vectorization always
|
||||
copy(policy, rC, tCgC); // Use a policy to establish vectorization assumptions
|
||||
}
|
||||
|
||||
|
||||
@ -26,52 +26,30 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
add_custom_target(
|
||||
cutlass_test_unit_gemm_device
|
||||
DEPENDS
|
||||
cutlass_test_unit_gemm_device_simt
|
||||
cutlass_test_unit_gemm_device_tensorop_sm70
|
||||
cutlass_test_unit_gemm_device_tensorop_sm75
|
||||
cutlass_test_unit_gemm_device_tensorop_f16_sm80
|
||||
cutlass_test_unit_gemm_device_tensorop_f32_sm80
|
||||
cutlass_test_unit_gemm_device_tensorop_f32_tf32_sm80
|
||||
cutlass_test_unit_gemm_device_tensorop_f64
|
||||
cutlass_test_unit_gemm_device_tensorop_s32_sm80
|
||||
cutlass_test_unit_gemm_device_wmma
|
||||
cutlass_test_unit_gemm_device_tensorop_planar_complex
|
||||
cutlass_test_unit_gemm_device_sparse_tensorop_sm80
|
||||
cutlass_test_unit_gemv_device
|
||||
cutlass_test_unit_gemm_device_tensorop_sm90
|
||||
cutlass_test_unit_sparse_gemm_device_tensorop_sm90
|
||||
cutlass_test_unit_gemm_device_tensorop_cluster_multicast_sm90
|
||||
)
|
||||
add_custom_target(cutlass_test_unit_gemm_device)
|
||||
add_custom_target(test_unit_gemm_device)
|
||||
|
||||
add_custom_target(
|
||||
test_unit_gemm_device
|
||||
DEPENDS
|
||||
test_unit_gemm_device_simt
|
||||
test_unit_gemm_device_tensorop_sm70
|
||||
test_unit_gemm_device_tensorop_sm75
|
||||
test_unit_gemm_device_tensorop_f16_sm80
|
||||
test_unit_gemm_device_tensorop_f32_sm80
|
||||
test_unit_gemm_device_tensorop_f32_tf32_sm80
|
||||
test_unit_gemm_device_tensorop_f64
|
||||
test_unit_gemm_device_tensorop_s32_sm80
|
||||
test_unit_gemm_device_wmma
|
||||
test_unit_gemm_device_tensorop_planar_complex
|
||||
test_unit_gemm_device_sparse_tensorop_sm80
|
||||
test_unit_gemv_device
|
||||
test_unit_gemm_device_tensorop_sm90
|
||||
)
|
||||
################################################################################
|
||||
|
||||
add_custom_target(
|
||||
cutlass_test_unit_gemm_device_sm90
|
||||
DEPENDS
|
||||
cutlass_test_unit_gemm_device_tensorop_sm90
|
||||
cutlass_test_unit_gemm_device_tensorop_cluster_multicast_sm90
|
||||
)
|
||||
function(cutlass_test_unit_gemm_device_add_deps NAME)
|
||||
string(REGEX REPLACE "^cutlass_" "" TEST_NAME "${NAME}")
|
||||
add_dependencies(cutlass_test_unit_gemm_device ${NAME})
|
||||
add_dependencies(test_unit_gemm_device ${TEST_NAME})
|
||||
endfunction()
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
function(cutlass_test_unit_gemm_device_add_executable NAME)
|
||||
cutlass_test_unit_add_executable(${NAME} ${ARGN} DO_NOT_LOWERCASE_TEST_NAME)
|
||||
cutlass_test_unit_gemm_device_add_deps(${NAME})
|
||||
endfunction()
|
||||
|
||||
function(cutlass_test_unit_gemm_device_add_executable_split_file NAME)
|
||||
cutlass_test_unit_add_executable_split_file(${NAME} ${ARGN} DO_NOT_LOWERCASE_TEST_NAME)
|
||||
cutlass_test_unit_gemm_device_add_deps(${NAME})
|
||||
endfunction()
|
||||
|
||||
################################################################################
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable(
|
||||
cutlass_test_unit_gemm_device_simt
|
||||
|
||||
BATCH_SOURCES ON
|
||||
@ -126,7 +104,9 @@ cutlass_test_unit_add_executable(
|
||||
gemm_splitk_simt_sm50.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
list(APPEND CUTLASS_TEST_UNIT_GEMM_DEVICE_LIST cutlass_test_unit_gemm_device_simt)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable(
|
||||
cutlass_test_unit_gemm_device_simt_3x
|
||||
|
||||
BATCH_SOURCES ON
|
||||
@ -139,8 +119,7 @@ cutlass_test_unit_add_executable(
|
||||
sm61_gemm_s8_s8_s32_simt.cu
|
||||
)
|
||||
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_add_executable(
|
||||
cutlass_test_unit_gemm_device_tensorop_sm70
|
||||
|
||||
BATCH_SOURCES ON
|
||||
@ -159,7 +138,7 @@ cutlass_test_unit_add_executable(
|
||||
gemm_splitk_tensor_op_sm70.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_add_executable(
|
||||
cutlass_test_unit_gemm_device_tensorop_sm75
|
||||
|
||||
BATCH_SOURCES ON
|
||||
@ -204,7 +183,7 @@ cutlass_test_unit_add_executable(
|
||||
|
||||
)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_add_executable(
|
||||
cutlass_test_unit_gemm_device_tensorop_f16_sm80
|
||||
|
||||
BATCH_SOURCES ON
|
||||
@ -214,7 +193,7 @@ cutlass_test_unit_add_executable(
|
||||
gemm_f16n_f16t_f16t_tensor_op_f16_slicedk_sm80.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_add_executable(
|
||||
cutlass_test_unit_gemm_device_tensorop_f32_sm80
|
||||
|
||||
BATCH_SOURCES ON
|
||||
@ -236,7 +215,7 @@ cutlass_test_unit_add_executable(
|
||||
gemm_f16n_f16n_f16n_direct_store_tensor_op_f32_sm80.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_add_executable(
|
||||
cutlass_test_unit_gemm_device_tensorop_f32_sm80_3x
|
||||
|
||||
sm80_gemm_s8_s8_s32_tensor_op.cu
|
||||
@ -245,7 +224,7 @@ cutlass_test_unit_add_executable(
|
||||
)
|
||||
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_add_executable(
|
||||
cutlass_test_unit_gemm_device_mixed_input_tensorop_sm80
|
||||
|
||||
BATCH_SOURCES ON
|
||||
@ -286,13 +265,14 @@ cutlass_test_unit_add_executable(
|
||||
gemm_universal_s8t_s4n_s8t_mixed_input_tensor_op_s32_sm80.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_add_executable(
|
||||
cutlass_test_unit_gemm_device_tensorop_sm90
|
||||
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 4
|
||||
|
||||
sm90_gemm_f16_f16_f16_tensor_op.cu
|
||||
sm90_gett_f16_f16_f16_tensor_op.cu
|
||||
sm90_gemm_bf16_bf16_bf16_tensor_op_f32.cu
|
||||
sm90_gemm_s8_s8_s8_tensor_op_s32.cu
|
||||
sm90_gemm_tf32_tf32_f32_tensor_op_f32.cu
|
||||
@ -302,7 +282,7 @@ cutlass_test_unit_add_executable(
|
||||
sm90_gemm_f8_f8_f8_tensor_op_fp32.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_add_executable(
|
||||
cutlass_test_unit_gemm_device_tensorop_sm90_stream_k
|
||||
|
||||
sm90_gemm_stream_k_scheduler.cu
|
||||
@ -311,7 +291,7 @@ cutlass_test_unit_add_executable(
|
||||
)
|
||||
|
||||
# Alignment tests
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_add_executable(
|
||||
cutlass_test_unit_gemm_device_tensorop_alignx_sm90
|
||||
|
||||
BATCH_SOURCES ON
|
||||
@ -336,14 +316,14 @@ cutlass_test_unit_add_executable(
|
||||
)
|
||||
|
||||
# Ptr Array test
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_add_executable(
|
||||
cutlass_test_unit_gemm_device_tensorop_sm90_ptr_array
|
||||
sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu
|
||||
sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array_pingpong.cu
|
||||
)
|
||||
|
||||
# Group Gemm test
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_add_executable(
|
||||
cutlass_test_unit_gemm_device_tensorop_sm90_group_gemm
|
||||
sm90_gemm_f16_f16_f16_tensor_op_f32_group_gemm.cu
|
||||
sm90_gemm_f16_f16_f16_tensor_op_f32_group_gemm_pingpong.cu
|
||||
@ -351,31 +331,25 @@ cutlass_test_unit_add_executable(
|
||||
|
||||
# Sparse tests
|
||||
# Sparse kernels trigger an ICE in gcc 7.5
|
||||
if (NOT (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 8.0))
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_sparse_gemm_device_tensorop_sm90
|
||||
if (NOT (CUTLASS_GNU_HOST_COMPILE AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 8.0))
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable(
|
||||
cutlass_test_unit_sparse_gemm_device_tensorop_sm90
|
||||
|
||||
# No batching of source to control compiler memory usage
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 1
|
||||
|
||||
sm90_sparse_gemm_s8_s8_s32_tensor_op_s32.cu
|
||||
sm90_sparse_gemm_f8_f8_f32_tensor_op_f32.cu
|
||||
sm90_sparse_gemm_f16_f16_f32_tensor_op_f32.cu
|
||||
sm90_sparse_gemm_tf32_tf32_f32_tensor_op_f32.cu
|
||||
)
|
||||
|
||||
# No batching of source to control compiler memory usage
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 1
|
||||
|
||||
sm90_sparse_gemm_s8_s8_s32_tensor_op_s32.cu
|
||||
sm90_sparse_gemm_f8_f8_f32_tensor_op_f32.cu
|
||||
sm90_sparse_gemm_f16_f16_f32_tensor_op_f32.cu
|
||||
sm90_sparse_gemm_tf32_tf32_f32_tensor_op_f32.cu
|
||||
)
|
||||
else()
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_sparse_gemm_device_tensorop_sm90
|
||||
|
||||
# No batching of source to control compiler memory usage
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 1
|
||||
)
|
||||
endif()
|
||||
|
||||
# Fused epilogue tests
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_add_executable(
|
||||
cutlass_test_unit_gemm_device_tensorop_epilogue_fusion_sm90
|
||||
|
||||
BATCH_SOURCES ON
|
||||
@ -400,7 +374,7 @@ cutlass_test_unit_add_executable(
|
||||
sm90_gemm_f8_f8_bf16_tensor_op_fp32_evt.cu
|
||||
sm90_gemm_f8_f8_f32_tensor_op_f32_cluster_warpspecialized_cooperative_evt.cu
|
||||
)
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_add_executable(
|
||||
cutlass_test_unit_gemm_device_tensorop_cluster_multicast_sm90
|
||||
|
||||
BATCH_SOURCES ON
|
||||
@ -412,7 +386,7 @@ cutlass_test_unit_add_executable(
|
||||
sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu
|
||||
sm90_gemm_f8_f8_f32_tensor_op_f32_cluster_warpspecialized_cooperative.cu
|
||||
)
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_add_executable(
|
||||
cutlass_test_unit_gemm_device_tensorop_gmma_rs_warpspecialized_sm90
|
||||
|
||||
BATCH_SOURCES ON
|
||||
@ -423,7 +397,7 @@ cutlass_test_unit_add_executable(
|
||||
sm90_gemm_f16_f16_f32_tensor_op_f32_rs_cluster_warpspecialized_cooperative.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_add_executable(
|
||||
cutlass_test_unit_gemm_device_tensorop_f32_tf32_sm80
|
||||
|
||||
BATCH_SOURCES ON
|
||||
@ -443,7 +417,7 @@ cutlass_test_unit_add_executable(
|
||||
sm80_gemm_f16_f16_f32_tensor_op_f32.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_add_executable(
|
||||
cutlass_test_unit_gemm_device_tensorop_f64
|
||||
|
||||
BATCH_SOURCES ON
|
||||
@ -471,7 +445,7 @@ cutlass_test_unit_add_executable(
|
||||
gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm90.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_add_executable(
|
||||
cutlass_test_unit_gemm_device_tensorop_s32_sm80
|
||||
|
||||
BATCH_SOURCES ON
|
||||
@ -493,7 +467,7 @@ cutlass_test_unit_add_executable(
|
||||
gemm_s4n_s4t_s4n_tensor_op_s32_sm80.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_add_executable(
|
||||
cutlass_test_unit_gemm_device_wmma
|
||||
|
||||
BATCH_SOURCES ON
|
||||
@ -551,7 +525,7 @@ cutlass_test_unit_add_executable(
|
||||
gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32_sm70.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_add_executable(
|
||||
cutlass_test_unit_gemm_device_tensorop_planar_complex
|
||||
|
||||
BATCH_SOURCES ON
|
||||
@ -562,7 +536,7 @@ cutlass_test_unit_add_executable(
|
||||
gemm_planar_complex_f16_f16_f32_tensor_op_sm80.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_add_executable(
|
||||
cutlass_test_unit_gemm_device_tensorop_sm89
|
||||
|
||||
BATCH_SOURCES ON
|
||||
@ -574,7 +548,7 @@ cutlass_test_unit_add_executable(
|
||||
# gemm_f8t_f8n_f8t_tensor_op_f32_sparse_sm89.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_add_executable(
|
||||
cutlass_test_unit_gemm_device_grouped
|
||||
|
||||
BATCH_SOURCES ON
|
||||
@ -583,7 +557,7 @@ cutlass_test_unit_add_executable(
|
||||
gemm_grouped_sm80.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_add_executable(
|
||||
cutlass_test_unit_gemm_device_grouped_scheduler
|
||||
|
||||
BATCH_SOURCES ON
|
||||
@ -592,7 +566,7 @@ cutlass_test_unit_add_executable(
|
||||
gemm_grouped_scheduler_sm80.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_add_executable(
|
||||
cutlass_test_unit_gemm_device_grouped_rank_2k_scheduler
|
||||
|
||||
BATCH_SOURCES ON
|
||||
@ -601,7 +575,7 @@ cutlass_test_unit_add_executable(
|
||||
rank_2k_grouped_scheduler_sm80.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_add_executable(
|
||||
cutlass_test_unit_gemm_device_sparse_tensorop_sm80
|
||||
|
||||
BATCH_SOURCES ON
|
||||
@ -622,7 +596,7 @@ cutlass_test_unit_add_executable(
|
||||
gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_add_executable(
|
||||
cutlass_test_unit_gemv_device
|
||||
|
||||
BATCH_SOURCES ON
|
||||
@ -631,42 +605,22 @@ cutlass_test_unit_add_executable(
|
||||
gemv.cu
|
||||
)
|
||||
|
||||
if (NOT CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
if (CUTLASS_NVCC_DEVICE_COMPILE)
|
||||
|
||||
add_dependencies(
|
||||
cutlass_test_unit_gemm_device
|
||||
cutlass_test_unit_gemm_device_gemm_with_fused_epilogue_tensorop
|
||||
cutlass_test_unit_gemm_device_add_executable(
|
||||
cutlass_test_unit_gemm_device_gemm_with_fused_epilogue_tensorop
|
||||
|
||||
gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu
|
||||
gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu
|
||||
|
||||
gemm_with_reduction_f16t_f16n_f16n_tensorop_f32_sm80.cu
|
||||
)
|
||||
|
||||
add_dependencies(
|
||||
test_unit_gemm_device
|
||||
test_unit_gemm_device_gemm_with_fused_epilogue_tensorop
|
||||
)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_gemm_with_fused_epilogue_tensorop
|
||||
|
||||
gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu
|
||||
gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu
|
||||
|
||||
gemm_with_reduction_f16t_f16n_f16n_tensorop_f32_sm80.cu
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
if (NOT CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
if (CUTLASS_NVCC_DEVICE_COMPILE)
|
||||
|
||||
add_dependencies(
|
||||
cutlass_test_unit_gemm_device
|
||||
cutlass_test_unit_gemm_device_blas3
|
||||
)
|
||||
|
||||
add_dependencies(
|
||||
test_unit_gemm_device
|
||||
test_unit_gemm_device_blas3
|
||||
)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_add_executable(
|
||||
cutlass_test_unit_gemm_device_blas3
|
||||
|
||||
BATCH_SOURCES ON
|
||||
@ -833,7 +787,7 @@ cutlass_test_unit_add_executable(
|
||||
hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_add_executable(
|
||||
cutlass_test_unit_gemm_device_grouped_blas3
|
||||
|
||||
BATCH_SOURCES ON
|
||||
@ -858,13 +812,12 @@ cutlass_test_unit_add_executable(
|
||||
|
||||
endif()
|
||||
|
||||
if (NOT CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
if (CUTLASS_NVCC_DEVICE_COMPILE)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_broadcast
|
||||
|
||||
gemm_f16t_f16n_f16t_tensor_op_f16_broadcast_sm80.cu
|
||||
)
|
||||
cutlass_test_unit_gemm_device_add_executable(
|
||||
cutlass_test_unit_gemm_device_broadcast
|
||||
gemm_f16t_f16n_f16t_tensor_op_f16_broadcast_sm80.cu
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
|
||||
@ -39,6 +39,7 @@
|
||||
#include <sstream>
|
||||
#include <algorithm>
|
||||
#include <random>
|
||||
#include <numeric> // std::lcm
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
@ -55,6 +56,7 @@
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/transform/device/transform_universal_adapter.hpp"
|
||||
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
|
||||
#include "cutlass/detail/collective.hpp"
|
||||
|
||||
#include "testbed_utils.h"
|
||||
|
||||
@ -151,6 +153,12 @@ struct ElementScalarType<Gemm, Default, std::void_t<typename Gemm::EpilogueOutpu
|
||||
using Type = typename Gemm::EpilogueOutputOp::ElementScalar;
|
||||
};
|
||||
|
||||
template<class CollectiveEpilogue, class = void>
|
||||
struct IsSfdEpi : cute::false_type {};
|
||||
|
||||
template<class CollectiveEpilogue>
|
||||
struct IsSfdEpi<CollectiveEpilogue, cute::void_t<typename CollectiveEpilogue::FusionCallbacks::Operation::GmemLayoutTagScalefactor>> : cute::true_type {};
|
||||
|
||||
// The maximum swizzle size to use
|
||||
//
|
||||
// This class, like Splits above makes it harder to confuse
|
||||
@ -1140,7 +1148,6 @@ struct HostCollectiveEpilogue {
|
||||
static constexpr bool IsAbsMaxEnabledAux = IsAuxOutEnabled && FusionOp::IsAbsMaxSupported &&
|
||||
(cute::is_same_v<ElementAux, cutlass::float_e4m3_t> ||
|
||||
cute::is_same_v<ElementAux, cutlass::float_e5m2_t>);
|
||||
|
||||
using Arguments = typename Gemm::GemmKernel::EpilogueArguments;
|
||||
|
||||
/// Initialization
|
||||
@ -1454,6 +1461,22 @@ struct HostCollectiveEpilogue {
|
||||
|
||||
bool passed = equality_check(reference_D.host_view(), tensor_D.host_view());
|
||||
if(!passed) {
|
||||
#if 0
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
auto ref = cute::make_tensor(detail::make_iterator(reference_D.host_data()),
|
||||
cute::make_layout(cute::make_shape(M, N, L), stride_d));
|
||||
auto comp = cute::make_tensor(detail::make_iterator(tensor_D.host_data()),
|
||||
cute::make_layout(cute::make_shape(M, N, L), stride_d));
|
||||
for(int i=0; i<M; i++) {
|
||||
for(int j=0; j<N; j++) {
|
||||
for(int l=0; l<L; l++) {
|
||||
if(static_cast<float>(ElementD(ref(i, j, l))) != static_cast<float>((ElementD(comp(i, j, l))))) {
|
||||
printf("<m %d, n %d, l %d> ref: %f comp: %f\n", i, j, l, static_cast<float>(ElementD(ref(i, j, l))), static_cast<float>((ElementD(comp(i, j, l)))));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
std::cout<<"D is incorrect"<<std::endl;
|
||||
}
|
||||
|
||||
@ -1575,9 +1598,12 @@ struct HostCollectiveEpilogue {
|
||||
}
|
||||
else {
|
||||
fusion_args.alpha = alpha.at(coord_0);
|
||||
fusion_args.beta = beta.at(coord_0);
|
||||
fusion_args.alpha_ptr = alpha.device_data();
|
||||
fusion_args.beta_ptr = beta.device_data(); // if vector_scale_mode is true this is nullptr
|
||||
// Only initializing beta/beta_ptr for non-void source
|
||||
if constexpr (not cute::is_void_v<typename kernel::ElementC>) {
|
||||
fusion_args.beta = beta.at(coord_0);
|
||||
fusion_args.beta_ptr = beta.device_data(); // if vector_scale_mode is true this is nullptr
|
||||
}
|
||||
|
||||
if constexpr (IsPerRowScaleEnabled) {
|
||||
int32_t m_stride = vector_scale_mode == VectorScale::ENABLED ? 1 : 0;
|
||||
@ -1620,6 +1646,7 @@ struct HostCollectiveEpilogue {
|
||||
// example of how to set kernel activation arguments
|
||||
// see ActivationFunctor::Arguments in activation.h for definition
|
||||
// if Arguments doesn't exist then fusion_args.activation is empty
|
||||
|
||||
if constexpr (cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::ScaledGELU_taylor<ElementCompute>>) {
|
||||
fusion_args.activation.scale = ElementCompute(1);
|
||||
}
|
||||
@ -1713,6 +1740,7 @@ struct HostCollectiveEpilogue {
|
||||
decltype(Vbeta),
|
||||
ActivationFunctor,
|
||||
cutlass::plus<ElementCompute>
|
||||
, false /*PerColumnBias_*/
|
||||
> epilogue_params{};
|
||||
|
||||
epilogue_params.C = C;
|
||||
@ -1779,6 +1807,7 @@ struct TestbedImpl {
|
||||
using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule;
|
||||
// All Collective MMA operands are defined by HostCollectiveMainloopType based on the schedule type
|
||||
using HostCollectiveMainloopType = HostCollectiveMainloop<ScheduleType, Gemm, ElementA, ElementB>;
|
||||
|
||||
using CollectiveEpilogue = cute::conditional_t<IsDefaultEpilogue<typename Gemm::GemmKernel::CollectiveEpilogue>::value || force_legacy_epilogue,
|
||||
HostCollectiveDefaultEpilogue<Gemm>,
|
||||
HostCollectiveEpilogue<Gemm>>;
|
||||
@ -2004,7 +2033,7 @@ struct TestbedImpl {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
catch (std::exception const& e) {
|
||||
catch ([[maybe_unused]] std::exception const& e) {
|
||||
CUTLASS_TRACE_HOST("TestbedImpl::run: this->initialize threw an exception: " << e.what());
|
||||
throw;
|
||||
}
|
||||
|
||||
@ -346,7 +346,7 @@ struct HostCollectiveMainloop {
|
||||
stride_b_host.clear();
|
||||
|
||||
auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1);
|
||||
L = max(problem_shapes.groups(), L);
|
||||
L = cutlass::platform::max(problem_shapes.groups(), L);
|
||||
|
||||
for(int32_t i = 0; i < L; ++i) {
|
||||
auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1);
|
||||
@ -380,7 +380,7 @@ struct HostCollectiveMainloop {
|
||||
|
||||
Arguments to_args(ProblemShapeType problem_shapes) {
|
||||
auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1);
|
||||
L = max(problem_shapes.groups(), L);
|
||||
L = cutlass::platform::max(problem_shapes.groups(), L);
|
||||
|
||||
std::vector<ElementA *> ptr_A_host(L);
|
||||
std::vector<ElementB *> ptr_B_host(L);
|
||||
@ -587,7 +587,7 @@ struct HostCollectiveDefaultEpilogue {
|
||||
stride_d_host.clear();
|
||||
|
||||
auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1);
|
||||
L = max(problem_shapes.groups(), L);
|
||||
L = cutlass::platform::max(problem_shapes.groups(), L);
|
||||
|
||||
for (int32_t i = 0; i < L; ++i) {
|
||||
auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1);
|
||||
@ -649,7 +649,7 @@ struct HostCollectiveDefaultEpilogue {
|
||||
ElementScalar beta,
|
||||
int batch) {
|
||||
auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1);
|
||||
L = max(problem_shapes.groups(), L);
|
||||
L = cutlass::platform::max(problem_shapes.groups(), L);
|
||||
|
||||
tensors_D[batch].sync_host();
|
||||
EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_C[batch].host_view()), 0);
|
||||
@ -678,7 +678,7 @@ struct HostCollectiveDefaultEpilogue {
|
||||
|
||||
Arguments to_args(ProblemShapeType problem_shapes) {
|
||||
auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1);
|
||||
L = max(problem_shapes.groups(), L);
|
||||
L = cutlass::platform::max(problem_shapes.groups(), L);
|
||||
|
||||
std::vector<ElementC *> ptr_C_host(L);
|
||||
std::vector<ElementD *> ptr_D_host(L);
|
||||
@ -724,8 +724,8 @@ struct HostCollectiveDefaultEpilogue {
|
||||
//
|
||||
// Allocate the GEMM workspace
|
||||
//
|
||||
auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1);
|
||||
L = max(problem_shapes.groups(), L);
|
||||
auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1);
|
||||
L = std::max(problem_shapes.groups(), L);
|
||||
|
||||
auto coord_0 = cutlass::make_Coord(0);
|
||||
auto C = cute::make_tensor(detail::make_iterator(tensors_C[batch].host_data()),
|
||||
@ -905,9 +905,8 @@ struct HostCollectiveEpilogue {
|
||||
references_D.clear();
|
||||
stride_c_host.clear();
|
||||
stride_d_host.clear();
|
||||
|
||||
auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1);
|
||||
L = max(problem_shapes.groups(), L);
|
||||
L = std::max(problem_shapes.groups(), L);
|
||||
|
||||
for (int32_t i = 0; i < L; ++i) {
|
||||
auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1);
|
||||
@ -1118,7 +1117,6 @@ struct HostCollectiveEpilogue {
|
||||
passed &= tmp;
|
||||
}
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
@ -1189,7 +1187,7 @@ struct HostCollectiveEpilogue {
|
||||
Arguments to_args(ProblemShapeType problem_shapes) {
|
||||
auto coord_0 = cutlass::make_Coord(0);
|
||||
auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1);
|
||||
L = max(problem_shapes.groups(), L);
|
||||
L = std::max(problem_shapes.groups(), L);
|
||||
|
||||
std::vector<ElementC *> ptr_C_host(L);
|
||||
std::vector<ElementD *> ptr_D_host(L);
|
||||
@ -1220,19 +1218,22 @@ struct HostCollectiveEpilogue {
|
||||
device_tensors_Aux.copy_from_host(ptr_Aux_host.data());
|
||||
}
|
||||
|
||||
auto device_tensors_C_ptr = cute::is_void_v<typename kernel::ElementC> ? nullptr :
|
||||
reinterpret_cast<typename kernel::ElementC const**>(device_tensors_C.get());
|
||||
|
||||
Arguments arguments;
|
||||
if constexpr (IsGroupGemm) {
|
||||
arguments =
|
||||
{
|
||||
{},
|
||||
device_tensors_C.get(), stride_c_device.get(), device_tensors_D.get(), stride_d_device.get()
|
||||
device_tensors_C_ptr, stride_c_device.get(), device_tensors_D.get(), stride_d_device.get()
|
||||
};
|
||||
}
|
||||
else {
|
||||
arguments =
|
||||
{
|
||||
{},
|
||||
device_tensors_C.get(), stride_c_host[0], device_tensors_D.get(), stride_d_host[0]
|
||||
device_tensors_C_ptr, stride_c_host[0], device_tensors_D.get(), stride_d_host[0]
|
||||
};
|
||||
}
|
||||
|
||||
@ -1252,7 +1253,9 @@ struct HostCollectiveEpilogue {
|
||||
fusion_args.beta = beta.at(coord_0);
|
||||
|
||||
fusion_args.alpha_ptr = alpha.device_data();
|
||||
fusion_args.beta_ptr = beta.device_data();
|
||||
// can_implement requires beta_ptr to not be set if its voidC
|
||||
fusion_args.beta_ptr = cute::is_void_v<typename kernel::ElementC> ? nullptr :
|
||||
beta.device_data();
|
||||
|
||||
if constexpr (IsScaleFactorEnabled) {
|
||||
fusion_args.scale_a = scale_A.at(coord_0);
|
||||
@ -1316,7 +1319,8 @@ struct HostCollectiveEpilogue {
|
||||
//
|
||||
// Allocate the GEMM workspace
|
||||
//
|
||||
auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1);
|
||||
auto problem_shape_MNKL = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1);
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
auto coord_0 = cutlass::make_Coord(0);
|
||||
auto C = cute::make_tensor(detail::make_iterator(tensors_C[batch].host_data()),
|
||||
cute::make_layout(cute::make_shape(M, N, 1), stride_c_host[batch]));
|
||||
@ -1338,7 +1342,6 @@ struct HostCollectiveEpilogue {
|
||||
cute::make_layout(cute::make_shape(M, N, cute::_1{}), cute::make_stride(cute::_1{}, cute::_0{}, M)));
|
||||
auto Vbeta = cute::make_tensor(detail::make_iterator(beta.host_data()),
|
||||
cute::make_layout(cute::make_shape(M, N, cute::_1{}), cute::make_stride(cute::_1{}, cute::_0{}, N)));
|
||||
|
||||
cutlass::reference::host::GettEpilogueParams<
|
||||
ElementScalar,
|
||||
ElementScalar,
|
||||
@ -1518,7 +1521,7 @@ struct TestbedImpl {
|
||||
{
|
||||
using namespace cute;
|
||||
auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1);
|
||||
L = max(problem_shapes.groups(), L);
|
||||
L = std::max(problem_shapes.groups(), L);
|
||||
|
||||
bool passed = true;
|
||||
for (int32_t i = 0; i < L; ++i) {
|
||||
@ -1760,7 +1763,7 @@ bool TestAll(double alpha = 1.0, double beta = 0.0, CheckEquality check_relative
|
||||
cutlass::DeviceAllocation<typename ProblemShapeType::UnderlyingProblemShape> problem_sizes_device;
|
||||
|
||||
for (int i = 0; i < batch; ++i) {
|
||||
problem_sizes_host.push_back({m, n, k});
|
||||
problem_sizes_host.push_back({m * ((i % 3) + 1), n * ((i % 4) + 1), k * ((i % 5) + 1)});
|
||||
}
|
||||
|
||||
problem_sizes_device.reset(problem_sizes_host.size());
|
||||
|
||||
184
test/unit/gemm/device/sm90_gett_f16_f16_f16_tensor_op.cu
Normal file
184
test/unit/gemm/device/sm90_gett_f16_f16_f16_tensor_op.cu
Normal file
@ -0,0 +1,184 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/host_vector.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
|
||||
#include "cutlass/util/reference/device/gett.hpp"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
using namespace cute;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gett_f16t_f16n_f16n_tensor_op_gmma_f16, 8x8x8x8x8x8) {
|
||||
|
||||
using BatModeStrides = int;
|
||||
|
||||
using RowModeStridesA = cute::Stride<int, int>;
|
||||
using RedModeStrides = cute::Stride<cute::_1, int>;
|
||||
|
||||
using ColModeStridesB = cute::Stride<int, int>;
|
||||
|
||||
using RowModeStridesC = cute::Stride<cute::_1, int>;
|
||||
using ColModeStridesC = cute::Stride<int, int>;
|
||||
|
||||
using StrideA = cute::Stride<RowModeStridesA, RedModeStrides, BatModeStrides>;
|
||||
using StrideB = cute::Stride<ColModeStridesB, RedModeStrides, BatModeStrides>;
|
||||
using StrideC = cute::Stride<RowModeStridesC, ColModeStridesC, BatModeStrides>;
|
||||
using StrideD = StrideC;
|
||||
|
||||
using TileShape = Shape<Shape<_8, _8>, Shape<_8, _8>, Shape<_8, _8>>;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, StrideA, 8,
|
||||
cutlass::half_t, StrideB, 8,
|
||||
cutlass::half_t,
|
||||
TileShape, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
cutlass::half_t, cutlass::half_t,
|
||||
cutlass::half_t, StrideC, 8,
|
||||
cutlass::half_t, StrideC, 8,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GettKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<Shape<int,int>,
|
||||
Shape<int,int>,
|
||||
Shape<int,int>,
|
||||
int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gett = cutlass::gemm::device::GemmUniversalAdapter<GettKernel>;
|
||||
|
||||
auto problem_shape = make_shape(
|
||||
make_shape(32,8),
|
||||
make_shape(32,4),
|
||||
make_shape(32,2),
|
||||
1
|
||||
);
|
||||
|
||||
auto [M, N, K, L] = problem_shape;
|
||||
|
||||
StrideA dA = make_stride(make_stride(64, 2048), make_stride(_1{}, 32), size(M) * size(K));
|
||||
StrideB dB = make_stride(make_stride(64, 2048), make_stride(_1{}, 32), size(N) * size(K));
|
||||
StrideC dC = make_stride(make_stride(_1{}, 32), make_stride(256, 8192), size(M) * size(N));
|
||||
StrideD dD = dC;
|
||||
|
||||
cutlass::half_t alpha = cutlass::half_t(1.0f);
|
||||
cutlass::half_t beta = cutlass::half_t(1.0f);
|
||||
|
||||
thrust::host_vector<cutlass::half_t> A_h(size(M) * size(K) * size(L));
|
||||
thrust::host_vector<cutlass::half_t> B_h(size(N) * size(K) * size(L));
|
||||
thrust::host_vector<cutlass::half_t> C_h(size(M) * size(N) * size(L));
|
||||
thrust::host_vector<cutlass::half_t> D_h(size(M) * size(N) * size(L));
|
||||
thrust::host_vector<cutlass::half_t> D_h_ref(size(M) * size(N) * size(L));
|
||||
|
||||
for (auto& a : A_h) a = cutlass::half_t(static_cast<int>(4 * (rand() / double(RAND_MAX) - 1)));
|
||||
for (auto& b : B_h) b = cutlass::half_t(static_cast<int>(4 * (rand() / double(RAND_MAX) - 1)));
|
||||
for (auto& c : C_h) c = cutlass::half_t(static_cast<int>(4 * (rand() / double(RAND_MAX) - 1)));
|
||||
for (auto& d : D_h) d = cutlass::half_t(-1);
|
||||
for (auto& d : D_h_ref) d = cutlass::half_t(-1);
|
||||
|
||||
thrust::device_vector<cutlass::half_t> A = A_h;
|
||||
thrust::device_vector<cutlass::half_t> B = B_h;
|
||||
thrust::device_vector<cutlass::half_t> C = C_h;
|
||||
thrust::device_vector<cutlass::half_t> D = D_h;
|
||||
thrust::device_vector<cutlass::half_t> D_ref = D_h_ref;
|
||||
|
||||
typename Gett::Arguments args {
|
||||
cutlass::gemm::GemmUniversalMode::kBatched,
|
||||
problem_shape,
|
||||
{A.data().get(), dA, B.data().get(), dB},
|
||||
{ {alpha, beta}, C.data().get(), dC, D.data().get(), dD}
|
||||
};
|
||||
|
||||
Gett gett;
|
||||
auto status = gett(args);
|
||||
EXPECT_TRUE(status == cutlass::Status::kSuccess);
|
||||
auto cuda_err = cudaDeviceSynchronize();
|
||||
|
||||
EXPECT_TRUE(cuda_err == cudaSuccess);
|
||||
|
||||
cutlass::reference::device::gett(
|
||||
problem_shape,
|
||||
A.data().get(), dA,
|
||||
B.data().get(), dB,
|
||||
cutlass::half_t(0.0f),
|
||||
C.data().get(), dC,
|
||||
D_ref.data().get(), dD,
|
||||
alpha, beta);
|
||||
|
||||
cuda_err = cudaDeviceSynchronize();
|
||||
EXPECT_TRUE(cuda_err == cudaSuccess);
|
||||
|
||||
bool passed = cutlass::reference::device::BlockCompareEqual(
|
||||
D.data().get(), D_ref.data().get(), D_ref.size());
|
||||
EXPECT_TRUE(passed);
|
||||
}
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
@ -33,7 +33,7 @@ add_custom_target(
|
||||
cutlass_test_unit_transform
|
||||
DEPENDS
|
||||
cutlass_test_unit_transform_threadblock
|
||||
cutlass_test_unit_transform_filter_format
|
||||
cutlass_test_unit_transform_kernel
|
||||
)
|
||||
|
||||
add_custom_target(
|
||||
|
||||
@ -45,6 +45,7 @@
|
||||
#include "cute/tensor.hpp" // cute::Tensor, cute::make_tensor, cute::print_tensor
|
||||
#include "cutlass/arch/arch.h" // cutlass::arch::Sm90
|
||||
#include "cutlass/cutlass.h" // cutlass::Status
|
||||
#include "cutlass/detail/collective.hpp"
|
||||
#include "cutlass/detail/layout.hpp" // cutlass::TagToStrideA_t
|
||||
#include "cutlass/fast_math.h" // cutlass::ceil_div, cutlass::round_up
|
||||
#include "cutlass/kernel_hardware_info.h" // cutlass::KernelHardwareInfo
|
||||
@ -219,9 +220,7 @@ public:
|
||||
// * EltA
|
||||
using ElementA = ElementA_;
|
||||
using ElementAUint = cute::uint_bit_t<cute::sizeof_bits_v<ElementA>>;
|
||||
static constexpr bool IsRuntimeDataTypeA = cute::is_same_v<ElementA, cutlass::type_erased_dynamic_float8_t> ||
|
||||
cute::is_same_v<ElementA, cutlass::type_erased_dynamic_float6_t> ||
|
||||
cute::is_same_v<ElementA, cutlass::type_erased_dynamic_float4_t>;
|
||||
static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4<ElementA>();
|
||||
using ArrayElementA = cute::conditional_t<IsRuntimeDataTypeA,
|
||||
cute::uint_bit_t<cute::sizeof_bits_v<ElementA>>,
|
||||
ElementA>;
|
||||
|
||||
@ -60,6 +60,7 @@
|
||||
#include "cutlass/util/packed_stride.hpp" // cutlass::make_cute_packed_stride
|
||||
#include "cutlass/util/reference/host/tensor_compare.h" // cutlass::reference::host::TensorEquals
|
||||
#include "cutlass/util/reference/host/tensor_fill.h" // cutlass::reference::host::TensorFillRandomUniform, TensorFillIdentity, TensorFillRandomGaussian, BlockFillSequential, TensorFill
|
||||
#include "cutlass/detail/collective.hpp"
|
||||
|
||||
#include "sm90_sparse_gemm_compressor_legacy.hpp" // Legacy host compressor
|
||||
#include "../../common/cutlass_unit_test.h" // CUTLASS UT, EXPECT_TRUE
|
||||
|
||||
@ -27,6 +27,6 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_transform_filter_format
|
||||
cutlass_test_unit_transform_kernel
|
||||
filter_format_transformer.cu
|
||||
)
|
||||
|
||||
@ -104,7 +104,7 @@ void run_test(int M, int N) {
|
||||
for (int n = 0; n < N; ++n) {
|
||||
auto diff = abs(static_cast<float>(output_ref.at({m, n}) - output.at({m, n})));
|
||||
mean_abs_diff += diff;
|
||||
max_abs_diff = max(max_abs_diff, diff);
|
||||
max_abs_diff = cutlass::platform::max(max_abs_diff, diff);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user