3.6.0 update (#2005)

* 3.6.0 update

* doc and swap stuff

---------

Co-authored-by: yuzhai <yuzhai@nvidia.com>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Yujia Zhai
2024-12-24 22:34:40 -08:00
committed by GitHub
parent e1cd8c7866
commit 3d261a5974
258 changed files with 10863 additions and 3883 deletions

View File

@ -80,4 +80,3 @@ foreach(FUSION_GEMM_EXAMPLE
add_dependencies(13_fused_two_gemms 13_${FUSION_GEMM_EXAMPLE})
endforeach()

View File

@ -59,11 +59,11 @@
// Also, we don't check the index value is legal and index array point is valid
// for the sake of the performance.
#include <stdlib.h>
#include <stdio.h>
#include <time.h>
#include <math.h>
#include <assert.h>
#include <cstdlib>
#include <cstdio>
#include <ctime>
#include <cmath>
#include <cassert>
#include <cuda_runtime.h>
#include <algorithm>

View File

@ -33,11 +33,7 @@
computing reference permutations of 4/5D tensors when source data is column-major.
*/
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#else
#include "assert.h"
#endif
#include "cutlass/cutlass.h"
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/layout/matrix.h"

View File

@ -30,8 +30,8 @@
**************************************************************************************************/
#pragma once
#include <float.h>
#include <stdio.h>
#include <cfloat>
#include <cstdio>
#include <cmath>
////////////////////////////////////////////////////////////////////////////////

View File

@ -43,11 +43,7 @@
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#else
#include <assert.h>
#endif
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
@ -57,12 +53,9 @@
#include "cutlass/layout/vector.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_coord.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
#include "cutlass/epilogue/threadblock/epilogue_base.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
#include "cutlass/numeric_types.h"

View File

@ -43,11 +43,7 @@
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#else
#include <assert.h>
#endif
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
@ -57,16 +53,12 @@
#include "cutlass/layout/vector.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_coord.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
#include "cutlass/epilogue/threadblock/epilogue_base.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
#include "cutlass/numeric_types.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/thread/scale_type.h"

View File

@ -550,7 +550,7 @@ public:
auto prologueV = [&](int blockN) {
typename MM1::Mma::IteratorB iterator_V(
typename MM1::IteratorB::Params{MM1::LayoutB(params.ldv[problem_idx])},
typename MM1::IteratorB::Params{typename MM1::LayoutB(params.ldv[problem_idx])},
params.ptr_V[problem_idx] + iter_key_start * params.ldv[problem_idx],
{problem_size_1_k, problem_size_1_n},
thread_id(),
@ -719,7 +719,7 @@ public:
}
typename MM1::Mma::IteratorB iterator_V(
typename MM1::IteratorB::Params{MM1::LayoutB(params.ldv[problem_idx])},
typename MM1::IteratorB::Params{typename MM1::LayoutB(params.ldv[problem_idx])},
params.ptr_V[problem_idx] + iter_key_start * params.ldv[problem_idx],
{problem_size_1_k, problem_size_1_n},
thread_id(),
@ -761,15 +761,15 @@ public:
using EpilogueOutputOp = typename cutlass::epilogue::
thread::MemoryEfficientAttentionNormalize<
typename cutlass::platform::conditional<
kIsLast,
kIsLast::value,
output_t,
output_accum_t>::type,
output_accum_t,
DefaultOp::kCount,
typename DefaultOp::ElementAccumulator,
output_accum_t,
kIsFirst,
kIsLast,
kIsFirst::value,
kIsLast::value,
cutlass::Array<ElementCompute, kQueriesPerBlock>>;
using Epilogue = typename cutlass::epilogue::threadblock::
EpiloguePipelined<
@ -777,7 +777,7 @@ public:
typename MM1::Mma::Operator,
DefaultEpilogue::kPartitionsK,
typename cutlass::platform::conditional<
kIsLast,
kIsLast::value,
typename MM1::OutputTileIterator,
typename MM1::OutputTileIteratorAccum>::type,
typename DefaultEpilogue::
@ -795,7 +795,7 @@ public:
int col = blockN * MM1::Mma::Shape::kN;
auto source_iter = createOutputAccumIter(col);
auto dest_iter = gemm_kernel_utils::call_conditional<
kIsLast,
kIsLast::value,
decltype(createOutputIter),
decltype(createOutputAccumIter)>::
apply(createOutputIter, createOutputAccumIter, col);
@ -817,8 +817,8 @@ public:
}
if (kKeepOutputInRF) {
const bool kIsFirst = true;
const bool kIsLast = true;
constexpr bool kIsFirst = true;
constexpr bool kIsLast = true;
using DefaultEpilogue = typename MM1::DefaultEpilogue;
using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp;
using ElementCompute = typename DefaultOp::ElementCompute;

View File

@ -55,13 +55,14 @@
#define DISPATCH_BOOL(BOOL_V, BOOL_NAME, F) \
{ \
if (BOOL_V) { \
constexpr bool BOOL_NAME = true; \
using BOOL_NAME = std::true_type; \
F(); \
} else { \
constexpr bool BOOL_NAME = false; \
using BOOL_NAME = std::false_type; \
F(); \
} \
}
#define DISPATCH_ARCHTAG(CC, func) \
{ \
if (CC >= 80) { \

View File

@ -32,6 +32,7 @@
#pragma once
#include <cmath>
#include <cinttypes>
#include <type_traits>
#include <vector>
@ -85,8 +86,6 @@
#include "gemm/mma_from_smem.h"
#include "transform/tile_smem_loader.h"
#include <inttypes.h>
using namespace gemm_kernel_utils;
namespace {
@ -1956,7 +1955,8 @@ struct AttentionBackwardKernel {
// no-op epilogue operator - just casting and storing contents of
// accum to global memory
typename MatmulDOIVJ::BiasGradEpilogue::OutputOp output_op({1, 1});
typename MatmulDOIVJ::BiasGradEpilogue::OutputOp output_op(
typename MatmulDOIVJ::BiasGradEpilogue::OutputOp::Params{1, 1});
typename MatmulDOIVJ::BiasGradEpilogue epilogue(
shared_storage.gradB_epilogue(), thread_id, warp_id, lane_id);
epilogue(output_op, output_iter, accum, output_iter);
@ -2211,7 +2211,7 @@ struct AttentionBackwardKernel {
incrIteration(p, query_start, key_start, next_query, next_key);
DISPATCH_BOOL(
next_key != key_start, kForceReloadK, ([&]() {
prologueQkNextIteration<kForceReloadK>(
prologueQkNextIteration<kForceReloadK::value>(
shared_storage, p, next_query, next_key, warp_id, lane_id);
}));
}
@ -2342,7 +2342,7 @@ struct AttentionBackwardKernel {
thread_id,
cutlass::MatrixCoord{0, 0});
MatmulQK::Mma::prologue<kReloadK, true>(
MatmulQK::Mma::template prologue<kReloadK, true>(
shared_storage.mm_qk_k(),
shared_storage.mm_qk_q(),
iterator_A,
@ -2369,6 +2369,7 @@ struct AttentionBackwardKernel {
p.grad_value_ptr + key_start * p.gV_strideM(),
{num_keys_in_block, p.head_dim_value},
thread_id);
accumulateInGmem<MatmulGradV>(
shared_storage.gradV_epilogue_final(),
output_frags.gradV,
@ -2406,7 +2407,7 @@ struct AttentionBackwardKernel {
int thread_id = 32 * warp_id + lane_id;
DISPATCH_BOOL(
first, kIsFirst, ([&]() {
static constexpr auto ScaleType = kIsFirst
static constexpr auto ScaleType = kIsFirst::value
? cutlass::epilogue::thread::ScaleType::Nothing
: cutlass::epilogue::thread::ScaleType::NoBetaScaling;
using EpilogueOutputOp =

View File

@ -38,6 +38,7 @@
#include <curand_kernel.h>
#include <cmath>
#include <cinttypes>
#include <vector>
#include "cutlass/fast_math.h"
@ -71,8 +72,6 @@
#include "gemm_kernel_utils.h"
#include "transform/tile_smem_loader.h"
#include <inttypes.h>
using namespace gemm_kernel_utils;
namespace {
@ -1036,15 +1035,15 @@ struct AttentionKernel {
using EpilogueOutputOp = typename cutlass::epilogue::
thread::MemoryEfficientAttentionNormalize<
typename cutlass::platform::conditional<
kIsLast,
kIsLast::value,
output_t,
output_accum_t>::type,
output_accum_t,
DefaultOp::kCount,
typename DefaultOp::ElementAccumulator,
ElementCompute,
kIsFirst,
kIsLast,
kIsFirst::value,
kIsLast::value,
cutlass::Array<ElementCompute, kQueriesPerBlock>>;
using Epilogue = typename cutlass::epilogue::threadblock::
EpiloguePipelined<
@ -1052,7 +1051,7 @@ struct AttentionKernel {
typename MM1::Mma::Operator,
DefaultEpilogue::kPartitionsK,
typename cutlass::platform::conditional<
kIsLast,
kIsLast::value,
typename MM1::OutputTileIterator,
typename MM1::OutputTileIteratorAccum>::type,
typename DefaultEpilogue::
@ -1070,7 +1069,7 @@ struct AttentionKernel {
int col = blockN * MM1::Mma::Shape::kN;
auto source_iter = createOutputAccumIter(col);
auto dest_iter = call_conditional<
kIsLast,
kIsLast::value,
decltype(createOutputIter),
decltype(createOutputAccumIter)>::
apply(createOutputIter, createOutputAccumIter, col);

View File

@ -39,11 +39,7 @@
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#else
#include <assert.h>
#endif
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
@ -53,12 +49,9 @@
#include "cutlass/tensor_coord.h"
#include "cutlass/aligned_buffer.h"
#include "cutlass/functional.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
#include "cutlass/epilogue/threadblock/epilogue_base.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"

View File

@ -43,7 +43,7 @@ class gen_test:
def gen_cpp_sample(self):
code = "/* Auto Generated code - Do not edit.*/\n"
code += "#include <stdio.h> \n"
code += "#include <cstdio> \n"
code += "#include \"cutlass/gemm/device/gemm_batched.h\" \n"
code += "#include \"cutlass/cutlass.h\" \n"

View File

@ -380,7 +380,7 @@ class gen_one_API:
def gen_CUTLASS_irrelevant_API(self):
code = ""
code += "#include <cuda_runtime.h>\n"
code += "#include <assert.h>\n"
code += "#include <cassert>\n"
param_name = "Fused" + str(self.b2b_num) + "xGemm_"
for i in range(self.b2b_num):

View File

@ -66,7 +66,7 @@ int testRun(int arch, std::vector<bool (*)()> & test_funcs, const std::string &
return -1;
}
if (!(props.major == arch_major && props.minor == arch_minor)) {
if (props.major < arch_major || (props.major == arch_major && props.minor < arch_minor) ) {
supported = false;
}

View File

@ -38,11 +38,7 @@
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#else
#include <assert.h>
#endif
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"

View File

@ -45,18 +45,18 @@
and BEFORE scatter operations are applied.
*/
#include <stdlib.h>
#include <stdio.h>
#include <time.h>
#include <math.h>
#include <assert.h>
#include <cuda_runtime.h>
#include <cstdlib>
#include <cstdio>
#include <ctime>
#include <cmath>
#include <cassert>
#include <algorithm>
#include <iostream>
#include <random>
#include <numeric>
#include <cuda_runtime.h>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
@ -64,7 +64,6 @@
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/util/packed_stride.hpp"

View File

@ -619,7 +619,6 @@ int main(int argc, char const **args) {
<< "later (compute capability 90 or greater).\n";
return 0;
}
//
// Parse options
//
@ -681,4 +680,4 @@ int main(int argc, char const **args) {
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -559,4 +559,4 @@ int main(int argc, char const **args) {
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -9,12 +9,17 @@ This first version only supports mixed type GEMMs using TMA.
## Performance
While the example offers a harness for straightforward benchmarking, this initial implementation isn't optimized for performance in the majority of scenarios. We expect this implementation to be performant for `{fp16, bf16} x {int8, int4}` and `{fp8} x {int4}` for problems that are compute bound. Additionally, we expect good performance for `fp16, bf16` or `fp32` scales and zero-points. For best performance, it is ideal to have the scales and zero-points be the same type.
While the example offers a harness for straightforward benchmarking, this initial implementation isn't optimized for performance in the majority of scenarios. We expect this implementation to be performant for `{fp16, bf16} x {int8, int4, int2}` and `{fp8} x {int4}` for problems that are compute bound. Additionally, we expect good performance for `fp16`, `bf16` or `fp32` scales and zero-points. For best performance, it is ideal to have the scales and zero-points be the same type as mma's type.
The scale only mode for `fp8 x int4` is significantly slower than direct conversion mode. There is a lookup-table workaround targeting this mode, as shown in `55_hopper_int4_fp8_gemm.cu`. To use this feature, use `cutlass::Array<ElementScale, 8>` as the scale type in the collective builder. However, it requires modifications to the encoding of quantized weights and scale factors. Also, scale with zero point mode is not supported for now.
Additionally, it's recommended to reorder the narrow data type tensor such that elements read into register file by the same thread are contiguous in global and shared memory. The user can use the helper function `compute_memory_reordering_atom` and `reorder_tensor` to achieve this. See `55_hopper_int4_fp8_gemm.cu` and `55_hopper_int4_bf16_gemm.cu` for more details.
We are currently optimizing the following cases:
1. Memory bound cases for all types
2. `fp8 x {int2, uint2}` case
## Limitations
@ -36,4 +41,4 @@ We are currently optimizing the following cases:
* Optimizations for memory bound cases.
* Optimizations for scale and zero-point loading when the group size is not equal to the threadblock-k size.
* Optimizations for scale and zero-point loading when the group size is not equal to the threadblock-k size.

View File

@ -151,16 +151,16 @@ void mixed_dtype_profiling(
runtimes.reserve(options.iterations);
for (int iter = 0; iter < options.warmup + options.iterations; ++iter) {
cudaEventRecord(start);
CUTLASS_CHECK(gemm.run());
cudaEventRecord(stop);
cudaEventSynchronize(stop);
cudaEventRecord(start);
CUTLASS_CHECK(gemm.run());
cudaEventRecord(stop);
cudaEventSynchronize(stop);
if (iter >= options.warmup) {
float milliseconds = 0;
cudaEventElapsedTime(&milliseconds, start, stop);
runtimes.push_back(milliseconds);
}
if (iter >= options.warmup) {
float milliseconds = 0;
cudaEventElapsedTime(&milliseconds, start, stop);
runtimes.push_back(milliseconds);
}
}
cudaEventDestroy(start);

View File

@ -33,6 +33,9 @@
#include <cstdint>
#include "cutlass/util/device_memory.h"
#include "cutlass/integer_subbyte.h"
#include "cutlass/float8.h"
#include "cutlass/util/reference/device/tensor_fill.h"
@ -197,7 +200,6 @@ bool initialize_packed_scale(
{
cutlass::packed_scale_t<ElementScale> tmp(data_in[i]);
data_out[i] = reinterpret_cast<cutlass::Array<ElementScale, 8> const&>(tmp);
// std::cout << data_in[i] << ":" << std::hex << static_cast<uint16_t>(data_in[i].storage) << ",\t" << -data_in[i] << ":" << std::hex << static_cast<uint16_t>((-data_in[i]).storage) << std::endl;
}
try {
block_out.copy_from_host(data_out.data());
@ -207,4 +209,4 @@ bool initialize_packed_scale(
return false;
}
return true;
}
}

View File

@ -159,4 +159,4 @@ void reorder_tensor(
cutlass::DeviceAllocation<T> temp(size(layout_src));
reorder_tensor(data, layout_src, temp.get(), layout_dst);
cutlass::device_memory::copy_device_to_device(data, temp.get(), static_cast<size_t>(size(layout_src)));
}
}

View File

@ -63,7 +63,7 @@
#include <fstream>
#include <sstream>
#include <vector>
#include <float.h>
#include <cfloat>
#include "cutlass/cutlass.h"

View File

@ -35,9 +35,36 @@
#include "dispatch_policy_extra.hpp"
#include "sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp"
#include "../pipeline/prefetch_pipeline_sm90.hpp"
namespace cutlass::gemm::collective {
namespace detail {
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
template<int CapacityBytes, class ElementA, class ElementB, class TileShapeMNK, int stages>
constexpr int
compute_stage_count_or_override_prefetch(StageCount<stages> stage_count) {
return stages;
}
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
template<int CapacityBytes, class ElementA, class ElementB, class TileShapeMNK, int carveout_bytes>
constexpr int
compute_stage_count_or_override_prefetch(StageCountAutoCarveout<carveout_bytes> stage_count) {
constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage);
constexpr auto prefetch_pipeline_bytes = sizeof(typename cutlass::detail::PrefetcherPipelineSharedStorage<PrefetchStages>);
constexpr auto a_bits = cute::sizeof_bits_v<ElementA>;
constexpr auto b_bits = cute::sizeof_bits_v<ElementB>;
constexpr int MK_bytes = cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})); //also the prefetch smem size
constexpr int NK_bytes = cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{}));
constexpr int stage_bytes = MK_bytes + NK_bytes + static_cast<int>(mainloop_pipeline_bytes);
return (CapacityBytes - carveout_bytes - MK_bytes * PrefetchStagesActual - prefetch_pipeline_bytes) / stage_bytes;
}
} // namespace detail
// GMMA_TMA_WS_FP8_FAST_ACCUM_SS + prefetch
template <
class ElementA,
@ -98,7 +125,7 @@ struct CollectiveBuilder<
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<
GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes,
static constexpr int PipelineStages = detail::compute_stage_count_or_override_prefetch<detail::sm90_smem_capacity_bytes,
ElementA, ElementB, TileShape_MNK>(StageCountType{});
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch<PipelineStages, ClusterShape_MNK, KernelScheduleType>;
@ -184,7 +211,7 @@ struct CollectiveBuilder<
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<
GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes,
static constexpr int PipelineStages = detail::compute_stage_count_or_override_prefetch<detail::sm90_smem_capacity_bytes,
ElementA, ElementB, TileShape_MNK>(StageCountType{});
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch<PipelineStages, ClusterShape_MNK, KernelScheduleType>;

View File

@ -57,6 +57,19 @@ using namespace cute;
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
constexpr int PrefetchStages = 4;
constexpr int PrefetchInitialStages = 1;
// This determines how much shmem we set aside for prefetch.
// We don't reuse anything loaded by prefetcher, so we can keep
// loading into the same place -- there will be a conflict when
// writing, but it doesn't affect performance as much as the doors
// that this opens.
constexpr int PrefetchStagesActual = 1;
} // namespace detail
// WarpSpecialized Mainloop
template <
int Stages,
@ -117,15 +130,7 @@ struct CollectiveMma<
static_assert(size<1>(ClusterShape{}) == 1, "Cluster shape N must be 1");
using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{}));
static constexpr int PrefetchStages = 4;
static constexpr int PrefetchInitialStages = 1;
// This determines how much shmem we set aside for prefetch.
// We don't reuse anything loaded by prefetcher, so we can keep
// loading into the same place -- there will be a conflict when
// writing, but it doesn't affect performance as much as the doors
// that this opens.
static constexpr int PrefetchStagesActual = 1;
using PrefetcherPipeline = cutlass::PrefetchPipeline<PrefetchStages>;
using PrefetcherPipeline = cutlass::PrefetchPipeline<detail::PrefetchStages>;
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
@ -155,7 +160,7 @@ struct CollectiveMma<
using PrefetchSmemLayoutA = decltype(make_layout(make_shape(
cute::Int<size<0>(SmemLayoutA{})>{},
cute::Int<size<1>(SmemLayoutA{})>{},
cute::Int<PrefetchStagesActual>{})));
cute::Int<detail::PrefetchStagesActual>{})));
static constexpr auto prefetch_smem_size = cute::cosize_v<PrefetchSmemLayoutA>;
@ -176,7 +181,7 @@ struct CollectiveMma<
using InternalElementB = cute::conditional_t<ConvertF32toTF32B, tfloat32_t, uint_bit_t<sizeof_bits_v<ElementB>>>;
// Defined outside the class where it's used, to work around MSVC issues
using PrefetcherPipelineStorage = ::cutlass::detail::PrefetcherPipelineSharedStorage<PrefetchStages>;
using PrefetcherPipelineStorage = ::cutlass::detail::PrefetcherPipelineSharedStorage<detail::PrefetchStages>;
struct SharedStorage {
struct TensorStorage : cute::aligned_struct<128, _0> {
@ -660,7 +665,7 @@ struct CollectiveMma<
bool do_best_effort_prefetch = mainloop_params.prefetch_ratio < 0;
float prefetch_ratio = do_best_effort_prefetch ? 1.0 : mainloop_params.prefetch_ratio;
int prefetch_iters = static_cast<int>(static_cast<float>(k_tile_count) * 0.5 * prefetch_ratio);
prefetch_iters = min(k_tile_count, ((prefetch_iters + PrefetchStages - 1) / PrefetchStages) * PrefetchStages);
prefetch_iters = min(k_tile_count, ((prefetch_iters + detail::PrefetchStages - 1) / detail::PrefetchStages) * detail::PrefetchStages);
Tensor sA = make_tensor(
make_smem_ptr(shared_tensors.smem_prefetch.data()), PrefetchSmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
@ -702,7 +707,7 @@ struct CollectiveMma<
break;
}
prefetcher_pipeline.prefetcher_acquire(prefetcher_stage, prefetcher_phase, cnt >= PrefetchStages);
prefetcher_pipeline.prefetcher_acquire(prefetcher_stage, prefetcher_phase, cnt >= detail::PrefetchStages);
using BarrierType = typename PrefetcherPipeline::PrefetcherBarrierType;
BarrierType* tma_barrier = prefetcher_pipeline.prefetcher_get_barrier(prefetcher_stage);

View File

@ -0,0 +1,35 @@
# Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cutlass_example_add_executable(
64_ada_fp8_gemm_grouped
ada_fp8_gemm_grouped.cu
)

File diff suppressed because it is too large Load Diff

View File

@ -143,8 +143,10 @@ foreach(EXAMPLE
61_hopper_gemm_with_topk_and_softmax
62_hopper_sparse_gemm
63_hopper_gemm_with_weight_prefetch
64_ada_fp8_gemm_grouped
)
add_subdirectory(${EXAMPLE})
endforeach()

View File

@ -95,36 +95,17 @@ __global__ void copy_kernel(TensorS S, TensorD D, ThreadLayout)
/// Uses `make_tiled_copy()` to perform a copy using vector instructions. This operation
/// has the precondition that pointers are aligned to the vector size.
///
template <class TensorS, class TensorD, class ThreadLayout, class VecLayout>
__global__ void copy_kernel_vectorized(TensorS S, TensorD D, ThreadLayout, VecLayout)
template <class TensorS, class TensorD, class Tiled_Copy>
__global__ void copy_kernel_vectorized(TensorS S, TensorD D, Tiled_Copy tiled_copy)
{
using namespace cute;
using Element = typename TensorS::value_type;
// Slice the tensors to obtain a view into each tile.
Tensor tile_S = S(make_coord(_, _), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N)
Tensor tile_D = D(make_coord(_, _), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N)
// Define `AccessType` which controls the size of the actual memory access.
using AccessType = cutlass::AlignedArray<Element, size(VecLayout{})>;
// A copy atom corresponds to one hardware memory access.
using Atom = Copy_Atom<UniversalCopy<AccessType>, Element>;
// Construct tiled copy, a tiling of copy atoms.
//
// Note, this assumes the vector and thread layouts are aligned with contigous data
// in GMEM. Alternative thread layouts are possible but may result in uncoalesced
// reads. Alternative vector layouts are also possible, though incompatible layouts
// will result in compile time errors.
auto tiled_copy =
make_tiled_copy(
Atom{}, // access size
ThreadLayout{}, // thread layout
VecLayout{}); // vector layout (e.g. 4x1)
// Construct a Tensor corresponding to each thread's slice.
auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x);
ThrCopy thr_copy = tiled_copy.get_thread_slice(threadIdx.x);
Tensor thr_tile_S = thr_copy.partition_S(tile_S); // (CopyOp, CopyM, CopyN)
Tensor thr_tile_D = thr_copy.partition_D(tile_D); // (CopyOp, CopyM, CopyN)
@ -198,11 +179,34 @@ int main(int argc, char** argv)
Tensor tiled_tensor_S = tiled_divide(tensor_S, block_shape); // ((M, N), m', n')
Tensor tiled_tensor_D = tiled_divide(tensor_D, block_shape); // ((M, N), m', n')
// Thread arrangement
Layout thr_layout = make_layout(make_shape(Int<32>{}, Int<8>{}));
// Construct a TiledCopy with a specific access pattern.
// This version uses a
// (1) Layout-of-Threads to describe the number and arrangement of threads (e.g. row-major, col-major, etc),
// (2) Layout-of-Values that each thread will access.
// Vector dimensions
Layout vec_layout = make_layout(make_shape(Int<4>{}, Int<1>{}));
// Thread arrangement
Layout thr_layout = make_layout(make_shape(Int<32>{}, Int<8>{})); // (32,8) -> thr_idx
// Value arrangement per thread
Layout val_layout = make_layout(make_shape(Int<4>{}, Int<1>{})); // (4,1) -> val_idx
// Define `AccessType` which controls the size of the actual memory access instruction.
using CopyOp = UniversalCopy<uint_byte_t<sizeof(Element) * size(val_layout)>>; // A very specific access width copy instruction
//using CopyOp = UniversalCopy<cutlass::AlignedArray<Element, size(val_layout)>>; // A more generic type that supports many copy strategies
//using CopyOp = AutoVectorizingCopy; // An adaptable-width instruction that assumes maximal alignment of inputs
// A Copy_Atom corresponds to one CopyOperation applied to Tensors of type Element.
using Atom = Copy_Atom<CopyOp, Element>;
// Construct tiled copy, a tiling of copy atoms.
//
// Note, this assumes the vector and thread layouts are aligned with contigous data
// in GMEM. Alternative thread layouts are possible but may result in uncoalesced
// reads. Alternative value layouts are also possible, though incompatible layouts
// will result in compile time errors.
TiledCopy tiled_copy = make_tiled_copy(Atom{}, // Access strategy
thr_layout, // thread layout (e.g. 32x4 Col-Major)
val_layout); // value layout (e.g. 4x1)
//
// Determine grid and block dimensions
@ -217,8 +221,7 @@ int main(int argc, char** argv)
copy_kernel_vectorized<<< gridDim, blockDim >>>(
tiled_tensor_S,
tiled_tensor_D,
thr_layout,
vec_layout);
tiled_copy);
cudaError result = cudaDeviceSynchronize();
if (result != cudaSuccess) {