3.6.0 update (#2005)
* 3.6.0 update * doc and swap stuff --------- Co-authored-by: yuzhai <yuzhai@nvidia.com> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
@ -80,4 +80,3 @@ foreach(FUSION_GEMM_EXAMPLE
|
||||
add_dependencies(13_fused_two_gemms 13_${FUSION_GEMM_EXAMPLE})
|
||||
|
||||
endforeach()
|
||||
|
||||
|
||||
@ -59,11 +59,11 @@
|
||||
// Also, we don't check the index value is legal and index array point is valid
|
||||
// for the sake of the performance.
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <stdio.h>
|
||||
#include <time.h>
|
||||
#include <math.h>
|
||||
#include <assert.h>
|
||||
#include <cstdlib>
|
||||
#include <cstdio>
|
||||
#include <ctime>
|
||||
#include <cmath>
|
||||
#include <cassert>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
@ -33,11 +33,7 @@
|
||||
computing reference permutations of 4/5D tensors when source data is column-major.
|
||||
*/
|
||||
#pragma once
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cassert>
|
||||
#else
|
||||
#include "assert.h"
|
||||
#endif
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
|
||||
@ -30,8 +30,8 @@
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
#include <float.h>
|
||||
#include <stdio.h>
|
||||
#include <cfloat>
|
||||
#include <cstdio>
|
||||
#include <cmath>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -43,11 +43,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cassert>
|
||||
#else
|
||||
#include <assert.h>
|
||||
#endif
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/array.h"
|
||||
@ -57,12 +53,9 @@
|
||||
#include "cutlass/layout/vector.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/tensor_coord.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/transform/pitch_linear_thread_map.h"
|
||||
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/epilogue_base.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
@ -43,11 +43,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cassert>
|
||||
#else
|
||||
#include <assert.h>
|
||||
#endif
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/array.h"
|
||||
@ -57,16 +53,12 @@
|
||||
#include "cutlass/layout/vector.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/tensor_coord.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/transform/pitch_linear_thread_map.h"
|
||||
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/epilogue_base.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/thread/scale_type.h"
|
||||
|
||||
@ -550,7 +550,7 @@ public:
|
||||
|
||||
auto prologueV = [&](int blockN) {
|
||||
typename MM1::Mma::IteratorB iterator_V(
|
||||
typename MM1::IteratorB::Params{MM1::LayoutB(params.ldv[problem_idx])},
|
||||
typename MM1::IteratorB::Params{typename MM1::LayoutB(params.ldv[problem_idx])},
|
||||
params.ptr_V[problem_idx] + iter_key_start * params.ldv[problem_idx],
|
||||
{problem_size_1_k, problem_size_1_n},
|
||||
thread_id(),
|
||||
@ -719,7 +719,7 @@ public:
|
||||
}
|
||||
|
||||
typename MM1::Mma::IteratorB iterator_V(
|
||||
typename MM1::IteratorB::Params{MM1::LayoutB(params.ldv[problem_idx])},
|
||||
typename MM1::IteratorB::Params{typename MM1::LayoutB(params.ldv[problem_idx])},
|
||||
params.ptr_V[problem_idx] + iter_key_start * params.ldv[problem_idx],
|
||||
{problem_size_1_k, problem_size_1_n},
|
||||
thread_id(),
|
||||
@ -761,15 +761,15 @@ public:
|
||||
using EpilogueOutputOp = typename cutlass::epilogue::
|
||||
thread::MemoryEfficientAttentionNormalize<
|
||||
typename cutlass::platform::conditional<
|
||||
kIsLast,
|
||||
kIsLast::value,
|
||||
output_t,
|
||||
output_accum_t>::type,
|
||||
output_accum_t,
|
||||
DefaultOp::kCount,
|
||||
typename DefaultOp::ElementAccumulator,
|
||||
output_accum_t,
|
||||
kIsFirst,
|
||||
kIsLast,
|
||||
kIsFirst::value,
|
||||
kIsLast::value,
|
||||
cutlass::Array<ElementCompute, kQueriesPerBlock>>;
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::
|
||||
EpiloguePipelined<
|
||||
@ -777,7 +777,7 @@ public:
|
||||
typename MM1::Mma::Operator,
|
||||
DefaultEpilogue::kPartitionsK,
|
||||
typename cutlass::platform::conditional<
|
||||
kIsLast,
|
||||
kIsLast::value,
|
||||
typename MM1::OutputTileIterator,
|
||||
typename MM1::OutputTileIteratorAccum>::type,
|
||||
typename DefaultEpilogue::
|
||||
@ -795,7 +795,7 @@ public:
|
||||
int col = blockN * MM1::Mma::Shape::kN;
|
||||
auto source_iter = createOutputAccumIter(col);
|
||||
auto dest_iter = gemm_kernel_utils::call_conditional<
|
||||
kIsLast,
|
||||
kIsLast::value,
|
||||
decltype(createOutputIter),
|
||||
decltype(createOutputAccumIter)>::
|
||||
apply(createOutputIter, createOutputAccumIter, col);
|
||||
@ -817,8 +817,8 @@ public:
|
||||
}
|
||||
|
||||
if (kKeepOutputInRF) {
|
||||
const bool kIsFirst = true;
|
||||
const bool kIsLast = true;
|
||||
constexpr bool kIsFirst = true;
|
||||
constexpr bool kIsLast = true;
|
||||
using DefaultEpilogue = typename MM1::DefaultEpilogue;
|
||||
using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp;
|
||||
using ElementCompute = typename DefaultOp::ElementCompute;
|
||||
|
||||
@ -55,13 +55,14 @@
|
||||
#define DISPATCH_BOOL(BOOL_V, BOOL_NAME, F) \
|
||||
{ \
|
||||
if (BOOL_V) { \
|
||||
constexpr bool BOOL_NAME = true; \
|
||||
using BOOL_NAME = std::true_type; \
|
||||
F(); \
|
||||
} else { \
|
||||
constexpr bool BOOL_NAME = false; \
|
||||
using BOOL_NAME = std::false_type; \
|
||||
F(); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define DISPATCH_ARCHTAG(CC, func) \
|
||||
{ \
|
||||
if (CC >= 80) { \
|
||||
|
||||
@ -32,6 +32,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
#include <cinttypes>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
@ -85,8 +86,6 @@
|
||||
#include "gemm/mma_from_smem.h"
|
||||
#include "transform/tile_smem_loader.h"
|
||||
|
||||
#include <inttypes.h>
|
||||
|
||||
using namespace gemm_kernel_utils;
|
||||
|
||||
namespace {
|
||||
@ -1956,7 +1955,8 @@ struct AttentionBackwardKernel {
|
||||
|
||||
// no-op epilogue operator - just casting and storing contents of
|
||||
// accum to global memory
|
||||
typename MatmulDOIVJ::BiasGradEpilogue::OutputOp output_op({1, 1});
|
||||
typename MatmulDOIVJ::BiasGradEpilogue::OutputOp output_op(
|
||||
typename MatmulDOIVJ::BiasGradEpilogue::OutputOp::Params{1, 1});
|
||||
typename MatmulDOIVJ::BiasGradEpilogue epilogue(
|
||||
shared_storage.gradB_epilogue(), thread_id, warp_id, lane_id);
|
||||
epilogue(output_op, output_iter, accum, output_iter);
|
||||
@ -2211,7 +2211,7 @@ struct AttentionBackwardKernel {
|
||||
incrIteration(p, query_start, key_start, next_query, next_key);
|
||||
DISPATCH_BOOL(
|
||||
next_key != key_start, kForceReloadK, ([&]() {
|
||||
prologueQkNextIteration<kForceReloadK>(
|
||||
prologueQkNextIteration<kForceReloadK::value>(
|
||||
shared_storage, p, next_query, next_key, warp_id, lane_id);
|
||||
}));
|
||||
}
|
||||
@ -2342,7 +2342,7 @@ struct AttentionBackwardKernel {
|
||||
thread_id,
|
||||
cutlass::MatrixCoord{0, 0});
|
||||
|
||||
MatmulQK::Mma::prologue<kReloadK, true>(
|
||||
MatmulQK::Mma::template prologue<kReloadK, true>(
|
||||
shared_storage.mm_qk_k(),
|
||||
shared_storage.mm_qk_q(),
|
||||
iterator_A,
|
||||
@ -2369,6 +2369,7 @@ struct AttentionBackwardKernel {
|
||||
p.grad_value_ptr + key_start * p.gV_strideM(),
|
||||
{num_keys_in_block, p.head_dim_value},
|
||||
thread_id);
|
||||
|
||||
accumulateInGmem<MatmulGradV>(
|
||||
shared_storage.gradV_epilogue_final(),
|
||||
output_frags.gradV,
|
||||
@ -2406,7 +2407,7 @@ struct AttentionBackwardKernel {
|
||||
int thread_id = 32 * warp_id + lane_id;
|
||||
DISPATCH_BOOL(
|
||||
first, kIsFirst, ([&]() {
|
||||
static constexpr auto ScaleType = kIsFirst
|
||||
static constexpr auto ScaleType = kIsFirst::value
|
||||
? cutlass::epilogue::thread::ScaleType::Nothing
|
||||
: cutlass::epilogue::thread::ScaleType::NoBetaScaling;
|
||||
using EpilogueOutputOp =
|
||||
|
||||
@ -38,6 +38,7 @@
|
||||
|
||||
#include <curand_kernel.h>
|
||||
#include <cmath>
|
||||
#include <cinttypes>
|
||||
#include <vector>
|
||||
|
||||
#include "cutlass/fast_math.h"
|
||||
@ -71,8 +72,6 @@
|
||||
#include "gemm_kernel_utils.h"
|
||||
#include "transform/tile_smem_loader.h"
|
||||
|
||||
#include <inttypes.h>
|
||||
|
||||
using namespace gemm_kernel_utils;
|
||||
|
||||
namespace {
|
||||
@ -1036,15 +1035,15 @@ struct AttentionKernel {
|
||||
using EpilogueOutputOp = typename cutlass::epilogue::
|
||||
thread::MemoryEfficientAttentionNormalize<
|
||||
typename cutlass::platform::conditional<
|
||||
kIsLast,
|
||||
kIsLast::value,
|
||||
output_t,
|
||||
output_accum_t>::type,
|
||||
output_accum_t,
|
||||
DefaultOp::kCount,
|
||||
typename DefaultOp::ElementAccumulator,
|
||||
ElementCompute,
|
||||
kIsFirst,
|
||||
kIsLast,
|
||||
kIsFirst::value,
|
||||
kIsLast::value,
|
||||
cutlass::Array<ElementCompute, kQueriesPerBlock>>;
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::
|
||||
EpiloguePipelined<
|
||||
@ -1052,7 +1051,7 @@ struct AttentionKernel {
|
||||
typename MM1::Mma::Operator,
|
||||
DefaultEpilogue::kPartitionsK,
|
||||
typename cutlass::platform::conditional<
|
||||
kIsLast,
|
||||
kIsLast::value,
|
||||
typename MM1::OutputTileIterator,
|
||||
typename MM1::OutputTileIteratorAccum>::type,
|
||||
typename DefaultEpilogue::
|
||||
@ -1070,7 +1069,7 @@ struct AttentionKernel {
|
||||
int col = blockN * MM1::Mma::Shape::kN;
|
||||
auto source_iter = createOutputAccumIter(col);
|
||||
auto dest_iter = call_conditional<
|
||||
kIsLast,
|
||||
kIsLast::value,
|
||||
decltype(createOutputIter),
|
||||
decltype(createOutputAccumIter)>::
|
||||
apply(createOutputIter, createOutputAccumIter, col);
|
||||
|
||||
@ -39,11 +39,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cassert>
|
||||
#else
|
||||
#include <assert.h>
|
||||
#endif
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
@ -53,12 +49,9 @@
|
||||
#include "cutlass/tensor_coord.h"
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/functional.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/transform/pitch_linear_thread_map.h"
|
||||
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/epilogue_base.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
||||
|
||||
|
||||
@ -43,7 +43,7 @@ class gen_test:
|
||||
|
||||
def gen_cpp_sample(self):
|
||||
code = "/* Auto Generated code - Do not edit.*/\n"
|
||||
code += "#include <stdio.h> \n"
|
||||
code += "#include <cstdio> \n"
|
||||
|
||||
code += "#include \"cutlass/gemm/device/gemm_batched.h\" \n"
|
||||
code += "#include \"cutlass/cutlass.h\" \n"
|
||||
|
||||
@ -380,7 +380,7 @@ class gen_one_API:
|
||||
def gen_CUTLASS_irrelevant_API(self):
|
||||
code = ""
|
||||
code += "#include <cuda_runtime.h>\n"
|
||||
code += "#include <assert.h>\n"
|
||||
code += "#include <cassert>\n"
|
||||
|
||||
param_name = "Fused" + str(self.b2b_num) + "xGemm_"
|
||||
for i in range(self.b2b_num):
|
||||
|
||||
@ -66,7 +66,7 @@ int testRun(int arch, std::vector<bool (*)()> & test_funcs, const std::string &
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!(props.major == arch_major && props.minor == arch_minor)) {
|
||||
if (props.major < arch_major || (props.major == arch_major && props.minor < arch_minor) ) {
|
||||
supported = false;
|
||||
}
|
||||
|
||||
|
||||
@ -38,11 +38,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cassert>
|
||||
#else
|
||||
#include <assert.h>
|
||||
#endif
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
@ -45,18 +45,18 @@
|
||||
and BEFORE scatter operations are applied.
|
||||
*/
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <stdio.h>
|
||||
#include <time.h>
|
||||
#include <math.h>
|
||||
#include <assert.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <cstdlib>
|
||||
#include <cstdio>
|
||||
#include <ctime>
|
||||
#include <cmath>
|
||||
#include <cassert>
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <numeric>
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm_universal.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
@ -64,7 +64,6 @@
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/device_memory.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
|
||||
@ -619,7 +619,6 @@ int main(int argc, char const **args) {
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
@ -681,4 +680,4 @@ int main(int argc, char const **args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -559,4 +559,4 @@ int main(int argc, char const **args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -9,12 +9,17 @@ This first version only supports mixed type GEMMs using TMA.
|
||||
|
||||
## Performance
|
||||
|
||||
While the example offers a harness for straightforward benchmarking, this initial implementation isn't optimized for performance in the majority of scenarios. We expect this implementation to be performant for `{fp16, bf16} x {int8, int4}` and `{fp8} x {int4}` for problems that are compute bound. Additionally, we expect good performance for `fp16, bf16` or `fp32` scales and zero-points. For best performance, it is ideal to have the scales and zero-points be the same type.
|
||||
While the example offers a harness for straightforward benchmarking, this initial implementation isn't optimized for performance in the majority of scenarios. We expect this implementation to be performant for `{fp16, bf16} x {int8, int4, int2}` and `{fp8} x {int4}` for problems that are compute bound. Additionally, we expect good performance for `fp16`, `bf16` or `fp32` scales and zero-points. For best performance, it is ideal to have the scales and zero-points be the same type as mma's type.
|
||||
|
||||
The scale only mode for `fp8 x int4` is significantly slower than direct conversion mode. There is a lookup-table workaround targeting this mode, as shown in `55_hopper_int4_fp8_gemm.cu`. To use this feature, use `cutlass::Array<ElementScale, 8>` as the scale type in the collective builder. However, it requires modifications to the encoding of quantized weights and scale factors. Also, scale with zero point mode is not supported for now.
|
||||
|
||||
|
||||
Additionally, it's recommended to reorder the narrow data type tensor such that elements read into register file by the same thread are contiguous in global and shared memory. The user can use the helper function `compute_memory_reordering_atom` and `reorder_tensor` to achieve this. See `55_hopper_int4_fp8_gemm.cu` and `55_hopper_int4_bf16_gemm.cu` for more details.
|
||||
|
||||
|
||||
We are currently optimizing the following cases:
|
||||
1. Memory bound cases for all types
|
||||
2. `fp8 x {int2, uint2}` case
|
||||
|
||||
## Limitations
|
||||
|
||||
@ -36,4 +41,4 @@ We are currently optimizing the following cases:
|
||||
|
||||
* Optimizations for memory bound cases.
|
||||
|
||||
* Optimizations for scale and zero-point loading when the group size is not equal to the threadblock-k size.
|
||||
* Optimizations for scale and zero-point loading when the group size is not equal to the threadblock-k size.
|
||||
|
||||
@ -151,16 +151,16 @@ void mixed_dtype_profiling(
|
||||
runtimes.reserve(options.iterations);
|
||||
|
||||
for (int iter = 0; iter < options.warmup + options.iterations; ++iter) {
|
||||
cudaEventRecord(start);
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
cudaEventRecord(stop);
|
||||
cudaEventSynchronize(stop);
|
||||
cudaEventRecord(start);
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
cudaEventRecord(stop);
|
||||
cudaEventSynchronize(stop);
|
||||
|
||||
if (iter >= options.warmup) {
|
||||
float milliseconds = 0;
|
||||
cudaEventElapsedTime(&milliseconds, start, stop);
|
||||
runtimes.push_back(milliseconds);
|
||||
}
|
||||
if (iter >= options.warmup) {
|
||||
float milliseconds = 0;
|
||||
cudaEventElapsedTime(&milliseconds, start, stop);
|
||||
runtimes.push_back(milliseconds);
|
||||
}
|
||||
}
|
||||
|
||||
cudaEventDestroy(start);
|
||||
|
||||
@ -33,6 +33,9 @@
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
|
||||
#include "cutlass/util/device_memory.h"
|
||||
#include "cutlass/integer_subbyte.h"
|
||||
#include "cutlass/float8.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
@ -197,7 +200,6 @@ bool initialize_packed_scale(
|
||||
{
|
||||
cutlass::packed_scale_t<ElementScale> tmp(data_in[i]);
|
||||
data_out[i] = reinterpret_cast<cutlass::Array<ElementScale, 8> const&>(tmp);
|
||||
// std::cout << data_in[i] << ":" << std::hex << static_cast<uint16_t>(data_in[i].storage) << ",\t" << -data_in[i] << ":" << std::hex << static_cast<uint16_t>((-data_in[i]).storage) << std::endl;
|
||||
}
|
||||
try {
|
||||
block_out.copy_from_host(data_out.data());
|
||||
@ -207,4 +209,4 @@ bool initialize_packed_scale(
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@ -159,4 +159,4 @@ void reorder_tensor(
|
||||
cutlass::DeviceAllocation<T> temp(size(layout_src));
|
||||
reorder_tensor(data, layout_src, temp.get(), layout_dst);
|
||||
cutlass::device_memory::copy_device_to_device(data, temp.get(), static_cast<size_t>(size(layout_src)));
|
||||
}
|
||||
}
|
||||
|
||||
@ -63,7 +63,7 @@
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <float.h>
|
||||
#include <cfloat>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
|
||||
@ -35,9 +35,36 @@
|
||||
|
||||
#include "dispatch_policy_extra.hpp"
|
||||
#include "sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp"
|
||||
#include "../pipeline/prefetch_pipeline_sm90.hpp"
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
|
||||
template<int CapacityBytes, class ElementA, class ElementB, class TileShapeMNK, int stages>
|
||||
constexpr int
|
||||
compute_stage_count_or_override_prefetch(StageCount<stages> stage_count) {
|
||||
return stages;
|
||||
}
|
||||
|
||||
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
|
||||
template<int CapacityBytes, class ElementA, class ElementB, class TileShapeMNK, int carveout_bytes>
|
||||
constexpr int
|
||||
compute_stage_count_or_override_prefetch(StageCountAutoCarveout<carveout_bytes> stage_count) {
|
||||
constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage);
|
||||
constexpr auto prefetch_pipeline_bytes = sizeof(typename cutlass::detail::PrefetcherPipelineSharedStorage<PrefetchStages>);
|
||||
constexpr auto a_bits = cute::sizeof_bits_v<ElementA>;
|
||||
constexpr auto b_bits = cute::sizeof_bits_v<ElementB>;
|
||||
constexpr int MK_bytes = cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})); //also the prefetch smem size
|
||||
constexpr int NK_bytes = cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{}));
|
||||
constexpr int stage_bytes = MK_bytes + NK_bytes + static_cast<int>(mainloop_pipeline_bytes);
|
||||
|
||||
return (CapacityBytes - carveout_bytes - MK_bytes * PrefetchStagesActual - prefetch_pipeline_bytes) / stage_bytes;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// GMMA_TMA_WS_FP8_FAST_ACCUM_SS + prefetch
|
||||
template <
|
||||
class ElementA,
|
||||
@ -98,7 +125,7 @@ struct CollectiveBuilder<
|
||||
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<
|
||||
GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
|
||||
static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes,
|
||||
static constexpr int PipelineStages = detail::compute_stage_count_or_override_prefetch<detail::sm90_smem_capacity_bytes,
|
||||
ElementA, ElementB, TileShape_MNK>(StageCountType{});
|
||||
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch<PipelineStages, ClusterShape_MNK, KernelScheduleType>;
|
||||
|
||||
@ -184,7 +211,7 @@ struct CollectiveBuilder<
|
||||
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<
|
||||
GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
|
||||
static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes,
|
||||
static constexpr int PipelineStages = detail::compute_stage_count_or_override_prefetch<detail::sm90_smem_capacity_bytes,
|
||||
ElementA, ElementB, TileShape_MNK>(StageCountType{});
|
||||
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch<PipelineStages, ClusterShape_MNK, KernelScheduleType>;
|
||||
|
||||
|
||||
@ -57,6 +57,19 @@ using namespace cute;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail {
|
||||
|
||||
constexpr int PrefetchStages = 4;
|
||||
constexpr int PrefetchInitialStages = 1;
|
||||
// This determines how much shmem we set aside for prefetch.
|
||||
// We don't reuse anything loaded by prefetcher, so we can keep
|
||||
// loading into the same place -- there will be a conflict when
|
||||
// writing, but it doesn't affect performance as much as the doors
|
||||
// that this opens.
|
||||
constexpr int PrefetchStagesActual = 1;
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// WarpSpecialized Mainloop
|
||||
template <
|
||||
int Stages,
|
||||
@ -117,15 +130,7 @@ struct CollectiveMma<
|
||||
static_assert(size<1>(ClusterShape{}) == 1, "Cluster shape N must be 1");
|
||||
using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{}));
|
||||
|
||||
static constexpr int PrefetchStages = 4;
|
||||
static constexpr int PrefetchInitialStages = 1;
|
||||
// This determines how much shmem we set aside for prefetch.
|
||||
// We don't reuse anything loaded by prefetcher, so we can keep
|
||||
// loading into the same place -- there will be a conflict when
|
||||
// writing, but it doesn't affect performance as much as the doors
|
||||
// that this opens.
|
||||
static constexpr int PrefetchStagesActual = 1;
|
||||
using PrefetcherPipeline = cutlass::PrefetchPipeline<PrefetchStages>;
|
||||
using PrefetcherPipeline = cutlass::PrefetchPipeline<detail::PrefetchStages>;
|
||||
|
||||
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
|
||||
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
|
||||
@ -155,7 +160,7 @@ struct CollectiveMma<
|
||||
using PrefetchSmemLayoutA = decltype(make_layout(make_shape(
|
||||
cute::Int<size<0>(SmemLayoutA{})>{},
|
||||
cute::Int<size<1>(SmemLayoutA{})>{},
|
||||
cute::Int<PrefetchStagesActual>{})));
|
||||
cute::Int<detail::PrefetchStagesActual>{})));
|
||||
|
||||
static constexpr auto prefetch_smem_size = cute::cosize_v<PrefetchSmemLayoutA>;
|
||||
|
||||
@ -176,7 +181,7 @@ struct CollectiveMma<
|
||||
using InternalElementB = cute::conditional_t<ConvertF32toTF32B, tfloat32_t, uint_bit_t<sizeof_bits_v<ElementB>>>;
|
||||
|
||||
// Defined outside the class where it's used, to work around MSVC issues
|
||||
using PrefetcherPipelineStorage = ::cutlass::detail::PrefetcherPipelineSharedStorage<PrefetchStages>;
|
||||
using PrefetcherPipelineStorage = ::cutlass::detail::PrefetcherPipelineSharedStorage<detail::PrefetchStages>;
|
||||
|
||||
struct SharedStorage {
|
||||
struct TensorStorage : cute::aligned_struct<128, _0> {
|
||||
@ -660,7 +665,7 @@ struct CollectiveMma<
|
||||
bool do_best_effort_prefetch = mainloop_params.prefetch_ratio < 0;
|
||||
float prefetch_ratio = do_best_effort_prefetch ? 1.0 : mainloop_params.prefetch_ratio;
|
||||
int prefetch_iters = static_cast<int>(static_cast<float>(k_tile_count) * 0.5 * prefetch_ratio);
|
||||
prefetch_iters = min(k_tile_count, ((prefetch_iters + PrefetchStages - 1) / PrefetchStages) * PrefetchStages);
|
||||
prefetch_iters = min(k_tile_count, ((prefetch_iters + detail::PrefetchStages - 1) / detail::PrefetchStages) * detail::PrefetchStages);
|
||||
|
||||
Tensor sA = make_tensor(
|
||||
make_smem_ptr(shared_tensors.smem_prefetch.data()), PrefetchSmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
@ -702,7 +707,7 @@ struct CollectiveMma<
|
||||
break;
|
||||
}
|
||||
|
||||
prefetcher_pipeline.prefetcher_acquire(prefetcher_stage, prefetcher_phase, cnt >= PrefetchStages);
|
||||
prefetcher_pipeline.prefetcher_acquire(prefetcher_stage, prefetcher_phase, cnt >= detail::PrefetchStages);
|
||||
using BarrierType = typename PrefetcherPipeline::PrefetcherBarrierType;
|
||||
BarrierType* tma_barrier = prefetcher_pipeline.prefetcher_get_barrier(prefetcher_stage);
|
||||
|
||||
|
||||
35
examples/64_ada_fp8_gemm_grouped/CMakeLists.txt
Normal file
35
examples/64_ada_fp8_gemm_grouped/CMakeLists.txt
Normal file
@ -0,0 +1,35 @@
|
||||
|
||||
# Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
64_ada_fp8_gemm_grouped
|
||||
ada_fp8_gemm_grouped.cu
|
||||
)
|
||||
1208
examples/64_ada_fp8_gemm_grouped/ada_fp8_gemm_grouped.cu
Normal file
1208
examples/64_ada_fp8_gemm_grouped/ada_fp8_gemm_grouped.cu
Normal file
File diff suppressed because it is too large
Load Diff
@ -143,8 +143,10 @@ foreach(EXAMPLE
|
||||
61_hopper_gemm_with_topk_and_softmax
|
||||
62_hopper_sparse_gemm
|
||||
63_hopper_gemm_with_weight_prefetch
|
||||
64_ada_fp8_gemm_grouped
|
||||
)
|
||||
|
||||
add_subdirectory(${EXAMPLE})
|
||||
|
||||
endforeach()
|
||||
|
||||
|
||||
@ -95,36 +95,17 @@ __global__ void copy_kernel(TensorS S, TensorD D, ThreadLayout)
|
||||
/// Uses `make_tiled_copy()` to perform a copy using vector instructions. This operation
|
||||
/// has the precondition that pointers are aligned to the vector size.
|
||||
///
|
||||
template <class TensorS, class TensorD, class ThreadLayout, class VecLayout>
|
||||
__global__ void copy_kernel_vectorized(TensorS S, TensorD D, ThreadLayout, VecLayout)
|
||||
template <class TensorS, class TensorD, class Tiled_Copy>
|
||||
__global__ void copy_kernel_vectorized(TensorS S, TensorD D, Tiled_Copy tiled_copy)
|
||||
{
|
||||
using namespace cute;
|
||||
using Element = typename TensorS::value_type;
|
||||
|
||||
// Slice the tensors to obtain a view into each tile.
|
||||
Tensor tile_S = S(make_coord(_, _), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N)
|
||||
Tensor tile_D = D(make_coord(_, _), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N)
|
||||
|
||||
// Define `AccessType` which controls the size of the actual memory access.
|
||||
using AccessType = cutlass::AlignedArray<Element, size(VecLayout{})>;
|
||||
|
||||
// A copy atom corresponds to one hardware memory access.
|
||||
using Atom = Copy_Atom<UniversalCopy<AccessType>, Element>;
|
||||
|
||||
// Construct tiled copy, a tiling of copy atoms.
|
||||
//
|
||||
// Note, this assumes the vector and thread layouts are aligned with contigous data
|
||||
// in GMEM. Alternative thread layouts are possible but may result in uncoalesced
|
||||
// reads. Alternative vector layouts are also possible, though incompatible layouts
|
||||
// will result in compile time errors.
|
||||
auto tiled_copy =
|
||||
make_tiled_copy(
|
||||
Atom{}, // access size
|
||||
ThreadLayout{}, // thread layout
|
||||
VecLayout{}); // vector layout (e.g. 4x1)
|
||||
|
||||
// Construct a Tensor corresponding to each thread's slice.
|
||||
auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x);
|
||||
ThrCopy thr_copy = tiled_copy.get_thread_slice(threadIdx.x);
|
||||
|
||||
Tensor thr_tile_S = thr_copy.partition_S(tile_S); // (CopyOp, CopyM, CopyN)
|
||||
Tensor thr_tile_D = thr_copy.partition_D(tile_D); // (CopyOp, CopyM, CopyN)
|
||||
@ -198,11 +179,34 @@ int main(int argc, char** argv)
|
||||
Tensor tiled_tensor_S = tiled_divide(tensor_S, block_shape); // ((M, N), m', n')
|
||||
Tensor tiled_tensor_D = tiled_divide(tensor_D, block_shape); // ((M, N), m', n')
|
||||
|
||||
// Thread arrangement
|
||||
Layout thr_layout = make_layout(make_shape(Int<32>{}, Int<8>{}));
|
||||
// Construct a TiledCopy with a specific access pattern.
|
||||
// This version uses a
|
||||
// (1) Layout-of-Threads to describe the number and arrangement of threads (e.g. row-major, col-major, etc),
|
||||
// (2) Layout-of-Values that each thread will access.
|
||||
|
||||
// Vector dimensions
|
||||
Layout vec_layout = make_layout(make_shape(Int<4>{}, Int<1>{}));
|
||||
// Thread arrangement
|
||||
Layout thr_layout = make_layout(make_shape(Int<32>{}, Int<8>{})); // (32,8) -> thr_idx
|
||||
|
||||
// Value arrangement per thread
|
||||
Layout val_layout = make_layout(make_shape(Int<4>{}, Int<1>{})); // (4,1) -> val_idx
|
||||
|
||||
// Define `AccessType` which controls the size of the actual memory access instruction.
|
||||
using CopyOp = UniversalCopy<uint_byte_t<sizeof(Element) * size(val_layout)>>; // A very specific access width copy instruction
|
||||
//using CopyOp = UniversalCopy<cutlass::AlignedArray<Element, size(val_layout)>>; // A more generic type that supports many copy strategies
|
||||
//using CopyOp = AutoVectorizingCopy; // An adaptable-width instruction that assumes maximal alignment of inputs
|
||||
|
||||
// A Copy_Atom corresponds to one CopyOperation applied to Tensors of type Element.
|
||||
using Atom = Copy_Atom<CopyOp, Element>;
|
||||
|
||||
// Construct tiled copy, a tiling of copy atoms.
|
||||
//
|
||||
// Note, this assumes the vector and thread layouts are aligned with contigous data
|
||||
// in GMEM. Alternative thread layouts are possible but may result in uncoalesced
|
||||
// reads. Alternative value layouts are also possible, though incompatible layouts
|
||||
// will result in compile time errors.
|
||||
TiledCopy tiled_copy = make_tiled_copy(Atom{}, // Access strategy
|
||||
thr_layout, // thread layout (e.g. 32x4 Col-Major)
|
||||
val_layout); // value layout (e.g. 4x1)
|
||||
|
||||
//
|
||||
// Determine grid and block dimensions
|
||||
@ -217,8 +221,7 @@ int main(int argc, char** argv)
|
||||
copy_kernel_vectorized<<< gridDim, blockDim >>>(
|
||||
tiled_tensor_S,
|
||||
tiled_tensor_D,
|
||||
thr_layout,
|
||||
vec_layout);
|
||||
tiled_copy);
|
||||
|
||||
cudaError result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
|
||||
Reference in New Issue
Block a user