v4.1 release

This commit is contained in:
Junkai-Wu
2025-07-03 20:07:53 +08:00
committed by GitHub
parent b995f93317
commit a1aaf2300a
155 changed files with 18407 additions and 6068 deletions

View File

@ -29,14 +29,15 @@
import argparse
import operator
import torch
from typing import Type
import time
from typing import Type, List
import cuda.bindings.driver as cuda
import torch
import cutlass
import cutlass.cute as cute
import cutlass.cute.testing as testing
import cutlass.torch as cutlass_torch
from cutlass.cute.runtime import from_dlpack
@ -77,8 +78,7 @@ while maintaining high performance through efficient memory access patterns.
@cute.kernel
def elementwise_apply_kernel(
op: cutlass.Constexpr,
gA: cute.Tensor,
gB: cute.Tensor,
inputs: List[cute.Tensor],
gC: cute.Tensor,
cC: cute.Tensor, # coordinate tensor
shape: cute.Shape,
@ -90,48 +90,46 @@ def elementwise_apply_kernel(
# slice for CTAs
cta_coord = ((None, None), bidx)
# logical coord -> address
ctaA = gA[cta_coord] # (TileM, TileN)
ctaB = gB[cta_coord] # (TileM, TileN)
# Leverage the meta-programming capability of the DSL to slice the tensors for each input
# All for loops below on input tensors would be fully unrolled automatically at compile time
ctaInputs = [t[cta_coord] for t in inputs] # (TileM, TileN)
ctaC = gC[cta_coord] # (TileM, TileN)
ctaCrd = cC[cta_coord] # (TileM, TileN)
print(f"[DSL INFO] Sliced Tensors per thread block:")
print(f"[DSL INFO] ctaA = {ctaA.type}")
print(f"[DSL INFO] ctaB = {ctaB.type}")
for i in cutlass.range_constexpr(len(ctaInputs)):
print(f"[DSL INFO] ctaInputs{i} = {ctaInputs[i].type}")
print(f"[DSL INFO] ctaC = {ctaC.type}")
print(f"[DSL INFO] ctaCrd = {ctaCrd.type}")
# compose with CTA TV layout
# (tid, vid) -> address
tidfrgA = cute.composition(ctaA, tv_layout)
tidfrgB = cute.composition(ctaB, tv_layout)
tidfrgInputs = [cute.composition(t, tv_layout) for t in ctaInputs]
tidfrgC = cute.composition(ctaC, tv_layout)
tidfrgCrd = cute.composition(ctaCrd, tv_layout)
# print(f"{tv_layout = }")
# print(f"{tidfrgA = }")
# print(f"{tidfrgAB[0] = }")
thr_coord = (tidx, (None, None))
# slice for threads
# vid -> address
thrA = tidfrgA[thr_coord] # (V)
thrB = tidfrgB[thr_coord] # (V)
thrInputs = [t[thr_coord] for t in tidfrgInputs] # (V)
thrC = tidfrgC[thr_coord] # (V)
thrCrd = tidfrgCrd[thr_coord]
print(f"[DSL INFO] Sliced Tensors per thread:")
print(f"[DSL INFO] thrA = {thrA.type}")
print(f"[DSL INFO] thrB = {thrB.type}")
for i in cutlass.range_constexpr(len(thrInputs)):
print(f"[DSL INFO] thrInputs{i} = {thrInputs[i].type}")
print(f"[DSL INFO] thrC = {thrC.type}")
print(f"[DSL INFO] thrCrd = {thrCrd.type}")
# allocate fragments for gmem->rmem
frgA = cute.make_fragment_like(thrA, gA.element_type)
frgB = cute.make_fragment_like(thrB, gB.element_type)
frgInputs = [cute.make_fragment_like(t, t.element_type) for t in thrInputs]
frgC = cute.make_fragment_like(thrC, gC.element_type)
frgPred = cute.make_fragment(thrCrd.shape, cutlass.Boolean)
for i in cutlass.range_dynamic(cute.size(frgPred), unroll=1):
for i in cutlass.range(cute.size(frgPred), unroll=1):
frgPred[i] = cute.elem_less(thrCrd[i], shape)
# if tidx == 0 and bidx == 0:
@ -142,10 +140,13 @@ def elementwise_apply_kernel(
##########################################################
# declare the atoms which will be used later for memory copy
# Compile time validation: expect same element type for all input tensors so as to reuse the copy atom for load
assert all(t.element_type == inputs[0].element_type for t in inputs)
copy_atom_load = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
gA.element_type,
num_bits_per_copy=gA.element_type.width,
inputs[0].element_type,
num_bits_per_copy=inputs[0].element_type.width,
)
copy_atom_store = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
@ -153,12 +154,12 @@ def elementwise_apply_kernel(
num_bits_per_copy=gC.element_type.width,
)
cute.copy(copy_atom_load, thrA, frgA, pred=frgPred)
cute.copy(copy_atom_load, thrB, frgB, pred=frgPred)
for thrInput, frgInput in zip(thrInputs, frgInputs):
cute.copy(copy_atom_load, thrInput, frgInput, pred=frgPred)
# Load data before use. The compiler will optimize the copy and load
# operations to convert some memory ld/st into register uses.
result = op(frgA.load(), frgB.load())
result = op(*[frgInput.load() for frgInput in frgInputs])
# Save the results back to registers. Here we reuse b's registers.
frgC.store(result)
@ -173,6 +174,7 @@ def elementwise_apply(
a: cute.Tensor,
b: cute.Tensor,
result: cute.Tensor,
stream: cuda.CUstream,
):
"""CUDA kernel applying binary operator on each element of two n-D input tensors in
CuTe Python and store to result tensor.
@ -262,8 +264,7 @@ def elementwise_apply(
# Async token(s) can also be specified as dependencies
elementwise_apply_kernel(
op,
gA,
gB,
[gA, gB], # Group input tensors into a list as a single argument
gC,
cC,
result.shape,
@ -271,6 +272,7 @@ def elementwise_apply(
).launch(
grid=[cute.size(gC, mode=[1]), 1, 1],
block=[cute.size(tv_layout, mode=[0]), 1, 1],
stream=stream,
)
@ -287,6 +289,11 @@ def run_elementwise_apply_and_verify(
if not torch.cuda.is_available():
raise RuntimeError(f"Ampere GPU is required to run this example!")
# Create non default CUDA stream from PyTorch
torch_stream = torch.cuda.Stream()
# Get the raw stream pointer as a CUstream
current_stream = cuda.CUstream(torch_stream.cuda_stream)
print(f"\nRunning Elementwise Apply test with:")
print(f"Tensor dimensions: [{M}, {N}]")
print(f"Input and Output Data type: {dtype}")
@ -309,20 +316,16 @@ def run_elementwise_apply_and_verify(
if op in (operator.truediv, operator.floordiv):
b = torch.where(b == 0, torch.tensor(epsilon), b)
print("Compiling kernel with cute.compile ...")
start_time = time.time()
compiled_func = cute.compile(elementwise_apply, op, from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic())
compilation_time = time.time() - start_time
print(f"Compilation time: {compilation_time:.4f} seconds")
print("Executing elementwise apply kernel...")
# Get current CUDA stream from PyTorch
torch_stream = torch.cuda.current_stream()
# Get the raw stream pointer as a CUstream
current_stream = cuda.CUstream(torch_stream.cuda_stream)
if not skip_ref_check:
compiled_func(from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic())
elementwise_apply(
op,
from_dlpack(a),
from_dlpack(b),
from_dlpack(c).mark_layout_dynamic(),
current_stream,
)
print("Verifying results...")
torch.testing.assert_close(op(a, b), c)
print("Results verified successfully!")
@ -330,28 +333,32 @@ def run_elementwise_apply_and_verify(
if not benchmark:
return
# Create CUDA events for timing
start_event = cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1]
end_event = cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1]
compiled_func = cute.compile(
elementwise_apply,
op,
from_dlpack(a),
from_dlpack(b),
from_dlpack(c).mark_layout_dynamic(),
current_stream,
)
# Warmup
for _ in range(warmup_iterations):
compiled_func(from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic())
# When compiled we inlined op in the kernel, so we do not pass it when benchmarking
# Record start event
cuda.cuEventRecord(start_event, current_stream)
avg_time_us = testing.benchmark(
compiled_func,
kernel_arguments=testing.JitArguments(
from_dlpack(a),
from_dlpack(b),
from_dlpack(c).mark_layout_dynamic(),
current_stream,
),
warmup_iterations=warmup_iterations,
profiling_iterations=iterations,
use_cuda_graphs=True,
stream=current_stream,
)
# Execute the kernel
for _ in range(iterations):
compiled_func(from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic())
# Record end event
cuda.cuEventRecord(end_event, current_stream)
cuda.cuEventSynchronize(end_event)
# Calculate elapsed time
err, elapsed_time = cuda.cuEventElapsedTime(start_event, end_event)
avg_time = elapsed_time / iterations
avg_time = avg_time_us / 1e3
# Print execution results
print(f"Kernel execution time: {avg_time:.4f} ms")
@ -360,10 +367,6 @@ def run_elementwise_apply_and_verify(
)
print(f"First few elements of result: \n{c[:3, :3]}")
# Destroy events
cuda.cuEventDestroy(start_event)
cuda.cuEventDestroy(end_event)
if __name__ == "__main__":
parser = argparse.ArgumentParser(