CUTLASS 3.4.0 (#1286)
* CUTLASS 3.4.0 * Update CHANGELOG.md --------- Co-authored-by: Pradeep Ramani <prramani@nvidia.com>
This commit is contained in:
@ -9,7 +9,8 @@ Clang as both host and device compiler ("CUDA Clang").
|
||||
|
||||
# Software prerequisites
|
||||
|
||||
1. Clang (tested with Clang 14)
|
||||
1. Clang (regularly tested with Clang 14;
|
||||
occasionally tested with Clang 10 and greater)
|
||||
|
||||
2. CUDA Toolkit (tested with 12.2; other versions likely work)
|
||||
|
||||
@ -32,14 +33,18 @@ is the following error when attempting to use clang:
|
||||
|
||||
# Running CMake
|
||||
|
||||
The Clang build requires specifying the following three CMake options.
|
||||
## Required CMake options
|
||||
|
||||
* `CMAKE_CXX_COMPILER=clang++`
|
||||
* `CMAKE_CUDA_HOST_COMPILER=clang++`
|
||||
The Clang build requires specifying the following CMake options.
|
||||
Replace `<path-to-clang++>` with the path to your `clang++` executable,
|
||||
and replace `<path-to-clang>` with the path to your `clang` executable
|
||||
(which must have the same version as your `clang++` executable).
|
||||
You may use `clang++` resp. `clang` directly if they are in your `PATH`.
|
||||
|
||||
* `CMAKE_C_COMPILER=clang`
|
||||
* `CMAKE_CXX_COMPILER=<path-to-clang++>`
|
||||
* `CMAKE_CUDA_HOST_COMPILER=<path-to-clang++>`
|
||||
* `CMAKE_C_COMPILER=<path-to-clang>`
|
||||
|
||||
This assumes that `clang++` and `clang` are in the user's `PATH`.
|
||||
Please note that both `CMAKE_CXX_COMPILER` and `CMAKE_C_COMPILER`
|
||||
must be set, even though CUTLASS is a C++ project, not a C project.
|
||||
|
||||
@ -51,3 +56,4 @@ if `${PATH_TO_CUDA_TOOLKIT}` is the CUDA Toolkit directory,
|
||||
then one can set `CMAKE_CUDA_COMPILER` as follows.
|
||||
|
||||
* `CMAKE_CUDA_COMPILER=${PATH_TO_CUDA_TOOLKIT}/bin/nvcc`
|
||||
|
||||
|
||||
@ -1,83 +1,233 @@
|
||||
# TMA tensors
|
||||
# CuTe TMA Tensors
|
||||
|
||||
TMA tensors have three differences from
|
||||
"ordinary" global memory tensors.
|
||||
Along your travels, you may find strange looking CuTe Tensors that are printed as something like
|
||||
```
|
||||
ArithTuple(0,_0,_0,_0) o ((_128,_64),2,3,1):((_1@0,_1@1),_64@1,_1@2,_1@3)
|
||||
```
|
||||
What is an `ArithTuple`? Are those tensor strides? What do those mean? What is this for?
|
||||
|
||||
1. The tensor's iterator stores a base coordinate,
|
||||
not a pointer.
|
||||
This documentation intends to answer those questions and introduce some of the more advanced features of CuTe.
|
||||
|
||||
2. The tensor's actual global memory pointer
|
||||
does not live in the tensor.
|
||||
Instead, it lives in a TMA descriptor,
|
||||
which is stored in the TMA `Copy_Traits` specialization.
|
||||
# Introduction to TMA instructions
|
||||
|
||||
3. The tensor's strides aren't just integers.
|
||||
Instead, they are linear combinations of "basis functions."
|
||||
The Tensor Memory Accelerator (TMA) is a set of instructions for copying possibly multidimensional arrays between global and shared memory. TMA was introduced in the Hopper architecture. A single TMA instruction can copy an entire tile of data all at once. As a result, the hardware no longer needs to compute individual memory addresses and issue a separate copy instruction for each element of the tile.
|
||||
|
||||
The following sections will elaborate these differences.
|
||||
To accomplish this, the TMA instruction is given a *TMA descriptor*, which is a packed representation of a multidimensional tensor in global memory with 1, 2, 3, 4, or 5 dimensions. The TMA descriptor holds
|
||||
|
||||
## Iterator stores a base coordinate, not a pointer
|
||||
* the base pointer of the tensor;
|
||||
|
||||
"Ordinary" tensors of global memory have an iterator type
|
||||
(the "Engine" template parameter) that wraps a pointer.
|
||||
For example, `gmem_ptr<T>` wraps a `T*`.
|
||||
A TMA tensor's iterator type is `ArithmeticTupleIterator`.
|
||||
`ArithmeticTupleIterator` stores a coordinate
|
||||
(a tuple of integers) instead of a pointer.
|
||||
The coordinate is represented as an `ArithmeticTuple`,
|
||||
which is just a (public subclass of) `cute::tuple`
|
||||
that has an overloaded `operator+`.
|
||||
The sum of two tuples is the tuple of the sum of the elements.
|
||||
* the data type of the tensor's elements (e.g., `int`, `float`, `double`, or `half`);
|
||||
|
||||
When we perform the TMA load or store,
|
||||
the iterator's coordinate goes into the PTX instruction.
|
||||
(For TMA specializations of `Copy_Traits`,
|
||||
this happens in the `private` member function `copy_unpack_`.)
|
||||
The coordinate represents the tensor's "base coordinate."
|
||||
For tiled TMA, the base coordinate of the whole tensor
|
||||
might start out as (0, 0, ..., 0). However, slicing the tensor
|
||||
might result in a different base coordinate.
|
||||
For im2col TMA load, the base coordinate is the lower corner.
|
||||
* the size of each dimension;
|
||||
|
||||
## Pointer lives in TMA descriptor, not tensor
|
||||
* the stride within each dimension; and
|
||||
|
||||
The TMA descriptor has the actual pointer to global memory in it.
|
||||
Storing the TMA descriptor in the tensor would make tensors
|
||||
expensive to copy and slice, as the TMA descriptor is 128 bytes.
|
||||
Instead, we store the TMA descriptor
|
||||
in the `Copy_Traits` specialization.
|
||||
* other flags representing the smem box size, smem swizzling patterns, and out-of-bounds access behavior.
|
||||
|
||||
## Tensor's strides aren't just integers
|
||||
This descriptor must be created on the host before kernel execution.
|
||||
It is shared between all thread blocks that will be issuing TMA instructions.
|
||||
Once inside the kernel, the TMA is executed with the following parameters:
|
||||
|
||||
For "ordinary" tensors, the layout takes a coordinate
|
||||
`(i, j)` as input, and returns a single integer offset `k`.
|
||||
The resulting pointer-to-element
|
||||
is the base pointer, plus the offset k.
|
||||
However, TMA loads and stores don't take a pointer.
|
||||
They take a TMA descriptor, and a coordinate `(i, j)`.
|
||||
Building the strides out of "basis functions"
|
||||
is the trick to make the layout return a coordinate --
|
||||
a tuple of integers -- instead of just a single integer offset.
|
||||
A "basis function" for strides
|
||||
is a lot like a basis function for Euclidean space,
|
||||
except that strides' basis functions can be hierarchical.
|
||||
* pointer to the TMA descriptor;
|
||||
|
||||
* pointer to the SMEM; and
|
||||
|
||||
* coordinates into the GMEM tensor represented within the TMA descriptor.
|
||||
|
||||
For example, the interface for TMA-store with 3-D coordinates looks like this.
|
||||
|
||||
```cpp
|
||||
struct SM90_TMA_STORE_3D {
|
||||
CUTE_DEVICE static void
|
||||
copy(void const* const desc_ptr,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) {
|
||||
// ... invoke CUDA PTX instruction ...
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
We observe that the TMA instruction does not directly consume pointers to global memory. Indeed, the global memory pointer is contained in the descriptor, is considered constant, and is NOT a separate parameter to the TMA instruction. Instead, the TMA consumes TMA coordinates into the TMA's view of global memory that is defined in the TMA descriptor.
|
||||
|
||||
That means that an ordinary CuTe Tensor that stores a GMEM pointer and computes offsets and new GMEM pointers is useless to the TMA.
|
||||
|
||||
What do we do?
|
||||
|
||||
# Building a TMA Tensor
|
||||
|
||||
## Implicit CuTe Tensors
|
||||
|
||||
All CuTe Tensors are compositions of Layouts and Iterators. An ordinary global memory tensor's iterator is its global memory pointer. However, a CuTe Tensor's iterator doesn't have to be a pointer; it can be any random-access iterator.
|
||||
|
||||
One example of such an iterator is a *counting iterator*.
|
||||
This represents a possibly infinite sequence of integers that starts at some value.
|
||||
We call the members of this sequence *implicit integers*,
|
||||
because the sequence is not explicitly stored in memory.
|
||||
The iterator just stores its current value.
|
||||
|
||||
We can use a counting iterator to create a tensor of implicit integers,
|
||||
```cpp
|
||||
Tensor A = make_tensor(counting_iterator<int>(42), make_shape(4,5));
|
||||
print_tensor(A);
|
||||
```
|
||||
which outputs
|
||||
```
|
||||
counting_iter(42) o (4,5):(_1,4):
|
||||
42 46 50 54 58
|
||||
43 47 51 55 59
|
||||
44 48 52 56 60
|
||||
45 49 53 57 61
|
||||
```
|
||||
This tensor maps logical coordinates to on-the-fly computed integers. Because it's still a CuTe Tensor, it can still be tiled and partitioned and sliced just like a normal tensor by accumulating integer offsets into the iterator.
|
||||
|
||||
But the TMA doesn't consume pointers or integers, it consumes coordinates. Can we make a tensor of implicit TMA
|
||||
coordinates for the TMA instruction to consume? If so, then we could presumably also tile and partition and slice that tensor of coordinates so that we would always have the right TMA coordinate to give to the instruction.
|
||||
|
||||
## ArithTupleIterators and ArithTuples
|
||||
|
||||
First, we build a `counting_iterator` equivalent for TMA coordinates. It should support
|
||||
|
||||
* dereference to a TMA coordinate, and
|
||||
|
||||
* offset by another TMA coordinate.
|
||||
|
||||
We'll call this an `ArithmeticTupleIterator`. It stores a coordinate (a tuple of integers) that is represented as an `ArithmeticTuple`. The `ArithmeticTuple` is simply a (public subclass of) `cute::tuple` that has an overloaded `operator+` so that it can be offset by another tuple. The sum of two tuples is the tuple of the sum of the elements.
|
||||
|
||||
Now similar to `counting_iterator<int>(42)` we can create an implicit "iterator" (but without increment or other common iterator operations) over tuples that can be dereferenced and offset by other tuples
|
||||
```cpp
|
||||
ArithmeticTupleIterator citer_1 = make_inttuple_iter(42, Int<2>{}, Int<7>{});
|
||||
ArithmeticTupleIterator citer_2 = citer_1 + make_tuple(Int<0>{}, 5, Int<2>{});
|
||||
print(*citer_2);
|
||||
```
|
||||
which outputs
|
||||
```
|
||||
(42,7,_9)
|
||||
```
|
||||
|
||||
A TMA Tensor can use an iterator like this to store the current TMA coordinate "offset". The "offset" here is in quotes because it's clearly not a normal 1-D array offset or pointer.
|
||||
|
||||
In summary, one creates a TMA descriptor for the *whole global memory tensor*. The TMA descriptor defines a view into that tensor and the instruction takes TMA coordinates into that view. In order to generate and track those TMA coordinates, we define an implicit CuTe Tensor of TMA coordinates that can be tiled, sliced, and partitioned the exact same way as an ordinary CuTe Tensor.
|
||||
|
||||
We can now track and offset TMA coordinates with this iterator, but how do we get CuTe Layouts to generate non-integer offsets?
|
||||
|
||||
## Strides aren't just integers
|
||||
|
||||
Ordinary tensors have a layout that maps
|
||||
a logical coordinate `(i,j)` into a 1-D linear index `k`.
|
||||
This mapping is the inner-product of the coordinate with the strides.
|
||||
|
||||
TMA Tensors hold iterators of TMA coordinates.
|
||||
Thus, a TMA Tensor's Layout must map a logical coordinate
|
||||
to a TMA coordinate, rather than to a 1-D linear index.
|
||||
|
||||
To do this, we can abstract what a stride is. Strides need not be integers, but rather any algebraic object that supports inner-product with the integers (the logical coordinate). The obvious choice is the `ArithmeticTuple` we used earlier since they can be added to each other, but this time additionally equipped with an `operator*` so it can also be scaled by an integer.
|
||||
|
||||
### Aside: Integer-module strides
|
||||
|
||||
A group of objects that support addition between elements and product between elements and integers is called an integer-module.
|
||||
|
||||
Formally, an integer-module is an abelian group `(M,+)` equipped with `Z*M -> M`, where `Z` are the integers. That is, an integer-module `M` is
|
||||
a group that supports inner products with the integers.
|
||||
The integers are an integer-module.
|
||||
Rank-R tuples of integers are an integer-module.
|
||||
|
||||
In principle, layout strides may be any integer-module.
|
||||
|
||||
### Basis elements
|
||||
|
||||
CuTe's basis elements live in the header file `cute/numeric/arithmetic_tuple.hpp`.
|
||||
To make it easy to create `ArithmeticTuple`s that can be used as strides, CuTe defines normalized basis elements using the `E` type alias. "Normalized" means that the scaling factor of the basis element is the compile-time integer 1.
|
||||
|
||||
| C++ object | Description | String representation |
|
||||
| --- | --- | --- |
|
||||
| `E<>{}` | `1` | `1` |
|
||||
| `E<0>{}` | `(1,0,...)` | `1@0` |
|
||||
| `E<1>{}` | `(0,1,0,...)` | `1@1` |
|
||||
| `E<0,1>{}` | `((0,1,0,...),0,...)` | `1@1@0` |
|
||||
| `E<1,0>{}` | `(0,(1,0,...),0,...)` | `1@0@1` |
|
||||
|
||||
The "description" column in the above table
|
||||
interprets each basis element as an infinite tuple of integers,
|
||||
where all the tuple's entries not specified by the element's type are zero.
|
||||
We count tuple entries from left to right, starting with zero.
|
||||
For example, `E<1>{}` has a 1 in position 1: `(0,1,0,...)`.
|
||||
`E<3>{}` has a 1 in position 3: `(0,0,0,1,0,...)`.
|
||||
|
||||
Basis elements can be *nested*.
|
||||
For instance, in the above table, `E<0,1>{}` means that
|
||||
in position 0 there is a `E<1>{}`: `((0,1,0,...),0,...)`.
|
||||
|
||||
Basis elements can be *scaled*.
|
||||
That is, they can be multiplied by an integer *scaling factor*.
|
||||
For example, in `5*E<1>{}`, the scaling factor is `5`.
|
||||
`5*E<1>{}` prints as `5@1` and means `(0,5,0,...)`.
|
||||
The scaling factor commutes through any nesting.
|
||||
For instance, `5*E<0,1>{}` prints as `5@1@0`
|
||||
and means `((0,5,0,...),0,...)`.
|
||||
|
||||
Basis elements can also be added together,
|
||||
as long as their hierarchical structures are compatible.
|
||||
For example, `3*E<0>{} + 4*E<1>{}` results in `(3,4,0,...)`.
|
||||
Intuitively, "compatible" means that
|
||||
the nested structure of the two basis elements
|
||||
matches well enough to add the two elements together.
|
||||
|
||||
### Linear combinations of strides
|
||||
|
||||
Layouts work by taking the inner product
|
||||
of their input coordinate with the strides.
|
||||
For "ordinary" integer strides, e.g., `(1, 100)`,
|
||||
the inner product of the input coordinate `(i, j)`
|
||||
and the strides is `i + 100j`.
|
||||
That gives the formula for the offset.
|
||||
For strides built of basis functions, for example,
|
||||
if the strides are `(_1@0, _1@1)`,
|
||||
then the inner product of the input coordinate `(i, j)`
|
||||
with the strides is `i@0 + j@1`.
|
||||
The `i` here is a coefficient of the basis function `@0`,
|
||||
and `j` is a coefficient of the basis function `@1`.
|
||||
The result is a vector sum. We _interpret_ this result as
|
||||
"the zeroth coefficient is i, and the first coefficient is j."
|
||||
That translates into the (TMA) coordinate `(i, j)`.
|
||||
of the natural coordinate with their strides.
|
||||
For strides made of integer elements, e.g., `(1,100)`,
|
||||
the inner product of the input coordinate `(i,j)`
|
||||
and the stride is `i + 100j`.
|
||||
Offsetting an "ordinary" tensor's pointer and this index
|
||||
gives the pointer to the tensor element at `(i,j)`.
|
||||
|
||||
For strides of basis elements, we still compute the inner product of the natural coordinate with the strides.
|
||||
For example, if the stride is `(1@0,1@1)`,
|
||||
then the inner product of the input coordinate `(i,j)`
|
||||
with the strides is `i@0 + j@1 = (i,j)`.
|
||||
That translates into the (TMA) coordinate `(i,j)`.
|
||||
If we wanted to reverse the coordinates,
|
||||
then we could use `(_1@1, _1@0)` as the strides.
|
||||
Evaluating the layout would give `i@1 + j@0`,
|
||||
that is, `(j, i)`.
|
||||
then we could use `(1@1,1@0)` as the stride.
|
||||
Evaluating the layout would give `i@1 + j@0 = (j,i)`.
|
||||
|
||||
A linear combination of basis elements
|
||||
can be interpreted as a possibly multidimensional and hierarchical coordinate.
|
||||
For instance, `2*2@1@0 + 3*1@1 + 4*5@1 + 7*1@0@0`
|
||||
means `((0,2,...),0,...) + (0,3,0,...) + (0,20,0,...) + ((7,...),...) = ((7,2,...),23,...)`
|
||||
and can be interpreted as the coordinate `((7,2),23)`.
|
||||
|
||||
Thus, linear combinations of these strides can be used to generate TMA coordinates.
|
||||
These coordinates, in turn, can be used to offset TMA coordinate iterators.
|
||||
|
||||
## Application to TMA Tensors
|
||||
|
||||
Now we can build CuTe Tensors like the one seen in the introduction.
|
||||
|
||||
```cpp
|
||||
Tensor a = make_tensor(make_inttuple_iter(0,0),
|
||||
make_shape ( 4, 5),
|
||||
make_stride(E<0>{}, E<1>{}));
|
||||
print_tensor(a);
|
||||
|
||||
Tensor b = make_tensor(make_inttuple_iter(0,0),
|
||||
make_shape ( 4, 5),
|
||||
make_stride(E<1>{}, E<0>{}));
|
||||
print_tensor(b);
|
||||
```
|
||||
prints
|
||||
```
|
||||
ArithTuple(0,0) o (4,5):(_1@0,_1@1):
|
||||
(0,0) (0,1) (0,2) (0,3) (0,4)
|
||||
(1,0) (1,1) (1,2) (1,3) (1,4)
|
||||
(2,0) (2,1) (2,2) (2,3) (2,4)
|
||||
(3,0) (3,1) (3,2) (3,3) (3,4)
|
||||
|
||||
ArithTuple(0,0) o (4,5):(_1@1,_1@0):
|
||||
(0,0) (1,0) (2,0) (3,0) (4,0)
|
||||
(0,1) (1,1) (2,1) (3,1) (4,1)
|
||||
(0,2) (1,2) (2,2) (3,2) (4,2)
|
||||
(0,3) (1,3) (2,3) (3,3) (4,3)
|
||||
```
|
||||
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
Although CUTLASS 3.0 restructures the GEMM hierarchy and introduces new types for the
|
||||
threadblock layer and below, we intend the entire source code to be usable in user applications.
|
||||
We expect users to be able to `#include` any source file from CUTLASS 3.0, whether
|
||||
they implement the 2.x or the 3.x API, without breaking user builds. This means that a single
|
||||
they implement the 2.x or the 3.x API, without breaking user builds. This means that a single
|
||||
translation unit should be able to contain any valid kernel regardless of its API version. The
|
||||
sections below discuss how `device` and `kernel` layer type names are made compatible across the
|
||||
two API versions, and what the users can expect out of the `threadblock` layer API going forward.
|
||||
@ -126,7 +126,7 @@ a 2.x mainloop with a 3.0 collective epilogue.
|
||||
CUTLASS 3.x implements various embodiments of `kernel::GemmUniversal`.
|
||||
Each kernel layer schedule is specialized
|
||||
for a GEMM scheduling algorithm and GPU architecture.
|
||||
Specializations of `kernel::GemmUniversal` for 3.0 APIs live in
|
||||
Specializations of `kernel::GemmUniversal` for 3.0 APIs live in
|
||||
any of various `gemm_*.hpp` files in the directory
|
||||
[include/cutlass/gemm/kernel/](../../include/cutlass/gemm/kernel/).
|
||||
The specialization to which to dispatch is decided through the dispatch policy's `Schedule` type.
|
||||
@ -155,7 +155,7 @@ All CUTLASS 3 `kernel::GemmUniversal` specializations expose the following (stat
|
||||
static bool
|
||||
can_implement(Arguments const& args);
|
||||
|
||||
// Returns a dim3 representing the threadblock shape.
|
||||
// Returns a dim3 representing the threadblock shape.
|
||||
static dim3
|
||||
get_block_shape();
|
||||
|
||||
@ -172,7 +172,7 @@ the 3.x API or 2.x API:
|
||||
// include/cutlass/gemm/gemm.h
|
||||
|
||||
namespace cutlass:gemm::detail {
|
||||
|
||||
|
||||
// The following metafunction is used to detect whether a
|
||||
// `kernel::Gemm` or `kernel::GemmUniversal` implements the CUTLASS 3.x API,
|
||||
// by checking whether the problem shape type is aliased within.
|
||||
@ -193,7 +193,7 @@ from that of CUTLASS 2.x. With that also comes the introduction of the
|
||||
of the 2.x `cutlass::gemm::threadblock` layer. Going forward,
|
||||
CUTLASS 3.x will discontinue new developments in the following namespaces.
|
||||
|
||||
* `cutlass::*::threadblock::*`
|
||||
* `cutlass::*::threadblock::*`
|
||||
* `cutlass::*::warp::*`
|
||||
* `cutlass::gemm::thread::*`
|
||||
* `cutlass::arch::*` (except `barrier.h`)
|
||||
@ -274,7 +274,7 @@ that live in the header file
|
||||
[`cutlass/layout/matrix.h`](/include/cutlass/layout/matrix.h).
|
||||
The interpretation of these layouts in GEMM
|
||||
depends on whether they are applied
|
||||
to the input matrix A or B. For the matrix A, "column major" means
|
||||
to the input matrix A or B. For the matrix A, "column major" means
|
||||
that mode corresponding to M extent has stride 1,
|
||||
and "row major" means that mode corresponding to K extent has stride 1.
|
||||
This is the usual computer science definition
|
||||
@ -332,7 +332,7 @@ and K mode as the 1st mode of the stride.
|
||||
### Conversions between 2.x tags and 3.0 types
|
||||
|
||||
Starting with CUTLASS 3.0, all layouts are described using
|
||||
`cute::Shape` and `cute::Stride` which compose into a `cute::Layout<Shape, Stride>`.
|
||||
`cute::Shape` and `cute::Stride` which compose into a `cute::Layout<Shape, Stride>`.
|
||||
In CUTLASS 2.x, various layout tags such as `cutlass::layout::RowMajor` are used to specialize
|
||||
template implementations. These tag types only encode information about the tensor strides,
|
||||
as 2.x layouts did not incorporate any concept of tensor shape in the layout tags themselves.
|
||||
@ -415,18 +415,18 @@ Here is an excerpt.
|
||||
static int const kThreadCount = GemmKernel::MaxThreadsPerBlock;
|
||||
|
||||
// Warp shape is not a primary API type in 3.x,
|
||||
// but we can best approximate it by inspecting the TiledMma::TiledShape_MNK.
|
||||
// but we can best approximate it by inspecting the TiledMma
|
||||
// For this, we make the assumption that we always have 4 warps along M,
|
||||
// and the rest along N, with none along K. We also always round up
|
||||
// the warp count to 4 if the tiled mma is smaller than 128 threads.
|
||||
static constexpr int WarpsInMma = std::max(4, cute::size(typename GemmKernel::TiledMma{}) / 32);
|
||||
static constexpr int WarpsInMma = std::max(4, CUTE_STATIC_V(cute::size(typename GemmKernel::TiledMma{})) / 32);
|
||||
static constexpr int WarpsInMmaM = 4;
|
||||
static constexpr int WarpsInMmaN = cute::ceil_div(WarpsInMma, WarpsInMmaM);
|
||||
using WarpCount = cutlass::gemm::GemmShape<WarpsInMmaM, WarpsInMmaN, 1>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<
|
||||
cute::size<0>(typename CollectiveMainloop::TiledMma::TiledShape_MNK{}) / WarpsInMmaM,
|
||||
cute::size<1>(typename CollectiveMainloop::TiledMma::TiledShape_MNK{}) / WarpsInMmaN,
|
||||
cute::size<2>(typename CollectiveMainloop::TiledMma::TiledShape_MNK{})>;
|
||||
CUTE_STATIC_V(cute::tile_size<0>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaM,
|
||||
CUTE_STATIC_V(cute::tile_size<1>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaN,
|
||||
CUTE_STATIC_V(cute::tile_size<2>(typename CollectiveMainloop::TiledMma{}))>;
|
||||
|
||||
// Inspect TiledCopy for A and B to compute the alignment size
|
||||
static int constexpr kAlignmentA = gemm::detail::get_alignment_count_from_gmem_tiled_copy<
|
||||
@ -435,7 +435,7 @@ Here is an excerpt.
|
||||
typename CollectiveMainloop::GmemTiledCopyB, ElementB>();
|
||||
```
|
||||
|
||||
CUTLASS's library and profiler use these reflective interfaces to
|
||||
CUTLASS's library and profiler use these reflective interfaces to
|
||||
obtain the kernel's configuration parameters. Users can use these to approximate the CUTLASS 2.x types
|
||||
for 3.0 API kernels. However, the reflective interfaces cannot always match the types exactly,
|
||||
as the mappings are not always bijective.
|
||||
|
||||
Reference in New Issue
Block a user