394 lines
14 KiB
Python
394 lines
14 KiB
Python
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions are met:
|
|
|
|
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
# list of conditions and the following disclaimer.
|
|
|
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
# this list of conditions and the following disclaimer in the documentation
|
|
# and/or other materials provided with the distribution.
|
|
|
|
# 3. Neither the name of the copyright holder nor the names of its
|
|
# contributors may be used to endorse or promote products derived from
|
|
# this software without specific prior written permission.
|
|
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
|
|
import argparse
|
|
import operator
|
|
import time
|
|
from typing import Type, List
|
|
|
|
import cuda.bindings.driver as cuda
|
|
import torch
|
|
|
|
import cutlass
|
|
import cutlass.cute as cute
|
|
import cutlass.cute.testing as testing
|
|
import cutlass.torch as cutlass_torch
|
|
from cutlass.cute.runtime import from_dlpack
|
|
|
|
"""
|
|
An Elementwise Apply Example using CuTe DSL.
|
|
|
|
This example kernel demonstrates the meta-programming capability of the CuTe DSL by allowing
|
|
customization of elementwise operations through lambda functions. The kernel copies data from
|
|
global memory to register memory (rmem), applies a user-defined operation to the elements,
|
|
and stores the result back to global memory.
|
|
|
|
Primary goals of this example:
|
|
1. Demonstrate meta-programming capability by passing lambda functions to customize elementwise operations
|
|
2. Show how to apply different operations (add, multiply, etc.) using the same kernel structure
|
|
3. Illustrate how to parameterize CUDA kernels with operation types at compile time
|
|
|
|
To run this example:
|
|
|
|
.. code-block:: bash
|
|
|
|
# Run with addition operation
|
|
python examples/ampere/elementwise_apply.py --M 1024 --N 512 --op add
|
|
|
|
# Run with multiplication operation
|
|
python examples/ampere/elementwise_apply.py --M 1024 --N 512 --op mul
|
|
|
|
# Run with subtraction operation
|
|
python examples/ampere/elementwise_apply.py --M 1024 --N 512 --op sub
|
|
|
|
# Benchmark performance
|
|
python examples/ampere/elementwise_apply.py --M 2048 --N 2048 --op add --benchmark --warmup_iterations 2 --iterations 10
|
|
|
|
The example demonstrates how to express complex CUDA kernels with customizable operations
|
|
while maintaining high performance through efficient memory access patterns.
|
|
"""
|
|
|
|
|
|
@cute.kernel
|
|
def elementwise_apply_kernel(
|
|
op: cutlass.Constexpr,
|
|
inputs: List[cute.Tensor],
|
|
gC: cute.Tensor,
|
|
cC: cute.Tensor, # coordinate tensor
|
|
shape: cute.Shape,
|
|
tv_layout: cute.Layout, # (tid, vid) -> logic coord
|
|
):
|
|
tidx, _, _ = cute.arch.thread_idx()
|
|
bidx, _, _ = cute.arch.block_idx()
|
|
|
|
# slice for CTAs
|
|
cta_coord = ((None, None), bidx)
|
|
# logical coord -> address
|
|
# Leverage the meta-programming capability of the DSL to slice the tensors for each input
|
|
# All for loops below on input tensors would be fully unrolled automatically at compile time
|
|
ctaInputs = [t[cta_coord] for t in inputs] # (TileM, TileN)
|
|
ctaC = gC[cta_coord] # (TileM, TileN)
|
|
ctaCrd = cC[cta_coord] # (TileM, TileN)
|
|
|
|
print(f"[DSL INFO] Sliced Tensors per thread block:")
|
|
for i in cutlass.range_constexpr(len(ctaInputs)):
|
|
print(f"[DSL INFO] ctaInputs{i} = {ctaInputs[i].type}")
|
|
print(f"[DSL INFO] ctaC = {ctaC.type}")
|
|
print(f"[DSL INFO] ctaCrd = {ctaCrd.type}")
|
|
|
|
# compose with CTA TV layout
|
|
# (tid, vid) -> address
|
|
tidfrgInputs = [cute.composition(t, tv_layout) for t in ctaInputs]
|
|
tidfrgC = cute.composition(ctaC, tv_layout)
|
|
tidfrgCrd = cute.composition(ctaCrd, tv_layout)
|
|
# print(f"{tv_layout = }")
|
|
# print(f"{tidfrgAB[0] = }")
|
|
|
|
thr_coord = (tidx, (None, None))
|
|
|
|
# slice for threads
|
|
# vid -> address
|
|
thrInputs = [t[thr_coord] for t in tidfrgInputs] # (V)
|
|
thrC = tidfrgC[thr_coord] # (V)
|
|
thrCrd = tidfrgCrd[thr_coord]
|
|
|
|
print(f"[DSL INFO] Sliced Tensors per thread:")
|
|
for i in cutlass.range_constexpr(len(thrInputs)):
|
|
print(f"[DSL INFO] thrInputs{i} = {thrInputs[i].type}")
|
|
print(f"[DSL INFO] thrC = {thrC.type}")
|
|
print(f"[DSL INFO] thrCrd = {thrCrd.type}")
|
|
|
|
# allocate fragments for gmem->rmem
|
|
frgInputs = [cute.make_fragment_like(t, t.element_type) for t in thrInputs]
|
|
frgC = cute.make_fragment_like(thrC, gC.element_type)
|
|
frgPred = cute.make_fragment(thrCrd.shape, cutlass.Boolean)
|
|
|
|
for i in cutlass.range(cute.size(frgPred), unroll=1):
|
|
frgPred[i] = cute.elem_less(thrCrd[i], shape)
|
|
|
|
# if tidx == 0 and bidx == 0:
|
|
# cute.print_tensor(frgPred)
|
|
|
|
##########################################################
|
|
# Move data to reg address space
|
|
##########################################################
|
|
|
|
# declare the atoms which will be used later for memory copy
|
|
# Compile time validation: expect same element type for all input tensors so as to reuse the copy atom for load
|
|
assert all(t.element_type == inputs[0].element_type for t in inputs)
|
|
|
|
copy_atom_load = cute.make_copy_atom(
|
|
cute.nvgpu.CopyUniversalOp(),
|
|
inputs[0].element_type,
|
|
num_bits_per_copy=inputs[0].element_type.width,
|
|
)
|
|
copy_atom_store = cute.make_copy_atom(
|
|
cute.nvgpu.CopyUniversalOp(),
|
|
gC.element_type,
|
|
num_bits_per_copy=gC.element_type.width,
|
|
)
|
|
|
|
for thrInput, frgInput in zip(thrInputs, frgInputs):
|
|
cute.copy(copy_atom_load, thrInput, frgInput, pred=frgPred)
|
|
|
|
# Load data before use. The compiler will optimize the copy and load
|
|
# operations to convert some memory ld/st into register uses.
|
|
result = op(*[frgInput.load() for frgInput in frgInputs])
|
|
|
|
# Save the results back to registers. Here we reuse b's registers.
|
|
frgC.store(result)
|
|
|
|
# Copy the results back to c
|
|
cute.copy(copy_atom_store, frgC, thrC, pred=frgPred)
|
|
|
|
|
|
@cute.jit
|
|
def elementwise_apply(
|
|
op: cutlass.Constexpr,
|
|
a: cute.Tensor,
|
|
b: cute.Tensor,
|
|
result: cute.Tensor,
|
|
stream: cuda.CUstream,
|
|
):
|
|
"""CUDA kernel applying binary operator on each element of two n-D input tensors in
|
|
CuTe Python and store to result tensor.
|
|
|
|
:param op: Binary operator or lambda function to apply element-wise
|
|
:type op: cutlass.Constexpr
|
|
:param a: First input tensor
|
|
:type a: cute.Tensor
|
|
:param b: Second input tensor
|
|
:type b: cute.Tensor
|
|
:param result: Output tensor to store the results of op(a, b)
|
|
:type result: cute.Tensor
|
|
:return: None
|
|
:rtype: None
|
|
|
|
.. code-block:: python
|
|
|
|
# Example 1: Adding two tensors
|
|
x = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32, device="cuda")
|
|
y = torch.tensor([[5, 6], [7, 8]], dtype=torch.float32, device="cuda")
|
|
result = torch.empty_like(x)
|
|
elementwise_apply(operator.add, from_dlpack(x), from_dlpack(y), from_dlpack(result))
|
|
# result:
|
|
# tensor([[6.0, 8.0],
|
|
# [10.0, 12.0]], device='cuda:0')
|
|
|
|
# Example 2: Using a lambda function
|
|
elementwise_apply(lambda a, b: a * a + b * b, from_dlpack(x), from_dlpack(y), from_dlpack(result))
|
|
# result:
|
|
# tensor([[ 2., 8.],
|
|
# [ 54., 512.]], device='cuda:0')
|
|
"""
|
|
|
|
# Baseline: naive TV layout
|
|
# * mA layout: (4096, 4096):(4096, 1)
|
|
# * TV layout map to (512, 4) tile
|
|
# * tidx maps to mode-0 but input layout is contiguous on mode-1, performance will be bad
|
|
# tv_layout = cute.make_layout((128, (4, 4)), stride=(4, (512, 1)))
|
|
# cta_tiler = (512, 4)
|
|
|
|
# Opt-1: better TV layout with better 1D thread layout (SOL with 1D thread layout)
|
|
# * mA layout: (4096, 4096):(4096, 1)
|
|
# * TV layout map to (4, 512) tile
|
|
# * tidx maps to mode-1 which is leading mode of input tensor for coalesced load
|
|
# tv_layout = cute.make_layout((128, (4, 4)), stride=(16, (4, 1)))
|
|
# cta_tiler = (4, 512)
|
|
|
|
# Opt-2: 2D tile but worse
|
|
# * mA layout: (4096, 4096):(4096, 1)
|
|
# * TV layout map to (128, 16) logical tile
|
|
# * V layout is bad as contiguous mode is not on right-most
|
|
# * `cute.copy` only supports vectorize when stride-1 of v-layout on right-most )
|
|
# tv_layout = cute.make_layout(((32, 4), (4, 4)), stride=((4, 512), (1, 128)))
|
|
# cta_tiler = (128, 16)
|
|
|
|
# Opt-3: SOL with 2D thread tile
|
|
# * mA layout: (4096, 4096):(4096, 1)
|
|
# * TV layout map to (16, 128) logical tile
|
|
# * tidx maps to mode-1 and input layout is contiguous on mode-1 for coalesced load-store
|
|
thr_layout = cute.make_layout((4, 32), stride=(32, 1))
|
|
val_layout = cute.make_layout((4, 4), stride=(4, 1))
|
|
tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)
|
|
|
|
print(f"[DSL INFO] Input Tensors:")
|
|
print(f"[DSL INFO] a = {a.type}")
|
|
print(f"[DSL INFO] b = {b.type}")
|
|
print(f"[DSL INFO] result = {result.type}")
|
|
|
|
print(f"[DSL INFO] Tiling Parameters:")
|
|
print(f"[DSL INFO] tiler_mn = {tiler_mn} per thread block")
|
|
print(f"[DSL INFO] tv_layout = {tv_layout}")
|
|
|
|
gA = cute.zipped_divide(a, tiler_mn) # ((TileM, TileN), (RestM, RestN))
|
|
gB = cute.zipped_divide(b, tiler_mn) # ((TileM, TileN), (RestM, RestN))
|
|
gC = cute.zipped_divide(result, tiler_mn) # ((TileM, TileN), (RestM, RestN))
|
|
|
|
print(f"[DSL INFO] Tiled Tensors:")
|
|
print(f"[DSL INFO] gA = {gA.type}")
|
|
print(f"[DSL INFO] gB = {gB.type}")
|
|
print(f"[DSL INFO] gC = {gC.type}")
|
|
|
|
idC = cute.make_identity_tensor(result.shape)
|
|
cC = cute.zipped_divide(idC, tiler=tiler_mn)
|
|
print(f"[DSL INFO] coord tensor = {cC.type}")
|
|
|
|
# Launch the kernel asynchronously
|
|
# Async token(s) can also be specified as dependencies
|
|
elementwise_apply_kernel(
|
|
op,
|
|
[gA, gB], # Group input tensors into a list as a single argument
|
|
gC,
|
|
cC,
|
|
result.shape,
|
|
tv_layout,
|
|
).launch(
|
|
grid=[cute.size(gC, mode=[1]), 1, 1],
|
|
block=[cute.size(tv_layout, mode=[0]), 1, 1],
|
|
stream=stream,
|
|
)
|
|
|
|
|
|
def run_elementwise_apply_and_verify(
|
|
op,
|
|
M,
|
|
N,
|
|
dtype: Type[cutlass.Numeric],
|
|
skip_ref_check=False,
|
|
benchmark=True,
|
|
warmup_iterations=2,
|
|
iterations=100,
|
|
):
|
|
if not torch.cuda.is_available():
|
|
raise RuntimeError(f"Ampere GPU is required to run this example!")
|
|
|
|
# Create non default CUDA stream from PyTorch
|
|
torch_stream = torch.cuda.Stream()
|
|
# Get the raw stream pointer as a CUstream
|
|
current_stream = cuda.CUstream(torch_stream.cuda_stream)
|
|
|
|
print(f"\nRunning Elementwise Apply test with:")
|
|
print(f"Tensor dimensions: [{M}, {N}]")
|
|
print(f"Input and Output Data type: {dtype}")
|
|
print(f"Warmup iterations: {warmup_iterations}")
|
|
print(f"Measurement iterations: {iterations}\n")
|
|
|
|
torch_dtype = cutlass_torch.dtype(dtype)
|
|
|
|
# Allocate tensors with random values.
|
|
a = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype)
|
|
b = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype)
|
|
c = torch.zeros_like(a)
|
|
|
|
print(f"Input tensor shapes:")
|
|
print(f"a: {a.shape}, dtype: {a.dtype}")
|
|
print(f"b: {b.shape}, dtype: {b.dtype}")
|
|
print(f"c: {c.shape}, dtype: {c.dtype}\n")
|
|
|
|
epsilon = 1.2
|
|
if op in (operator.truediv, operator.floordiv):
|
|
b = torch.where(b == 0, torch.tensor(epsilon), b)
|
|
|
|
print("Executing elementwise apply kernel...")
|
|
|
|
if not skip_ref_check:
|
|
elementwise_apply(
|
|
op,
|
|
from_dlpack(a),
|
|
from_dlpack(b),
|
|
from_dlpack(c).mark_layout_dynamic(),
|
|
current_stream,
|
|
)
|
|
print("Verifying results...")
|
|
torch.testing.assert_close(op(a, b), c)
|
|
print("Results verified successfully!")
|
|
|
|
if not benchmark:
|
|
return
|
|
|
|
compiled_func = cute.compile(
|
|
elementwise_apply,
|
|
op,
|
|
from_dlpack(a),
|
|
from_dlpack(b),
|
|
from_dlpack(c).mark_layout_dynamic(),
|
|
current_stream,
|
|
)
|
|
|
|
# When compiled we inlined op in the kernel, so we do not pass it when benchmarking
|
|
|
|
avg_time_us = testing.benchmark(
|
|
compiled_func,
|
|
kernel_arguments=testing.JitArguments(
|
|
from_dlpack(a),
|
|
from_dlpack(b),
|
|
from_dlpack(c).mark_layout_dynamic(),
|
|
current_stream,
|
|
),
|
|
warmup_iterations=warmup_iterations,
|
|
iterations=iterations,
|
|
use_cuda_graphs=True,
|
|
stream=current_stream,
|
|
)
|
|
|
|
avg_time = avg_time_us / 1e3
|
|
|
|
# Print execution results
|
|
print(f"Kernel execution time: {avg_time:.4f} ms")
|
|
print(
|
|
f"Achieved memory throughput: {(3 * a.numel() * dtype.width // 8) / (avg_time / 1000) / 1e9:.2f} GB/s"
|
|
)
|
|
print(f"First few elements of result: \n{c[:3, :3]}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
description="example of elementwise apply to demonstrate building elementwise kernels"
|
|
)
|
|
parser.add_argument("--M", default=128, type=int)
|
|
parser.add_argument("--N", default=128, type=int)
|
|
parser.add_argument("--op", default="add", type=str)
|
|
parser.add_argument("--warmup_iterations", default=2, type=int)
|
|
parser.add_argument("--iterations", default=100, type=int)
|
|
parser.add_argument("--skip_ref_check", action="store_true")
|
|
parser.add_argument("--benchmark", action="store_true")
|
|
args = parser.parse_args()
|
|
run_elementwise_apply_and_verify(
|
|
getattr(operator, args.op),
|
|
args.M,
|
|
args.N,
|
|
dtype=cutlass.Float32,
|
|
warmup_iterations=args.warmup_iterations,
|
|
iterations=args.iterations,
|
|
skip_ref_check=args.skip_ref_check,
|
|
benchmark=args.benchmark,
|
|
)
|
|
print("\nPASS")
|