v4.2 tag release. (#2638)
This commit is contained in:
@ -79,7 +79,7 @@ Instruction shape levels control the selection of WGMMA shapes used in kernel ge
|
||||
- **Level 2**: Includes shapes that are powers of 2.
|
||||
- **Level 3**: Includes all other shapes.
|
||||
|
||||
The detailed defination of the three instantiation levels controlling cluster shape, MMA shape multiplier, and instruction shape can be found in [sm90_shapes.py](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/sm90_shapes.py).
|
||||
The detailed definition of the three instantiation levels controlling cluster shape, MMA shape multiplier, and instruction shape can be found in [sm90_shapes.py](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/sm90_shapes.py).
|
||||
|
||||
Schedule pruning levels decide the epilogue schedule and mainloop schedule to stamp out a kernel instance. As defined in `get_valid_schedules` in [sm90_utils.py](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/sm90_utils.py),
|
||||
|
||||
@ -122,6 +122,55 @@ For each mixed dtype kernel, the kernel generator will generate combinations of
|
||||
|
||||
For {4-bits-dtype, 8-bits-dtype} x 16-bits-dtype, the kernel generator will further generate kernels using shuffled layouts for the narrow data type matrix, which may have a better performance compared to its non-shuffle counter parts.
|
||||
|
||||
## Instantiating more kernels with Blackwell
|
||||
Blackwell (SM100) and Blackwell Ultra similarly support
|
||||
`CUTLASS_LIBRARY_INSTANTIATION_LEVEL`, in order to instantiate all possible combinations.
|
||||
Due to this, `CUTLASS_LIBRARY_KERNELS` must be non-empty, since generating and filtering these
|
||||
kernels alone can take hours.
|
||||
You must also exercise caution, because not all of these configs are tested, and some may fail to
|
||||
compile or fail to launch at runtime.
|
||||
|
||||
```bash
|
||||
$ cmake .. \
|
||||
-DCUTLASS_NVCC_ARCHS="100f" \
|
||||
-DCUTLASS_LIBRARY_KERNELS="cutlass3x_sm100_tensorop_gemm_f16_f16_f32_void_f32_*" \
|
||||
-DCUTLASS_LIBRARY_INSTANTIATION_LEVEL="max" \
|
||||
-DCUTLASS_UNITY_BUILD_ENABLED=ON
|
||||
```
|
||||
|
||||
The CUTLASS profiler uses the same four-digit integer level (global instantiation level) mechanism to manage the generation of kernel configurations for Blackwell as well:
|
||||
|
||||
0. **Instruction Shape**
|
||||
1. **MMA Shape Multiplier**
|
||||
2. **Cluster Shape**
|
||||
3. **Data Type and Schedule Pruning**
|
||||
|
||||
Note for Blackwell kernels an MMA shape multiplier is no longer necessary since Blackwell kernels do not have a different
|
||||
ping pong or cooperative schedule. The profiler ignores this digit when instantiating.
|
||||
|
||||
Cluster shape levels define the number of CTAs (Cooperative Thread Arrays) included in the kernel generation:
|
||||
|
||||
- **Level 0**: Only dynamic cluster shapes.
|
||||
- **Level 1**: For 1SM kernels `(1, 1, 1)` and `(2, 1, 1)` for 2SM kernels.
|
||||
- **Level 2**: For 1SM kernels we also have `(1, 2, 1)` and for 2SM we have `(2, 2, 1)` and `(4, 1, 1)`.
|
||||
- **Level 3**: For 1SM kernels we have `(1, 4, 1)` and for 2SM we have `(2, 4, 1)` and `(4, 2, 1)`.
|
||||
- **Level 4**: For 1SM kernels we have `(4, 4, 1)` and for 2SM we have `(4, 4, 1)`.
|
||||
- **Level 5**: For 1SM kernels we have `(2, 1, 1)`.
|
||||
- **Level 6**: For 1SM kernels we have `(2, 2, 1)` and `(4, 1, 1)` and for 2SM kernels we have `(8, 1, 1)`.
|
||||
- **Level 7**: For 1SM kernels we have `(2, 4, 1)` and `(4, 2, 1)`
|
||||
- **Level 8**: For 1SM kernels we have `(1, 8, 1)` and `(8, 1, 1)`
|
||||
|
||||
Instruction shape levels control the selection of MMA shapes used in kernel generation:
|
||||
|
||||
- **Level 0**: Generates the "default" shape only.
|
||||
- **Level 1**: Includes additional shapes for FP8, FP6, and FP4 as well as MX and NVFP4.
|
||||
- **Level 2**: Includes small tile shapes.
|
||||
- **Level 3**: Includes some non-power of 2 shapes.
|
||||
- **Level 4**: Includes further small tile shapes and non-power of 2 shapes.
|
||||
- **Level 5**: Includes all shapes.
|
||||
|
||||
The detailed definition of the three instantiation levels controlling cluster shape and instruction shape can be found in [sm100_shapes.py](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/sm100_shapes.py).
|
||||
|
||||
## CUTLASS Profiler usage
|
||||
|
||||
The CUTLASS Profiler usage statement may be obtained by executing `cutlass_profiler --help` and appears as follows.
|
||||
@ -577,6 +626,10 @@ cutlass3x_sm90_tensorop_gemm_f16_f16_f16_void_f16_128x128x64_1x1x1_0_nnn_align8_
|
||||
* `f16_f16_f16_void_f16`: In this case, C type is set to `void`, indicating that residual matrix support
|
||||
is disabled.
|
||||
|
||||
## Further Documentation
|
||||
|
||||
For documentation on profiling blockwise and groupwise (software scaled) GEMMs see the [example 81 README](https://github.com/NVIDIA/cutlass/blob/main/examples/81_blackwell_gemm_blockwise/README.md).
|
||||
|
||||
# Convolution
|
||||
|
||||
The CUTLASS Profiler is capable of executing 2-D and 3-D convolution problems for forwards and backwards
|
||||
|
||||
@ -6,6 +6,7 @@ CuTe DSL API
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
changelog <cute_dsl_api/changelog.rst>
|
||||
cute <cute_dsl_api/cute.rst>
|
||||
cute_arch <cute_dsl_api/cute_arch.rst>
|
||||
cute_nvgpu <cute_dsl_api/cute_nvgpu.rst>
|
||||
|
||||
54
media/docs/pythonDSL/cute_dsl_api/changelog.rst
Normal file
54
media/docs/pythonDSL/cute_dsl_api/changelog.rst
Normal file
@ -0,0 +1,54 @@
|
||||
======================================
|
||||
Changelog for CuTe DSL API changes
|
||||
======================================
|
||||
|
||||
`4.2.0 <https://github.com/NVIDIA/cutlass/releases/tag/v4.2.0>`_ (2025-09-15)
|
||||
==============================================================================
|
||||
|
||||
* Added back ``cute.make_tiled_copy`` per the request from community
|
||||
* Added support for explicit and implicit broadcast in ``TensorSSA``
|
||||
- ``cutlass.cute.TensorSSA``: support ``broadcast_to`` and implicit broadcasting for binary operations.
|
||||
* Supported printing ``TensorSSA`` value in ``cutlass.cute.print_tensor``
|
||||
* Updated ``cute.gemm`` to support all dispatch patterns and improved checks for illegal inputs
|
||||
* Introduced automatic kernel smem usage calculation for launch config.
|
||||
* Introduced per op fast-math control for math ops(e.g. ``exp``, ``exp2``, ``log2``, ``log``)
|
||||
* Introduced ``CopyReduceBulkTensorTileS2GOp`` in `tcgen05/copy.py <https://github.com/NVIDIA/cutlass/blob/main/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py>`_ to support TMA Reduce.
|
||||
|
||||
|
||||
`4.1.0 <https://github.com/NVIDIA/cutlass/releases/tag/v4.1.0>`_ (2025-07-16)
|
||||
==============================================================================
|
||||
|
||||
* for loop
|
||||
|
||||
- Python built-in ``range`` now always generates codes and executes at runtime
|
||||
- ``cutlass.range`` is advanced ``range`` with kernel code level unrolling and pipelining control
|
||||
- Deprecated ``cutlass.range_dynamic``, please replace with ``range`` or ``cutlass.range``
|
||||
- **Experimental** Added ``pipelining`` control for compiler generated software pipeline code
|
||||
|
||||
* while/if
|
||||
|
||||
- ``while``/``if`` now by default generates codes and executes at runtime unless ``cutlass.const_expr`` is specified for the predicate
|
||||
- Deprecated ``cutlass.dynamic_expr``, please remove it
|
||||
|
||||
* Rename mbarrier functions to reduce ambiguity
|
||||
* Modify SyncObject API (``MbarrierArray``, ``NamedBarrier``, ``TmaStoreFence``) to match ``std::barrier``
|
||||
* Change pipeline ``create`` function to take only keyword arguments, and make ``barrier_storage`` optional.
|
||||
* Introduce ``cutlass.cute.arch.get_dyn_smem_size`` api to get runtime dynamic shared memory size.
|
||||
* Various API Support for SM100 BlockScaled Gemm
|
||||
|
||||
- Introduce BlockScaled MmaOps in `tcgen05/mma.py <https://github.com/NVIDIA/cutlass/blob/main/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py>`_, and provide a ``make_blockscaled_trivial_tiled_mma`` function in `blackwell_helpers.py <https://github.com/NVIDIA/cutlass/blob/main/python/CuTeDSL/cutlass/utils/blackwell_helpers.py>`_ to help construct a BlockScaled TiledMma.
|
||||
- Introduce S2T CopyOps in `tcgen05/copy.py <https://github.com/NVIDIA/cutlass/blob/main/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py>`_.
|
||||
- Introduce BlockScaled layout utilities in `blockscaled_layout.py <https://github.com/NVIDIA/cutlass/blob/main/python/CuTeDSL/cutlass/utils/blockscaled_layout.py>`_ for creating the required scale factor layouts in global memory, shared memory and tensor memory.
|
||||
|
||||
* ``cutlass.cute.compile`` now supports compilation options. Refer to `JIT compilation options <https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_general/dsl_jit_compilation_options.html>`_ for more details.
|
||||
* ``cutlass.cute.testing.assert_`` now works for device JIT function. Specify ``--enable-device-assertions`` as compilation option to enable.
|
||||
* ``cutlass.cute.make_tiled_copy`` is now deprecated. Please use ``cutlass.cute.make_tiled_copy_tv`` instead.
|
||||
* Shared memory capacity query
|
||||
|
||||
- Introduce ``cutlass.utils.get_smem_capacity_in_bytes`` for querying the shared memory capacity.
|
||||
- ``<arch>_utils.SMEM_CAPACITY["<arch_str>"]`` is now deprecated.
|
||||
|
||||
`4.0.0 <https://github.com/NVIDIA/cutlass/releases/tag/v4.0.0>`_ (2025-06-03)
|
||||
==============================================================================
|
||||
|
||||
* Fixed API mismatch in class ``cute.runtime.Pointer``: change ``element_type`` to ``dtype`` to match ``typing.Pointer``
|
||||
@ -72,6 +72,55 @@ All loop indices must be |Constexpr|.
|
||||
for i in cutlass.range(bound, unroll=2):
|
||||
cute.printf("%d\\n", i)
|
||||
|
||||
Software Pipelining
|
||||
~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Software pipelining is a technique used to optimize loops. Typically, this involves writing a prefetch loop and a main loop.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@cute.jit
|
||||
def example():
|
||||
...
|
||||
# build a circular buffer
|
||||
buffer = ...
|
||||
|
||||
# prefetch loop
|
||||
for i in range(prefetch_stages):
|
||||
cute.copy(atom, gmem[i], buffer[i], ...)
|
||||
|
||||
# main loop
|
||||
for i in range(bound):
|
||||
if i + prefetch_stages < bound:
|
||||
cute.copy(atom, gmem[i + prefetch_stages], buffer[(i + prefetch_stages) % total_stages], ...)
|
||||
|
||||
use(buffer[i % total_stages])
|
||||
|
||||
...
|
||||
|
||||
This can be tedious to write and tune. |DSL| provides a loop attribute to ask the compiler to do this.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@cute.jit
|
||||
def example():
|
||||
...
|
||||
# build a circular buffer
|
||||
buffer = ...
|
||||
|
||||
for i in cutlass.range(bound, prefetch_stages=prefetch_stages):
|
||||
# Compiler automatically handles the pipelining:
|
||||
# - Generates prefetch loop for initial stages
|
||||
# - In main loop, prefetches future data while using current data
|
||||
cute.copy(atom, gmem[i], buffer[i % total_stages], ...)
|
||||
use(buffer[i % total_stages]) # Uses data from previous iterations
|
||||
|
||||
...
|
||||
|
||||
Compiler will automatically generate the prefetch loop with `prefetch_stages` iterations and a corresponding main loop.
|
||||
|
||||
This feature is experimental and only supported on sm90 and above.
|
||||
|
||||
|
||||
If-Else Statements
|
||||
------------------
|
||||
|
||||
@ -7,7 +7,8 @@ Integration with Frameworks
|
||||
In order to facilitate the integration of CUTLASS Python with popular frameworks, we leverage the
|
||||
`DLPack protocol <https://github.com/dmlc/dlpack>`_ and transform tensors originating from these
|
||||
frameworks to CuTe tensors. The present page documents the conventions, the API available to the
|
||||
user, and provide example code snippets for common usage patterns.
|
||||
user, and provide example code snippets for common usage patterns. We also provide a section on how to
|
||||
bypass the DLPack protocol and directly call the JIT function.
|
||||
|
||||
Implicit Conversion
|
||||
-------------------
|
||||
@ -396,3 +397,84 @@ The following example demonstrates how to use ``mark_compact_shape_dynamic`` to
|
||||
mode=0, divisibility=1, stride_order=(2, 1, 3, 0, 4)
|
||||
)
|
||||
# The stride_order is not consistent with the layout
|
||||
|
||||
|
||||
Bypass the DLPack Protocol
|
||||
--------------------------
|
||||
|
||||
In certain scenarios, users may wish to bypass the DLPack protocol and invoke the JIT function directly.
|
||||
This can be accomplished by creating a lightweight JIT wrapper around the existing JIT function,
|
||||
utilizing ``cute.ptr`` and ``cute.make_tensor`` to pass pointers and construct tensors directly.
|
||||
|
||||
Typical use cases for bypassing DLPack include:
|
||||
1. Users want to call the JIT function directly to avoid the overhead introduced by the DLPack protocol.
|
||||
2. DLPack canonicalizes the stride of shape-1 dimensions to 1, which may result in incorrect alignment
|
||||
propagation and affect memory access or performance.
|
||||
3. DLPack may lack support for some narrow data types.
|
||||
|
||||
The following example illustrates how to bypass the DLPack protocol when invoking a JIT function.
|
||||
Assume we have a pre-defined ``TensorOpGemm`` kernel whose JIT interface expects three
|
||||
arguments of type ``cute.Tensor``. To enable direct invocation without DLPack, we first define a JIT wrapper
|
||||
function that accepts ``cute.Pointer`` types as parameters. Within this wrapper, we use ``cute.make_tensor``
|
||||
to construct tensors from the provided pointers, and then call the ``TensorOpGemm`` kernel as usual.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@cute.jit
|
||||
def tensor_op_gemm_wrapper(
|
||||
a_ptr: cute.Pointer,
|
||||
b_ptr: cute.Pointer,
|
||||
c_ptr: cute.Pointer,
|
||||
m: cutlass.Int32,
|
||||
n: cutlass.Int32,
|
||||
k: cutlass.Int32,
|
||||
l: cutlass.Int32,
|
||||
):
|
||||
|
||||
# Assume alignment of shape to call tensorop_gemm example
|
||||
m = cute.assume(m, divby=8)
|
||||
n = cute.assume(n, divby=8)
|
||||
|
||||
# Torch is row major
|
||||
a_layout = cute.make_ordered_layout((m, k, l), order=(0, 1, 2))
|
||||
b_layout = cute.make_ordered_layout((n, k, l), order=(0, 1, 2))
|
||||
c_layout = cute.make_ordered_layout((m, n, l), order=(1, 0, 2))
|
||||
mA = cute.make_tensor(a_ptr, layout=a_layout)
|
||||
mB = cute.make_tensor(b_ptr, layout=b_layout)
|
||||
mC = cute.make_tensor(c_ptr, layout=c_layout)
|
||||
|
||||
# TensorOpGemm is a pre-defined kernel from our example
|
||||
tensor_op_gemm = TensorOpGemm(
|
||||
a_ptr.value_type, c_ptr.value_type, cutlass.Float32, (2, 2, 1)
|
||||
)
|
||||
|
||||
tensor_op_gemm(mA, mB, mC)
|
||||
|
||||
To pass a PyTorch tensor to this new JIT wrapper, we retrieve the raw pointer from the PyTorch tensor
|
||||
and create a ``cute.Pointer`` instance using ``cute.make_ptr``.
|
||||
This approach allows us to bypass the DLPack protocol entirely, avoiding its overhead and potential
|
||||
issues with shape-1 dimension handling.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
a = torch.randn(
|
||||
m, k, l, dtype=torch.float16, device="cuda"
|
||||
).permute(2, 1, 0)
|
||||
b = torch.randn(
|
||||
n, k, l, dtype=torch.float16, device="cuda"
|
||||
).permute(2, 1, 0)
|
||||
c = torch.randn(
|
||||
n, m, l, dtype=torch.float16, device="cuda"
|
||||
).permute(1, 2, 0)
|
||||
|
||||
# from cutlass.cute.runtime import make_ptr
|
||||
a_ptr = make_ptr(
|
||||
cutlass.Float16, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=32
|
||||
)
|
||||
b_ptr = make_ptr(
|
||||
cutlass.Float16, b.data_ptr(), cute.AddressSpace.gmem, assumed_align=32
|
||||
)
|
||||
c_ptr = make_ptr(
|
||||
cutlass.Float16, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=32
|
||||
)
|
||||
tensor_op_gemm_wrapper(a_ptr, b_ptr, c_ptr, m, n, k, l)
|
||||
|
||||
Reference in New Issue
Block a user