v4.2 tag release. (#2638)
This commit is contained in:
314
examples/python/CuTeDSL/ampere/all_reduce.py
Normal file
314
examples/python/CuTeDSL/ampere/all_reduce.py
Normal file
@ -0,0 +1,314 @@
|
||||
import os
|
||||
import torch
|
||||
import argparse
|
||||
from cuda import cuda
|
||||
from cuda.bindings import driver
|
||||
from typing import Type
|
||||
|
||||
import torch.distributed as dist
|
||||
import torch.distributed._symmetric_memory as symm_mem
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
from cutlass._mlir.dialects import llvm, builtin, vector, arith
|
||||
|
||||
WORLD_SIZE = 8
|
||||
PING_PONG_SIZE = 3
|
||||
|
||||
|
||||
def setup(rank, world_size):
|
||||
# set environment variables for torch.distributed environment
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "12959"
|
||||
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
|
||||
def cleanup():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
class AllReduceKernel:
|
||||
|
||||
@cute.jit
|
||||
def __call__(
|
||||
self,
|
||||
rank,
|
||||
signal,
|
||||
local_input: cute.Tensor,
|
||||
local_output: cute.Tensor,
|
||||
buffer0: cute.Tensor,
|
||||
buffer1: cute.Tensor,
|
||||
buffer2: cute.Tensor,
|
||||
buffer3: cute.Tensor,
|
||||
buffer4: cute.Tensor,
|
||||
buffer5: cute.Tensor,
|
||||
buffer6: cute.Tensor,
|
||||
buffer7: cute.Tensor,
|
||||
stream: cuda.CUstream,
|
||||
):
|
||||
# define constants for future use
|
||||
num_of_elements = cute.size(local_input.layout)
|
||||
# 128 threads per block and 4 elements per thread
|
||||
tv_layout = cute.make_layout(((128), (4)), stride=((1), (1)))
|
||||
tile = cute.size(tv_layout.shape)
|
||||
|
||||
buffers = [
|
||||
buffer0,
|
||||
buffer1,
|
||||
buffer2,
|
||||
buffer3,
|
||||
buffer4,
|
||||
buffer5,
|
||||
buffer6,
|
||||
buffer7,
|
||||
]
|
||||
tiled_buffers = [
|
||||
cute.logical_divide(buffer, (tile, None, None)) for buffer in buffers
|
||||
]
|
||||
|
||||
tiled_input = cute.zipped_divide(local_input, cute.make_layout(tile))
|
||||
tiled_output = cute.zipped_divide(local_output, cute.make_layout(tile))
|
||||
self.kernel(
|
||||
tiled_buffers[0],
|
||||
tiled_buffers[1],
|
||||
tiled_buffers[2],
|
||||
tiled_buffers[3],
|
||||
tiled_buffers[4],
|
||||
tiled_buffers[5],
|
||||
tiled_buffers[6],
|
||||
tiled_buffers[7],
|
||||
tiled_input,
|
||||
tiled_output,
|
||||
tv_layout,
|
||||
cutlass.Int32(signal),
|
||||
cutlass.Int32(rank),
|
||||
).launch(
|
||||
grid=[num_of_elements // tile, 1, 1],
|
||||
block=[tv_layout.shape[0], 1, 1],
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
# GPU device kernel
|
||||
@cute.kernel
|
||||
def kernel(
|
||||
self,
|
||||
buffer0: cute.Tensor,
|
||||
buffer1: cute.Tensor,
|
||||
buffer2: cute.Tensor,
|
||||
buffer3: cute.Tensor,
|
||||
buffer4: cute.Tensor,
|
||||
buffer5: cute.Tensor,
|
||||
buffer6: cute.Tensor,
|
||||
buffer7: cute.Tensor,
|
||||
local_input: cute.Tensor,
|
||||
local_output: cute.Tensor,
|
||||
tv_layout: cute.Layout,
|
||||
signal: cutlass.Int32,
|
||||
rank: cutlass.Int32,
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, _, _ = cute.arch.block_idx()
|
||||
ping = signal % 3
|
||||
pong = (signal + 1) % 3
|
||||
|
||||
buffers = [
|
||||
buffer0,
|
||||
buffer1,
|
||||
buffer2,
|
||||
buffer3,
|
||||
buffer4,
|
||||
buffer5,
|
||||
buffer6,
|
||||
buffer7,
|
||||
]
|
||||
|
||||
def get_buffer():
|
||||
t = buffers[2]
|
||||
if rank == cutlass.Int32(0):
|
||||
t = buffers[0]
|
||||
if rank == cutlass.Int32(1):
|
||||
t = buffers[1]
|
||||
if rank == cutlass.Int32(2):
|
||||
t = buffers[2]
|
||||
if rank == cutlass.Int32(3):
|
||||
t = buffers[3]
|
||||
if rank == cutlass.Int32(4):
|
||||
t = buffers[4]
|
||||
if rank == cutlass.Int32(5):
|
||||
t = buffers[5]
|
||||
if rank == cutlass.Int32(6):
|
||||
t = buffers[6]
|
||||
if rank == cutlass.Int32(7):
|
||||
t = buffers[7]
|
||||
return t
|
||||
|
||||
buffer_local = get_buffer()
|
||||
cta_coord = (None, bidx)
|
||||
local_tile_in = local_input[cta_coord]
|
||||
local_tile_out = local_output[cta_coord]
|
||||
|
||||
ping_coord = ((None, bidx), None, ping)
|
||||
read_buffer = buffer_local[ping_coord]
|
||||
|
||||
pong_coord = ((None, bidx), None, pong)
|
||||
clear_buffer = buffer_local[pong_coord]
|
||||
|
||||
write_coord = ((None, bidx), rank, ping)
|
||||
write_buffers = [buffer[write_coord] for buffer in buffers]
|
||||
|
||||
# assume all buffers have the same element type with input
|
||||
copy_atom_load = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
buffer0.element_type,
|
||||
num_bits_per_copy=64,
|
||||
memory_scope=cute.nvgpu.common.MemoryScope.SYS,
|
||||
memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE,
|
||||
)
|
||||
copy_atom_store = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
buffer0.element_type,
|
||||
num_bits_per_copy=64,
|
||||
memory_scope=cute.nvgpu.common.MemoryScope.SYS,
|
||||
memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE,
|
||||
)
|
||||
tiled_copy = cute.make_tiled_copy_tv(copy_atom_load, tv_layout[0], tv_layout[1])
|
||||
thr_copy = tiled_copy.get_slice(tidx)
|
||||
|
||||
thr_write_buffer_list = [
|
||||
thr_copy.partition_D(tensor) for tensor in write_buffers
|
||||
]
|
||||
thr_read_buffer = thr_copy.partition_S(read_buffer)
|
||||
thr_clear_buffer = thr_copy.partition_D(clear_buffer)
|
||||
|
||||
thr_in = thr_copy.partition_S(local_tile_in)
|
||||
thr_out = thr_copy.partition_D(local_tile_out)
|
||||
|
||||
frg_in = cute.make_fragment_like(thr_in)
|
||||
frg_clear = cute.make_fragment_like(thr_clear_buffer)
|
||||
frg_acc = cute.make_fragment_like(thr_out)
|
||||
frg_acc.fill(0.0)
|
||||
|
||||
clear_tensor = frg_clear.load()
|
||||
frg_size = cute.size(clear_tensor.shape)
|
||||
neg0_i32_vec = cute.full_like(clear_tensor, 0x80000000, cutlass.Int32)
|
||||
neg0_f32_vec = vector.bitcast(T.vector(frg_size, T.f32()), neg0_i32_vec)
|
||||
neg0_f32_tensor = cute.TensorSSA(
|
||||
neg0_f32_vec, clear_tensor.shape, cutlass.Float32
|
||||
)
|
||||
frg_clear.store(neg0_f32_tensor)
|
||||
|
||||
cute.copy(copy_atom_load, thr_in, frg_in)
|
||||
|
||||
for thr_write_buffer in thr_write_buffer_list:
|
||||
cute.copy(copy_atom_store, frg_in, thr_write_buffer)
|
||||
|
||||
cute.copy(copy_atom_store, frg_clear, thr_clear_buffer)
|
||||
|
||||
frg_in_vector_neg0_i32 = cute.full_like(
|
||||
frg_in, cutlass.Int32(0x80000000), cutlass.Int32
|
||||
)
|
||||
frg_in_size = cute.size(frg_in.shape)
|
||||
|
||||
for i in range(WORLD_SIZE):
|
||||
read_coord = (None, 0, i)
|
||||
cute.copy(copy_atom_load, thr_read_buffer[read_coord], frg_in[None, 0])
|
||||
frg_vector = frg_in.load()
|
||||
frg_vector_i32 = vector.bitcast(T.vector(frg_in_size, T.i32()), frg_vector)
|
||||
isNotNeg0 = cute.all_(frg_vector_i32 != frg_in_vector_neg0_i32)
|
||||
while not isNotNeg0:
|
||||
cute.copy(copy_atom_load, thr_read_buffer[read_coord], frg_in[None, 0])
|
||||
frg_vector = frg_in.load()
|
||||
frg_vector_i32 = vector.bitcast(
|
||||
T.vector(frg_in_size, T.i32()), frg_vector
|
||||
)
|
||||
isNotNeg0 = cute.all_(frg_vector_i32 != frg_in_vector_neg0_i32)
|
||||
frg_acc.store(frg_in.load() + frg_acc.load())
|
||||
|
||||
cute.copy(copy_atom_stg, frg_acc, thr_out)
|
||||
|
||||
|
||||
def run_all_reduce(rank, M, N, dtype: Type[cutlass.Numeric]):
|
||||
setup(rank, WORLD_SIZE)
|
||||
|
||||
input_tensor = torch.randn(M * N, device=f"cuda:{rank}")
|
||||
output_tensor = torch.zeros(M * N, device=f"cuda:{rank}")
|
||||
|
||||
# init tensors on different devices
|
||||
t = symm_mem.empty(
|
||||
[
|
||||
PING_PONG_SIZE,
|
||||
WORLD_SIZE,
|
||||
M * N,
|
||||
],
|
||||
device="cuda",
|
||||
).neg_()
|
||||
hdl = symm_mem.rendezvous(t, dist.group.WORLD)
|
||||
buffer_tensor_list = [
|
||||
hdl.get_buffer(rank, t.shape, t.dtype).permute(2, 1, 0)
|
||||
for rank in range(WORLD_SIZE)
|
||||
]
|
||||
|
||||
# enable peer access
|
||||
driver.cuInit(0)
|
||||
dev_list = [driver.cuDeviceGet(i)[1] for i in range(WORLD_SIZE)]
|
||||
ctx_list = [driver.cuDevicePrimaryCtxRetain(dev)[1] for dev in dev_list]
|
||||
for i in range(WORLD_SIZE):
|
||||
driver.cuCtxSetCurrent(ctx_list[i])
|
||||
for j in range(WORLD_SIZE):
|
||||
if i == j:
|
||||
continue
|
||||
driver.cuCtxEnablePeerAccess(ctx_list[j], 0)
|
||||
driver.cuCtxSetCurrent(ctx_list[rank])
|
||||
|
||||
stream = cutlass.cuda.default_stream()
|
||||
all_reduce_kernel = AllReduceKernel()
|
||||
dlpack_buffers = [from_dlpack(x, assumed_align=32) for x in buffer_tensor_list]
|
||||
all_reduce_kernel(
|
||||
rank,
|
||||
0,
|
||||
from_dlpack(input_tensor, assumed_align=32),
|
||||
from_dlpack(output_tensor, assumed_align=32),
|
||||
*dlpack_buffers,
|
||||
stream,
|
||||
)
|
||||
torch.cuda.synchronize(0)
|
||||
|
||||
# use torch api to get reference and inplace stored to input_tensor
|
||||
ref_tensor = input_tensor.clone()
|
||||
dist.all_reduce(ref_tensor, op=dist.ReduceOp.SUM)
|
||||
|
||||
# check result of output tensor, allow small error due to different accumulator datatypes
|
||||
equal_mask = (ref_tensor.cpu() - output_tensor.cpu()).abs() < 1e-4
|
||||
result = (equal_mask.sum()).item() == ref_tensor.numel()
|
||||
|
||||
if result:
|
||||
print(f"rank {rank} test passed")
|
||||
else:
|
||||
print(f"rank {rank} test failed")
|
||||
print(
|
||||
"ref_tensor[ref_tensor != output_tensor]: ",
|
||||
ref_tensor[ref_tensor != output_tensor],
|
||||
)
|
||||
print(
|
||||
"output_tensor[ref_tensor != output_tensor]: ",
|
||||
output_tensor[ref_tensor != output_tensor],
|
||||
)
|
||||
|
||||
cleanup()
|
||||
|
||||
|
||||
def main():
|
||||
M = 1024
|
||||
N = 1024
|
||||
|
||||
# each process will run run_all_reduce on different device
|
||||
mp.spawn(run_all_reduce, args=(M, N, cutlass.Float32), nprocs=WORLD_SIZE, join=True)
|
||||
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
166
examples/python/CuTeDSL/ampere/call_bypass_dlpack.py
Normal file
166
examples/python/CuTeDSL/ampere/call_bypass_dlpack.py
Normal file
@ -0,0 +1,166 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import sys
|
||||
import os
|
||||
from typing import Tuple
|
||||
import torch
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.runtime import make_ptr
|
||||
|
||||
|
||||
"""
|
||||
An Example demonstrating how to call off-the-shelf kernel by-passing dlpack protocol
|
||||
|
||||
The example shows how to directly pass pointers from PyTorch tensors to off-the-shelf kernels
|
||||
written by CuTe DSL with a thin customized wrapper jit function. The jit function will be
|
||||
compiled with inline without introducing overhead.
|
||||
|
||||
To run this example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python examples/ampere/call_bypass_dlpack.py
|
||||
|
||||
|
||||
It's worth to mention that by-passing dlpack protocol can resolve the issue that dlpack doesn't handle shape-1
|
||||
mode correctly. For example, the following code will fail, because dlpack will convert the shape-1 mode
|
||||
with stride-1 which propagate alignment incorrectly.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@cute.kernel
|
||||
def fails_kernel(gX: cute.Tensor):
|
||||
bidx, _, _ = cute.arch.block_idx()
|
||||
mX = gX[None, bidx, None] # We wish to retain alignment
|
||||
# assert mX.iterator.alignment == 16
|
||||
|
||||
|
||||
@cute.jit
|
||||
def fails(gX_: cute.Tensor):
|
||||
gX = gX_
|
||||
fails_kernel(gX).launch(grid=(1, 1, 1), block=(128, 1, 1))
|
||||
|
||||
|
||||
gX_torch = torch.rand((128, 1, 128), device="cuda", dtype=torch.bfloat16)
|
||||
fails(from_dlpack(gX_torch, assumed_align=16))
|
||||
|
||||
"""
|
||||
|
||||
# Add the current directory to sys.path
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
from tensorop_gemm import TensorOpGemm
|
||||
|
||||
|
||||
@cute.jit
|
||||
def tensor_op_gemm_wrapper(
|
||||
a_ptr: cute.Pointer,
|
||||
b_ptr: cute.Pointer,
|
||||
c_ptr: cute.Pointer,
|
||||
m: cutlass.Int32,
|
||||
n: cutlass.Int32,
|
||||
k: cutlass.Int32,
|
||||
l: cutlass.Int32,
|
||||
):
|
||||
print(f"\n[DSL INFO] Input Parameters:")
|
||||
print(f"[DSL INFO] mnkl: {(m, n, k, l)}")
|
||||
|
||||
# Assume alignment of shape to call tensorop_gemm example
|
||||
m = cute.assume(m, divby=8)
|
||||
n = cute.assume(n, divby=8)
|
||||
|
||||
# Torch is row major
|
||||
a_layout = cute.make_ordered_layout((m, k, l), order=(0, 1, 2))
|
||||
b_layout = cute.make_ordered_layout((n, k, l), order=(0, 1, 2))
|
||||
c_layout = cute.make_ordered_layout((m, n, l), order=(1, 0, 2))
|
||||
mA = cute.make_tensor(a_ptr, layout=a_layout)
|
||||
mB = cute.make_tensor(b_ptr, layout=b_layout)
|
||||
mC = cute.make_tensor(c_ptr, layout=c_layout)
|
||||
|
||||
print(f"[DSL INFO] mA: {mA}")
|
||||
print(f"[DSL INFO] mB: {mB}")
|
||||
print(f"[DSL INFO] mC: {mC}")
|
||||
|
||||
tensor_op_gemm = TensorOpGemm(
|
||||
a_ptr.value_type, c_ptr.value_type, cutlass.Float32, (2, 2, 1)
|
||||
)
|
||||
print(f"\n[DSL INFO] Created TensorOpGemm instance")
|
||||
print(f"[DSL INFO] Input dtype: {a_ptr.value_type}")
|
||||
print(f"[DSL INFO] Output dtype: {c_ptr.value_type}")
|
||||
print(f"[DSL INFO] Accumulation dtype: {cutlass.Float32}")
|
||||
print(f"[DSL INFO] Atom layout: {(2, 2, 1)}")
|
||||
|
||||
# No need to compile inside jit function
|
||||
tensor_op_gemm(mA, mB, mC)
|
||||
print(f"\n[DSL INFO] Executed TensorOpGemm")
|
||||
|
||||
|
||||
def run_tensor_op_gemm_wrapper(mnkl: Tuple[int, int, int, int]):
|
||||
print(f"\nRunning TensorOpGemm test with:")
|
||||
print(f"Tensor dimensions: {mnkl}")
|
||||
|
||||
# (M,K,L)
|
||||
a = torch.randn(
|
||||
mnkl[3], mnkl[2], mnkl[0], dtype=torch.float16, device="cuda"
|
||||
).permute(2, 1, 0)
|
||||
# (N,K,L)
|
||||
b = torch.randn(
|
||||
mnkl[3], mnkl[2], mnkl[1], dtype=torch.float16, device="cuda"
|
||||
).permute(2, 1, 0)
|
||||
# (N,M,L)
|
||||
c = torch.randn(
|
||||
mnkl[3], mnkl[0], mnkl[1], dtype=torch.float16, device="cuda"
|
||||
).permute(1, 2, 0)
|
||||
|
||||
print(f"Input tensor shapes:")
|
||||
print(f"a: {a.shape}, dtype: {a.dtype}")
|
||||
print(f"b: {b.shape}, dtype: {b.dtype}")
|
||||
print(f"c: {c.shape}, dtype: {c.dtype}\n")
|
||||
|
||||
a_ptr = make_ptr(
|
||||
cutlass.Float16, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=32
|
||||
)
|
||||
b_ptr = make_ptr(
|
||||
cutlass.Float16, b.data_ptr(), cute.AddressSpace.gmem, assumed_align=32
|
||||
)
|
||||
c_ptr = make_ptr(
|
||||
cutlass.Float16, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=32
|
||||
)
|
||||
tensor_op_gemm_wrapper(a_ptr, b_ptr, c_ptr, *mnkl)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
ref = torch.einsum("mkl,nkl->mnl", a, b)
|
||||
torch.testing.assert_close(c, ref, atol=1e-05, rtol=1e-05)
|
||||
print(f"\n[DSL INFO] Results verified successfully!")
|
||||
print(f"First few elements of result: \n{c[:3, :3, :3]}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tensor_op_gemm_wrapper((512, 256, 128, 16))
|
||||
@ -226,15 +226,15 @@ def run_tensor_op_gemm_wrapper(mnkl: Tuple[int, int, int, int]):
|
||||
print(f"c: {c.shape}, dtype: {c.dtype}\n")
|
||||
|
||||
buffer_a = BufferWithLayout(
|
||||
make_ptr(ab_dtype, a.data_ptr(), cute.AddressSpace.gmem),
|
||||
make_ptr(ab_dtype, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=32),
|
||||
(2, 1, 0),
|
||||
)
|
||||
buffer_b = BufferWithLayout(
|
||||
make_ptr(ab_dtype, b.data_ptr(), cute.AddressSpace.gmem),
|
||||
make_ptr(ab_dtype, b.data_ptr(), cute.AddressSpace.gmem, assumed_align=32),
|
||||
(2, 1, 0),
|
||||
)
|
||||
buffer_c = BufferWithLayout(
|
||||
make_ptr(c_dtype, c.data_ptr(), cute.AddressSpace.gmem),
|
||||
make_ptr(c_dtype, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=32),
|
||||
(2, 1, 0),
|
||||
)
|
||||
|
||||
|
||||
189
examples/python/CuTeDSL/ampere/distributed_vector_add.py
Normal file
189
examples/python/CuTeDSL/ampere/distributed_vector_add.py
Normal file
@ -0,0 +1,189 @@
|
||||
import os
|
||||
import torch
|
||||
import argparse
|
||||
from typing import Type
|
||||
from cuda.bindings import driver
|
||||
|
||||
import torch.distributed as dist
|
||||
import torch.distributed._symmetric_memory as symm_mem
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
|
||||
def setup(rank, world_size):
|
||||
# set environment variables for torch.distributed environment
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "12995"
|
||||
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
|
||||
def cleanup():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def vector_add_kernel(
|
||||
g0: cute.Tensor,
|
||||
g1: cute.Tensor,
|
||||
g2: cute.Tensor,
|
||||
g3: cute.Tensor,
|
||||
g4: cute.Tensor,
|
||||
g5: cute.Tensor,
|
||||
g6: cute.Tensor,
|
||||
g7: cute.Tensor,
|
||||
gOut: cute.Tensor,
|
||||
tv_layout: cute.Layout,
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, _, _ = cute.arch.block_idx()
|
||||
cta_coord = (None, bidx)
|
||||
local_tile_out = gOut[cta_coord]
|
||||
local_tile_list = [
|
||||
g0[cta_coord],
|
||||
g1[cta_coord],
|
||||
g2[cta_coord],
|
||||
g3[cta_coord],
|
||||
g4[cta_coord],
|
||||
g5[cta_coord],
|
||||
g6[cta_coord],
|
||||
g7[cta_coord],
|
||||
]
|
||||
|
||||
copy_atom_load = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
g0.element_type,
|
||||
memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE,
|
||||
memory_scope=cute.nvgpu.common.MemoryScope.SYS,
|
||||
)
|
||||
copy_atom_store = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
g0.element_type,
|
||||
memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE,
|
||||
memory_scope=cute.nvgpu.common.MemoryScope.SYS,
|
||||
)
|
||||
tiled_copy = cute.make_tiled_copy_tv(copy_atom_load, tv_layout[0], tv_layout[1])
|
||||
thr_copy = tiled_copy.get_slice(tidx)
|
||||
|
||||
thr_tensor_list = [thr_copy.partition_S(tensor) for tensor in local_tile_list]
|
||||
thr_out = thr_copy.partition_D(local_tile_out)
|
||||
frg_tensor_list = [cute.make_fragment_like(tensor) for tensor in thr_tensor_list]
|
||||
frg_acc = cute.make_fragment_like(thr_out)
|
||||
frg_acc.fill(0.0)
|
||||
|
||||
for thr, frg in zip(thr_tensor_list, frg_tensor_list):
|
||||
cute.copy(copy_atom_load, thr, frg)
|
||||
tmp = frg.load() + frg_acc.load()
|
||||
frg_acc.store(tmp)
|
||||
|
||||
cute.copy(copy_atom_store, frg_acc, thr_out)
|
||||
|
||||
|
||||
@cute.jit
|
||||
def vector_add(
|
||||
m0: cute.Tensor,
|
||||
m1: cute.Tensor,
|
||||
m2: cute.Tensor,
|
||||
m3: cute.Tensor,
|
||||
m4: cute.Tensor,
|
||||
m5: cute.Tensor,
|
||||
m6: cute.Tensor,
|
||||
m7: cute.Tensor,
|
||||
output: cute.Tensor,
|
||||
):
|
||||
# define constants for future use
|
||||
num_of_elements = cute.size(m0.layout)
|
||||
# 128 threads per block and 4 elements per thread
|
||||
tv_layout = cute.make_layout(((128), (4)), stride=((1), (1)))
|
||||
tile = cute.size(tv_layout.shape)
|
||||
|
||||
tensors = [m0, m1, m2, m3, m4, m5, m6, m7]
|
||||
divided_tensors = [
|
||||
cute.zipped_divide(tensor, cute.make_layout(tile)) for tensor in tensors
|
||||
]
|
||||
gOut = cute.zipped_divide(output, cute.make_layout(tile)) # ((Tile),(Rest))
|
||||
vector_add_kernel(
|
||||
divided_tensors[0],
|
||||
divided_tensors[1],
|
||||
divided_tensors[2],
|
||||
divided_tensors[3],
|
||||
divided_tensors[4],
|
||||
divided_tensors[5],
|
||||
divided_tensors[6],
|
||||
divided_tensors[7],
|
||||
gOut,
|
||||
tv_layout,
|
||||
).launch(
|
||||
grid=[num_of_elements // tile, 1, 1],
|
||||
block=[tv_layout.shape[0], 1, 1],
|
||||
)
|
||||
|
||||
|
||||
def run_vector_add(rank, world_size, M, N, dtype: Type[cutlass.Numeric]):
|
||||
setup(rank, world_size)
|
||||
|
||||
t = symm_mem.empty(M * N, device="cuda")
|
||||
hdl = symm_mem.rendezvous(t, dist.group.WORLD)
|
||||
# get tensors from other devices from the symmetric memory
|
||||
tensor_list = [hdl.get_buffer(rank, t.shape, t.dtype) for rank in range(world_size)]
|
||||
tensor_list[rank].random_(0, 100)
|
||||
|
||||
# enable peer access
|
||||
driver.cuInit(0)
|
||||
dev_list = [driver.cuDeviceGet(i)[1] for i in range(world_size)]
|
||||
ctx_list = [driver.cuDevicePrimaryCtxRetain(dev)[1] for dev in dev_list]
|
||||
driver.cuCtxSetCurrent(ctx_list[rank])
|
||||
for i in range(world_size):
|
||||
if i == rank:
|
||||
continue
|
||||
driver.cuCtxEnablePeerAccess(ctx_list[i], 0)
|
||||
|
||||
output = torch.zeros(M * N, device=f"cuda:{rank}")
|
||||
|
||||
# we have to explicitly pass each tensor instead of a list of tensors
|
||||
vector_add(
|
||||
from_dlpack(tensor_list[0], assumed_align=32),
|
||||
from_dlpack(tensor_list[1], assumed_align=32),
|
||||
from_dlpack(tensor_list[2], assumed_align=32),
|
||||
from_dlpack(tensor_list[3], assumed_align=32),
|
||||
from_dlpack(tensor_list[4], assumed_align=32),
|
||||
from_dlpack(tensor_list[5], assumed_align=32),
|
||||
from_dlpack(tensor_list[6], assumed_align=32),
|
||||
from_dlpack(tensor_list[7], assumed_align=32),
|
||||
from_dlpack(output, assumed_align=32),
|
||||
)
|
||||
|
||||
sum_tensor = sum([tensor.cpu() for tensor in tensor_list])
|
||||
|
||||
if sum(sum_tensor.cpu() == output.cpu()) == sum_tensor.numel():
|
||||
print("test passed")
|
||||
else:
|
||||
print("test failed")
|
||||
print(sum_tensor.cpu())
|
||||
print(output.cpu())
|
||||
|
||||
cleanup()
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
world_size = torch.cuda.device_count()
|
||||
M = 1024
|
||||
N = 1024
|
||||
|
||||
# each process will run run_vector_add on different device
|
||||
mp.spawn(
|
||||
run_vector_add,
|
||||
args=(world_size, M, N, cutlass.Float32),
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
114
examples/python/CuTeDSL/ampere/dynamic_smem_size.py
Normal file
114
examples/python/CuTeDSL/ampere/dynamic_smem_size.py
Normal file
@ -0,0 +1,114 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import cutlass.cute as cute
|
||||
import cutlass
|
||||
|
||||
"""
|
||||
Example of automatic shared memory size computation for configuring kernel launch
|
||||
|
||||
This example demonstrates how to let the DSL automatically set shared memory
|
||||
size for a kernel launch rather explicitly configuring it at launch time,
|
||||
provided that developers are using `SmemAllocator` for all allocations.
|
||||
|
||||
Usage:
|
||||
python dynamic_smem_size.py # Show auto inference
|
||||
"""
|
||||
|
||||
|
||||
@cute.struct
|
||||
class SharedData:
|
||||
"""A struct to demonstrate shared memory allocation."""
|
||||
|
||||
values: cute.struct.MemRange[cutlass.Float32, 64] # 256 bytes
|
||||
counter: cutlass.Int32 # 4 bytes
|
||||
flag: cutlass.Int8 # 1 byte
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def kernel():
|
||||
"""
|
||||
Example kernel that allocates shared memory.
|
||||
The total allocation will be automatically calculated when smem=None.
|
||||
"""
|
||||
allocator = cutlass.utils.SmemAllocator()
|
||||
|
||||
# Allocate various types of shared memory
|
||||
shared_data = allocator.allocate(SharedData)
|
||||
raw_buffer = allocator.allocate(512, byte_alignment=64)
|
||||
int_array = allocator.allocate_array(element_type=cutlass.Int32, num_elems=128)
|
||||
tensor_smem = allocator.allocate_tensor(
|
||||
element_type=cutlass.Float16,
|
||||
layout=cute.make_layout((32, 16)),
|
||||
byte_alignment=16,
|
||||
swizzle=None,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def kernel_no_smem():
|
||||
"""
|
||||
Example kernel that does not allocates shared memory.
|
||||
The total allocation will be automatically calculated as 0 when smem=None.
|
||||
"""
|
||||
tidx, _, _ = cute.arch.block_idx()
|
||||
if tidx == 0:
|
||||
cute.printf("Hello world")
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Initialize CUDA context
|
||||
cutlass.cuda.initialize_cuda_context()
|
||||
|
||||
print("Launching kernel with auto smem size. (launch config `smem=None`)")
|
||||
|
||||
# Compile the example
|
||||
@cute.jit
|
||||
def launch_kernel1():
|
||||
k = kernel()
|
||||
k.launch(
|
||||
grid=(1, 1, 1),
|
||||
block=(1, 1, 1),
|
||||
)
|
||||
print(f"Kernel recorded internal smem usage: {k.smem_usage()}")
|
||||
|
||||
@cute.jit
|
||||
def launch_kernel2():
|
||||
k = kernel_no_smem()
|
||||
k.launch(
|
||||
grid=(1, 1, 1),
|
||||
block=(1, 1, 1),
|
||||
)
|
||||
print(f"Kernel recorded internal smem usage: {k.smem_usage()}")
|
||||
|
||||
cute.compile(launch_kernel1)
|
||||
cute.compile(launch_kernel2)
|
||||
|
||||
print("PASS")
|
||||
@ -327,7 +327,6 @@ class FlashAttentionForwardAmpere:
|
||||
).launch(
|
||||
grid=grid_dim,
|
||||
block=[self._num_threads, 1, 1],
|
||||
smem=SharedStorage.size_in_bytes(),
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
@ -1014,13 +1013,10 @@ class FlashAttentionForwardAmpere:
|
||||
)
|
||||
|
||||
# compute exp(x - max) using exp2(x * log_2(e) - max * log_2(e))
|
||||
acc_S_row_exp = cute.TensorSSA(
|
||||
self._exp2f(
|
||||
acc_S_row * softmax_params.softmax_scale_log2
|
||||
- row_max_cur_row * softmax_params.softmax_scale_log2
|
||||
),
|
||||
tuple(acc_S_row.shape),
|
||||
cutlass.Float32,
|
||||
acc_S_row_exp = cute.math.exp2(
|
||||
acc_S_row * softmax_params.softmax_scale_log2
|
||||
- row_max_cur_row * softmax_params.softmax_scale_log2,
|
||||
fastmath=True,
|
||||
)
|
||||
# acc_S_row_sum => f32
|
||||
acc_S_row_sum = acc_S_row_exp.reduce(
|
||||
@ -1028,9 +1024,10 @@ class FlashAttentionForwardAmpere:
|
||||
)
|
||||
# if it is not the first tile, load the row r of previous row_max and minus row_max_cur_row to update row_sum.
|
||||
if cutlass.const_expr(not is_first_n_block):
|
||||
prev_minus_cur_exp = self._exp2f(
|
||||
prev_minus_cur_exp = cute.math.exp2(
|
||||
row_max_prev_row * softmax_params.softmax_scale_log2
|
||||
- row_max_cur_row * softmax_params.softmax_scale_log2
|
||||
- row_max_cur_row * softmax_params.softmax_scale_log2,
|
||||
fastmath=True,
|
||||
)
|
||||
acc_S_row_sum = (
|
||||
acc_S_row_sum + softmax_params.row_sum[r] * prev_minus_cur_exp
|
||||
@ -1141,26 +1138,6 @@ class FlashAttentionForwardAmpere:
|
||||
"""
|
||||
return self._threadquad_reduce(val, lambda x, y: x + y)
|
||||
|
||||
def _exp2f(
|
||||
self, x: Union[cute.TensorSSA, cutlass.Float32]
|
||||
) -> Union[cute.TensorSSA, cutlass.Float32]:
|
||||
"""exp2f calculation for both vector and scalar.
|
||||
|
||||
:param x: input value
|
||||
:type x: cute.TensorSSA or cutlass.Float32
|
||||
:return: exp2 value
|
||||
:rtype: cute.TensorSSA or cutlass.Float32
|
||||
"""
|
||||
if isinstance(x, cute.TensorSSA):
|
||||
res = cute.make_fragment(x.shape, cutlass.Float32)
|
||||
res.store(x)
|
||||
|
||||
for i in range(cute.size(x.shape)):
|
||||
res[i] = self._exp2f(res[i])
|
||||
|
||||
return res.load()
|
||||
return cute.arch.exp2(x)
|
||||
|
||||
|
||||
def run(
|
||||
dtype: Type[cutlass.Numeric],
|
||||
|
||||
@ -136,10 +136,6 @@ class SGemm:
|
||||
stride=(1, (self._bN + padding_b), self._bK * (self._bN + padding_b)),
|
||||
)
|
||||
|
||||
smem_size = cute.size_in_bytes(mA.element_type, sA_layout) + cute.size_in_bytes(
|
||||
mB.element_type, sB_layout
|
||||
)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Create copy layouts that will be used for asynchronous
|
||||
# global memory -> shared memory copies:
|
||||
@ -258,7 +254,6 @@ class SGemm:
|
||||
).launch(
|
||||
grid=grid_dim,
|
||||
block=[cute.size(atoms_layout), 1, 1],
|
||||
smem=smem_size,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
@ -738,14 +733,20 @@ def run(
|
||||
|
||||
print("Compiling kernel with cute.compile ...")
|
||||
start_time = time.time()
|
||||
gemm = cute.compile(sgemm, a_tensor, b_tensor, c_tensor, stream=current_stream)
|
||||
compiled_fn = cute.compile(
|
||||
sgemm,
|
||||
a_tensor,
|
||||
b_tensor,
|
||||
c_tensor,
|
||||
stream=current_stream,
|
||||
)
|
||||
compilation_time = time.time() - start_time
|
||||
print(f"Compilation time: {compilation_time:.4f} seconds")
|
||||
|
||||
print("Executing GEMM kernel...")
|
||||
|
||||
if not skip_ref_check:
|
||||
gemm(a_tensor, b_tensor, c_tensor)
|
||||
compiled_fn(a_tensor, b_tensor, c_tensor)
|
||||
torch.cuda.synchronize()
|
||||
print("Verifying results...")
|
||||
ref = torch.einsum("mk,nk->mn", a, b)
|
||||
@ -804,7 +805,7 @@ def run(
|
||||
)
|
||||
|
||||
avg_time_us = testing.benchmark(
|
||||
gemm,
|
||||
compiled_fn,
|
||||
workspace_generator=generate_tensors,
|
||||
workspace_count=workspace_count,
|
||||
stream=current_stream,
|
||||
@ -837,6 +838,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--c_major", choices=["n", "m"], default="n")
|
||||
parser.add_argument("--warmup_iterations", default=2, type=int)
|
||||
parser.add_argument("--iterations", default=100, type=int)
|
||||
parser.add_argument("--static_shape", action="store_true")
|
||||
parser.add_argument("--skip_ref_check", action="store_true")
|
||||
parser.add_argument(
|
||||
"--use_cold_l2",
|
||||
|
||||
@ -69,7 +69,7 @@ class complex:
|
||||
class SharedStorage:
|
||||
# struct elements with natural alignment
|
||||
a: cute.struct.MemRange[cutlass.Float32, 32] # array
|
||||
b: cutlass.Int64 # scalar
|
||||
b: cutlass.Int64 # saclar
|
||||
c: complex # nested struct
|
||||
# struct elements with strict alignment
|
||||
x: cute.struct.Align[
|
||||
|
||||
@ -471,7 +471,7 @@ class TensorOpGemm:
|
||||
cute.arch.sync_threads()
|
||||
# Start async loads for the first k-tile. Here we take care of the k residue
|
||||
# via if/else check along the k dimension. Because we shifted the identity tensor
|
||||
# by the residue_k and because the identity tensor is a counting tensor, the
|
||||
# by the residue_k and because the identity tensor is a coord tensor, the
|
||||
# values of any identity tensor element that is poison is less than -1
|
||||
num_smem_stages = cute.size(tAsA, mode=[3])
|
||||
k_tile_count = cute.size(tAgA, mode=[3])
|
||||
@ -683,7 +683,7 @@ class TensorOpGemm:
|
||||
# Copy results of D back to shared memory
|
||||
cute.autovec_copy(tCrD, tCsC)
|
||||
|
||||
# Create counting tensor for C
|
||||
# Create coord tensor for C
|
||||
ceilM, ceilN, _ = cute.ceil_div(mC.shape, (self.bM, self.bN, 1))
|
||||
mcC = cute.make_identity_tensor(
|
||||
(
|
||||
|
||||
@ -610,7 +610,6 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
|
||||
grid=grid,
|
||||
block=[self.threads_per_cta, 1, 1],
|
||||
cluster=(*self.cluster_shape_mn, 1),
|
||||
smem=self.shared_storage.size_in_bytes(),
|
||||
stream=stream,
|
||||
)
|
||||
return
|
||||
@ -797,7 +796,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
|
||||
gC_mnl = cute.local_tile(
|
||||
mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None)
|
||||
)
|
||||
k_block_cnt = cute.size(gA_mkl, mode=[3])
|
||||
k_tile_cnt = cute.size(gA_mkl, mode=[3])
|
||||
|
||||
#
|
||||
# Partition global tensor for TiledMMA_A/B/C
|
||||
@ -946,17 +945,17 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
|
||||
(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])
|
||||
]
|
||||
|
||||
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt
|
||||
# Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt
|
||||
ab_producer_state.reset_count()
|
||||
peek_ab_empty_status = cutlass.Boolean(1)
|
||||
if ab_producer_state.count < k_block_cnt:
|
||||
if ab_producer_state.count < k_tile_cnt:
|
||||
peek_ab_empty_status = ab_pipeline.producer_try_acquire(
|
||||
ab_producer_state
|
||||
)
|
||||
#
|
||||
# Tma load loop
|
||||
#
|
||||
for k_block in cutlass.range(0, k_block_cnt, 1, unroll=1):
|
||||
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
|
||||
# Conditionally wait for AB buffer empty
|
||||
ab_pipeline.producer_acquire(
|
||||
ab_producer_state, peek_ab_empty_status
|
||||
@ -992,10 +991,10 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
|
||||
mcast_mask=sfb_full_mcast_mask,
|
||||
)
|
||||
|
||||
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt + k_block + 1
|
||||
# Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1
|
||||
ab_producer_state.advance()
|
||||
peek_ab_empty_status = cutlass.Boolean(1)
|
||||
if ab_producer_state.count < k_block_cnt:
|
||||
if ab_producer_state.count < k_tile_cnt:
|
||||
peek_ab_empty_status = ab_pipeline.producer_try_acquire(
|
||||
ab_producer_state
|
||||
)
|
||||
@ -1103,10 +1102,10 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
|
||||
# (MMA, MMA_M, MMA_N)
|
||||
tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)]
|
||||
|
||||
# Peek (try_wait) AB buffer full for k_block = 0
|
||||
# Peek (try_wait) AB buffer full for k_tile = 0
|
||||
ab_consumer_state.reset_count()
|
||||
peek_ab_full_status = cutlass.Boolean(1)
|
||||
if ab_consumer_state.count < k_block_cnt and is_leader_cta:
|
||||
if ab_consumer_state.count < k_tile_cnt and is_leader_cta:
|
||||
peek_ab_full_status = ab_pipeline.consumer_try_wait(
|
||||
ab_consumer_state
|
||||
)
|
||||
@ -1125,7 +1124,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
|
||||
#
|
||||
# Mma mainloop
|
||||
#
|
||||
for k_block in range(k_block_cnt):
|
||||
for k_tile in range(k_tile_cnt):
|
||||
if is_leader_cta:
|
||||
# Conditionally wait for AB buffer full
|
||||
ab_pipeline.consumer_wait(
|
||||
@ -1154,44 +1153,44 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
|
||||
)
|
||||
|
||||
# tCtAcc += tCrA * tCrSFA * tCrB * tCrSFB
|
||||
num_kphases = cute.size(tCrA, mode=[2])
|
||||
for kphase_idx in cutlass.range(num_kphases, unroll_full=True):
|
||||
kphase_coord = (
|
||||
num_kblocks = cute.size(tCrA, mode=[2])
|
||||
for kblock_idx in cutlass.range(num_kblocks, unroll_full=True):
|
||||
kblock_coord = (
|
||||
None,
|
||||
None,
|
||||
kphase_idx,
|
||||
kblock_idx,
|
||||
ab_consumer_state.index,
|
||||
)
|
||||
|
||||
# Set SFA/SFB tensor to tiled_mma
|
||||
sf_kphase_coord = (None, None, kphase_idx)
|
||||
sf_kblock_coord = (None, None, kblock_idx)
|
||||
tiled_mma.set(
|
||||
tcgen05.Field.SFA,
|
||||
tCtSFA[sf_kphase_coord].iterator,
|
||||
tCtSFA[sf_kblock_coord].iterator,
|
||||
)
|
||||
tiled_mma.set(
|
||||
tcgen05.Field.SFB,
|
||||
tCtSFB[sf_kphase_coord].iterator,
|
||||
tCtSFB[sf_kblock_coord].iterator,
|
||||
)
|
||||
|
||||
cute.gemm(
|
||||
tiled_mma,
|
||||
tCtAcc,
|
||||
tCrA[kphase_coord],
|
||||
tCrB[kphase_coord],
|
||||
tCrA[kblock_coord],
|
||||
tCrB[kblock_coord],
|
||||
tCtAcc,
|
||||
)
|
||||
|
||||
# Enable accumulate on tCtAcc after first kphase
|
||||
# Enable accumulate on tCtAcc after first kblock
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
|
||||
# Async arrive AB buffer empty
|
||||
ab_pipeline.consumer_release(ab_consumer_state)
|
||||
|
||||
# Peek (try_wait) AB buffer full for k_block = k_block + 1
|
||||
# Peek (try_wait) AB buffer full for k_tile = k_tile + 1
|
||||
ab_consumer_state.advance()
|
||||
peek_ab_full_status = cutlass.Boolean(1)
|
||||
if ab_consumer_state.count < k_block_cnt:
|
||||
if ab_consumer_state.count < k_tile_cnt:
|
||||
if is_leader_cta:
|
||||
peek_ab_full_status = ab_pipeline.consumer_try_wait(
|
||||
ab_consumer_state
|
||||
|
||||
@ -486,7 +486,6 @@ class DenseGemmKernel:
|
||||
grid=grid,
|
||||
block=[self.threads_per_cta, 1, 1],
|
||||
cluster=(*self.cluster_shape_mn, 1),
|
||||
smem=self.shared_storage.size_in_bytes(),
|
||||
stream=stream,
|
||||
)
|
||||
return
|
||||
@ -660,7 +659,7 @@ class DenseGemmKernel:
|
||||
gC_mnl = cute.local_tile(
|
||||
mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None)
|
||||
)
|
||||
k_block_cnt = cute.size(gA_mkl, mode=[3])
|
||||
k_tile_cnt = cute.size(gA_mkl, mode=[3])
|
||||
|
||||
#
|
||||
# Partition global tensor for TiledMMA_A/B/C
|
||||
@ -788,19 +787,19 @@ class DenseGemmKernel:
|
||||
#
|
||||
# Pipelining TMA load A/B and MMA mainloop
|
||||
#
|
||||
prefetch_k_block_cnt = cutlass.min(self.num_ab_stage - 2, k_block_cnt)
|
||||
prefetch_k_tile_cnt = cutlass.min(self.num_ab_stage - 2, k_tile_cnt)
|
||||
|
||||
if warp_idx == 0:
|
||||
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt
|
||||
# Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt
|
||||
peek_ab_empty_status = cutlass.Boolean(1)
|
||||
if ab_producer_state.count < k_block_cnt:
|
||||
if ab_producer_state.count < k_tile_cnt:
|
||||
peek_ab_empty_status = ab_pipeline.producer_try_acquire(
|
||||
ab_producer_state
|
||||
)
|
||||
#
|
||||
# Prefetch TMA load A/B
|
||||
#
|
||||
for prefetch_idx in cutlass.range(prefetch_k_block_cnt, unroll=1):
|
||||
for prefetch_idx in cutlass.range(prefetch_k_tile_cnt, unroll=1):
|
||||
# Conditionally wait for AB buffer empty
|
||||
ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status)
|
||||
|
||||
@ -820,27 +819,27 @@ class DenseGemmKernel:
|
||||
mcast_mask=b_full_mcast_mask,
|
||||
)
|
||||
|
||||
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt + k_block + 1
|
||||
# Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1
|
||||
ab_producer_state.advance()
|
||||
peek_ab_empty_status = cutlass.Boolean(1)
|
||||
if ab_producer_state.count < k_block_cnt:
|
||||
if ab_producer_state.count < k_tile_cnt:
|
||||
peek_ab_empty_status = ab_pipeline.producer_try_acquire(
|
||||
ab_producer_state
|
||||
)
|
||||
|
||||
# Peek (try_wait) AB buffer full for k_block = 0
|
||||
# Peek (try_wait) AB buffer full for k_tile = 0
|
||||
peek_ab_full_status = cutlass.Boolean(1)
|
||||
if ab_consumer_state.count < k_block_cnt and is_leader_cta:
|
||||
if ab_consumer_state.count < k_tile_cnt and is_leader_cta:
|
||||
peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state)
|
||||
|
||||
#
|
||||
# MMA mainloop
|
||||
#
|
||||
for k_block in range(k_block_cnt):
|
||||
for k_tile in range(k_tile_cnt):
|
||||
# Conditionally wait for AB buffer empty
|
||||
ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status)
|
||||
|
||||
if ab_producer_state.count < k_block_cnt:
|
||||
if ab_producer_state.count < k_tile_cnt:
|
||||
# TMA load A/B
|
||||
cute.copy(
|
||||
tma_atom_a,
|
||||
@ -862,35 +861,35 @@ class DenseGemmKernel:
|
||||
ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status)
|
||||
|
||||
# tCtAcc += tCrA * tCrB
|
||||
num_kphases = cute.size(tCrA, mode=[2])
|
||||
for kphase_idx in cutlass.range(num_kphases, unroll_full=True):
|
||||
kphase_coord = (None, None, kphase_idx, ab_consumer_state.index)
|
||||
num_kblocks = cute.size(tCrA, mode=[2])
|
||||
for kblock_idx in cutlass.range(num_kblocks, unroll_full=True):
|
||||
kblock_coord = (None, None, kblock_idx, ab_consumer_state.index)
|
||||
|
||||
cute.gemm(
|
||||
tiled_mma,
|
||||
tCtAcc,
|
||||
tCrA[kphase_coord],
|
||||
tCrB[kphase_coord],
|
||||
tCrA[kblock_coord],
|
||||
tCrB[kblock_coord],
|
||||
tCtAcc,
|
||||
)
|
||||
# Enable accumulate on tCtAcc after first kphase
|
||||
# Enable accumulate on tCtAcc after first kblock
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
|
||||
# Async arrive AB buffer empty
|
||||
ab_pipeline.consumer_release(ab_consumer_state)
|
||||
|
||||
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt + k_block + 1
|
||||
# Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1
|
||||
ab_producer_state.advance()
|
||||
peek_ab_empty_status = cutlass.Boolean(1)
|
||||
if ab_producer_state.count < k_block_cnt:
|
||||
if ab_producer_state.count < k_tile_cnt:
|
||||
peek_ab_empty_status = ab_pipeline.producer_try_acquire(
|
||||
ab_producer_state
|
||||
)
|
||||
|
||||
# Peek (try_wait) AB buffer full for k_block = k_block + 1
|
||||
# Peek (try_wait) AB buffer full for k_tile = k_tile + 1
|
||||
ab_consumer_state.advance()
|
||||
peek_ab_full_status = cutlass.Boolean(1)
|
||||
if ab_consumer_state.count < k_block_cnt:
|
||||
if ab_consumer_state.count < k_tile_cnt:
|
||||
if is_leader_cta:
|
||||
peek_ab_full_status = ab_pipeline.consumer_try_wait(
|
||||
ab_consumer_state
|
||||
@ -1009,8 +1008,8 @@ class DenseGemmKernel:
|
||||
# Wait A/B buffer empty
|
||||
#
|
||||
if warp_idx == 0:
|
||||
# Reverse prefetch_k_block_cnt times to next available buffer
|
||||
for i in range(prefetch_k_block_cnt):
|
||||
# Reverse prefetch_k_tile_cnt times to next available buffer
|
||||
for i in range(prefetch_k_tile_cnt):
|
||||
ab_producer_state.reverse()
|
||||
ab_pipeline.producer_tail(ab_producer_state)
|
||||
return
|
||||
|
||||
@ -510,7 +510,6 @@ class PersistentDenseGemmKernel:
|
||||
grid=grid,
|
||||
block=[self.threads_per_cta, 1, 1],
|
||||
cluster=(*self.cluster_shape_mn, 1),
|
||||
smem=self.shared_storage.size_in_bytes(),
|
||||
stream=stream,
|
||||
)
|
||||
return
|
||||
@ -669,7 +668,7 @@ class PersistentDenseGemmKernel:
|
||||
gC_mnl = cute.local_tile(
|
||||
mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None)
|
||||
)
|
||||
k_block_cnt = cute.size(gA_mkl, mode=[3])
|
||||
k_tile_cnt = cute.size(gA_mkl, mode=[3])
|
||||
|
||||
#
|
||||
# Partition global tensor for TiledMMA_A/B/C
|
||||
@ -774,17 +773,17 @@ class PersistentDenseGemmKernel:
|
||||
(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])
|
||||
]
|
||||
|
||||
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt
|
||||
# Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt
|
||||
ab_producer_state.reset_count()
|
||||
peek_ab_empty_status = cutlass.Boolean(1)
|
||||
if ab_producer_state.count < k_block_cnt:
|
||||
if ab_producer_state.count < k_tile_cnt:
|
||||
peek_ab_empty_status = ab_pipeline.producer_try_acquire(
|
||||
ab_producer_state
|
||||
)
|
||||
#
|
||||
# Tma load loop
|
||||
#
|
||||
for k_block in cutlass.range(0, k_block_cnt, 1, unroll=1):
|
||||
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
|
||||
# Conditionally wait for AB buffer empty
|
||||
ab_pipeline.producer_acquire(
|
||||
ab_producer_state, peek_ab_empty_status
|
||||
@ -806,10 +805,10 @@ class PersistentDenseGemmKernel:
|
||||
mcast_mask=b_full_mcast_mask,
|
||||
)
|
||||
|
||||
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt + k_block + 1
|
||||
# Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1
|
||||
ab_producer_state.advance()
|
||||
peek_ab_empty_status = cutlass.Boolean(1)
|
||||
if ab_producer_state.count < k_block_cnt:
|
||||
if ab_producer_state.count < k_tile_cnt:
|
||||
peek_ab_empty_status = ab_pipeline.producer_try_acquire(
|
||||
ab_producer_state
|
||||
)
|
||||
@ -877,10 +876,10 @@ class PersistentDenseGemmKernel:
|
||||
# (MMA, MMA_M, MMA_N)
|
||||
tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)]
|
||||
|
||||
# Peek (try_wait) AB buffer full for k_block = 0
|
||||
# Peek (try_wait) AB buffer full for k_tile = 0
|
||||
ab_consumer_state.reset_count()
|
||||
peek_ab_full_status = cutlass.Boolean(1)
|
||||
if ab_consumer_state.count < k_block_cnt and is_leader_cta:
|
||||
if ab_consumer_state.count < k_tile_cnt and is_leader_cta:
|
||||
peek_ab_full_status = ab_pipeline.consumer_try_wait(
|
||||
ab_consumer_state
|
||||
)
|
||||
@ -899,7 +898,7 @@ class PersistentDenseGemmKernel:
|
||||
#
|
||||
# Mma mainloop
|
||||
#
|
||||
for k_block in range(k_block_cnt):
|
||||
for k_tile in range(k_tile_cnt):
|
||||
if is_leader_cta:
|
||||
# Conditionally wait for AB buffer full
|
||||
ab_pipeline.consumer_wait(
|
||||
@ -907,32 +906,32 @@ class PersistentDenseGemmKernel:
|
||||
)
|
||||
|
||||
# tCtAcc += tCrA * tCrB
|
||||
num_kphases = cute.size(tCrA, mode=[2])
|
||||
for kphase_idx in cutlass.range(num_kphases, unroll_full=True):
|
||||
kphase_coord = (
|
||||
num_kblocks = cute.size(tCrA, mode=[2])
|
||||
for kblock_idx in cutlass.range(num_kblocks, unroll_full=True):
|
||||
kblock_coord = (
|
||||
None,
|
||||
None,
|
||||
kphase_idx,
|
||||
kblock_idx,
|
||||
ab_consumer_state.index,
|
||||
)
|
||||
|
||||
cute.gemm(
|
||||
tiled_mma,
|
||||
tCtAcc,
|
||||
tCrA[kphase_coord],
|
||||
tCrB[kphase_coord],
|
||||
tCrA[kblock_coord],
|
||||
tCrB[kblock_coord],
|
||||
tCtAcc,
|
||||
)
|
||||
# Enable accumulate on tCtAcc after first kphase
|
||||
# Enable accumulate on tCtAcc after first kblock
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
|
||||
# Async arrive AB buffer empty
|
||||
ab_pipeline.consumer_release(ab_consumer_state)
|
||||
|
||||
# Peek (try_wait) AB buffer full for k_block = k_block + 1
|
||||
# Peek (try_wait) AB buffer full for k_tile = k_tile + 1
|
||||
ab_consumer_state.advance()
|
||||
peek_ab_full_status = cutlass.Boolean(1)
|
||||
if ab_consumer_state.count < k_block_cnt:
|
||||
if ab_consumer_state.count < k_tile_cnt:
|
||||
if is_leader_cta:
|
||||
peek_ab_full_status = ab_pipeline.consumer_try_wait(
|
||||
ab_consumer_state
|
||||
|
||||
@ -110,17 +110,6 @@ Constraints:
|
||||
"""
|
||||
|
||||
|
||||
class PipelineStateMinimal:
|
||||
"""
|
||||
Pipeline state contains an index and phase bit corresponding to the current position in the circular buffer.
|
||||
"""
|
||||
|
||||
def __init__(self, count, index, phase):
|
||||
self.count = count
|
||||
self.index = index
|
||||
self.phase = phase
|
||||
|
||||
|
||||
class DenseGemmKernel:
|
||||
"""
|
||||
This class implements batched matrix multiplication (C = A x B) with support for various data types
|
||||
@ -497,7 +486,6 @@ class DenseGemmKernel:
|
||||
grid=grid,
|
||||
block=[self.threads_per_cta, 1, 1],
|
||||
cluster=(*self.cluster_shape_mn, 1),
|
||||
smem=self.shared_storage.size_in_bytes(),
|
||||
stream=stream,
|
||||
)
|
||||
return
|
||||
@ -576,13 +564,19 @@ class DenseGemmKernel:
|
||||
pipeline.Agent.Thread, num_tma_producer
|
||||
)
|
||||
ab_pipeline = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=storage.ab_full_mbar_ptr.data_ptr(),
|
||||
num_stages=self.num_ab_stage,
|
||||
producer_group=ab_pipeline_producer_group,
|
||||
consumer_group=ab_pipeline_consumer_group,
|
||||
tx_count=self.num_tma_load_bytes,
|
||||
barrier_storage=storage.ab_full_mbar_ptr.data_ptr(),
|
||||
cta_layout_vmnk=cluster_layout_vmnk,
|
||||
)
|
||||
ab_producer_state = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Producer, self.num_ab_stage
|
||||
)
|
||||
ab_consumer_state = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Consumer, self.num_ab_stage
|
||||
)
|
||||
|
||||
# Initialize acc_pipeline (barrier) and states
|
||||
acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
||||
@ -590,10 +584,10 @@ class DenseGemmKernel:
|
||||
pipeline.Agent.Thread, self.threads_per_cta, self.threads_per_cta
|
||||
)
|
||||
acc_pipeline = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=storage.acc_full_mbar_ptr.data_ptr(),
|
||||
num_stages=self.num_acc_stage,
|
||||
producer_group=acc_pipeline_producer_group,
|
||||
consumer_group=acc_pipeline_consumer_group,
|
||||
barrier_storage=storage.acc_full_mbar_ptr.data_ptr(),
|
||||
cta_layout_vmnk=cluster_layout_vmnk,
|
||||
)
|
||||
acc_producer_state = pipeline.make_pipeline_state(
|
||||
@ -665,7 +659,7 @@ class DenseGemmKernel:
|
||||
gC_mnl = cute.local_tile(
|
||||
mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None)
|
||||
)
|
||||
k_block_cnt = cute.size(gA_mkl, mode=[3])
|
||||
k_tile_cnt = cute.size(gA_mkl, mode=[3])
|
||||
|
||||
#
|
||||
# Partition global tensor for TiledMMA_A/B/C
|
||||
@ -793,24 +787,12 @@ class DenseGemmKernel:
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# MAINLOOP
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
prefetch_k_block_cnt = cutlass.min(self.num_ab_stage - 2, k_block_cnt)
|
||||
prefetch_k_tile_cnt = cutlass.min(self.num_ab_stage - 2, k_tile_cnt)
|
||||
if warp_idx == 0:
|
||||
for k_block in cutlass.range(
|
||||
k_block_cnt,
|
||||
pipelining=self.num_ab_stage - 2,
|
||||
for k_tile in cutlass.range(
|
||||
k_tile_cnt,
|
||||
prefetch_stages=self.num_ab_stage - 2,
|
||||
):
|
||||
ab_producer_state = PipelineStateMinimal(
|
||||
k_block,
|
||||
k_block % self.num_ab_stage,
|
||||
cutlass.Int32((k_block // self.num_ab_stage) % 2) ^ 1,
|
||||
)
|
||||
|
||||
ab_consumer_state = PipelineStateMinimal(
|
||||
k_block,
|
||||
k_block % self.num_ab_stage,
|
||||
cutlass.Int32((k_block // self.num_ab_stage) % 2),
|
||||
)
|
||||
|
||||
# wait for AB buffer empty
|
||||
ab_pipeline.producer_acquire(ab_producer_state)
|
||||
|
||||
@ -835,22 +817,26 @@ class DenseGemmKernel:
|
||||
ab_pipeline.consumer_wait(ab_consumer_state)
|
||||
|
||||
# tCtAcc += tCrA * tCrB
|
||||
num_kphases = cute.size(tCrA, mode=[2])
|
||||
for kphase_idx in cutlass.range(num_kphases, unroll_full=True):
|
||||
kphase_coord = (None, None, kphase_idx, ab_consumer_state.index)
|
||||
num_kblocks = cute.size(tCrA, mode=[2])
|
||||
for kblock_idx in cutlass.range(num_kblocks, unroll_full=True):
|
||||
kblock_coord = (None, None, kblock_idx, ab_consumer_state.index)
|
||||
|
||||
cute.gemm(
|
||||
tiled_mma,
|
||||
tCtAcc,
|
||||
tCrA[kphase_coord],
|
||||
tCrB[kphase_coord],
|
||||
tCrA[kblock_coord],
|
||||
tCrB[kblock_coord],
|
||||
tCtAcc,
|
||||
)
|
||||
# Enable accumulate on tCtAcc after first kphase
|
||||
# Enable accumulate on tCtAcc after first kblock
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
|
||||
# Async arrive AB buffer empty
|
||||
ab_pipeline.consumer_release(ab_consumer_state)
|
||||
|
||||
ab_producer_state.advance()
|
||||
ab_consumer_state.advance()
|
||||
|
||||
# Async arrive accumulator buffer full
|
||||
if is_leader_cta:
|
||||
acc_pipeline.producer_commit(acc_producer_state)
|
||||
@ -964,12 +950,10 @@ class DenseGemmKernel:
|
||||
# Wait A/B buffer empty
|
||||
#
|
||||
if warp_idx == 0:
|
||||
ab_producer_state = PipelineStateMinimal(
|
||||
k_block_cnt,
|
||||
k_block_cnt % self.num_ab_stage,
|
||||
cutlass.Int32((k_block_cnt // self.num_ab_stage) % 2) ^ 1,
|
||||
)
|
||||
ab_pipeline.producer_acquire(ab_producer_state)
|
||||
# Reverse prefetch_k_tile_cnt times to next available buffer
|
||||
for i in range(prefetch_k_tile_cnt):
|
||||
ab_producer_state.reverse()
|
||||
ab_pipeline.producer_tail(ab_producer_state)
|
||||
return
|
||||
|
||||
def epilog_tmem_copy_and_partition(
|
||||
@ -1579,7 +1563,6 @@ def run_dense_gemm(
|
||||
warmup_iterations: int = 0,
|
||||
iterations: int = 1,
|
||||
skip_ref_check: bool = False,
|
||||
measure_launch_overhead=False,
|
||||
):
|
||||
"""
|
||||
Prepare A/B/C tensors, launch GPU kernel, and reference checking.
|
||||
@ -1725,7 +1708,7 @@ def run_dense_gemm(
|
||||
ref_c = ref
|
||||
elif c_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}:
|
||||
# m major: (l, n, m) -> (m, n, l)
|
||||
# k major: (l, m, n) -> (m, n, l)
|
||||
# n major: (l, m, n) -> (m, n, l)
|
||||
permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0)
|
||||
shape = (l, m, n) if c_major == "n" else (l, n, m)
|
||||
f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -475,7 +475,6 @@ class GroupedGemmKernel:
|
||||
grid=grid,
|
||||
block=[self.threads_per_cta, 1, 1],
|
||||
cluster=(*self.cluster_shape_mn, 1),
|
||||
smem=self.shared_storage.size_in_bytes(),
|
||||
stream=stream,
|
||||
)
|
||||
return
|
||||
@ -785,7 +784,7 @@ class GroupedGemmKernel:
|
||||
)
|
||||
tensormap_init_done = cutlass.Boolean(False)
|
||||
# tile count we have searched
|
||||
total_k_block_cnt = cutlass.Int32(0)
|
||||
total_k_tile_cnt = cutlass.Int32(0)
|
||||
# group index of last tile
|
||||
last_group_idx = cutlass.Int32(-1)
|
||||
work_tile = tile_sched.initial_work_tile_info()
|
||||
@ -795,7 +794,7 @@ class GroupedGemmKernel:
|
||||
cur_tile_coord,
|
||||
problem_sizes_mnkl,
|
||||
)
|
||||
cur_k_block_cnt = grouped_gemm_cta_tile_info.cta_tile_count_k
|
||||
cur_k_tile_cnt = grouped_gemm_cta_tile_info.cta_tile_count_k
|
||||
cur_group_idx = grouped_gemm_cta_tile_info.group_idx
|
||||
is_group_changed = cur_group_idx != last_group_idx
|
||||
# skip tensormap update if we're working on the same group
|
||||
@ -861,17 +860,17 @@ class GroupedGemmKernel:
|
||||
(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])
|
||||
]
|
||||
|
||||
num_prev_k_blk = total_k_block_cnt
|
||||
total_k_block_cnt += cur_k_block_cnt
|
||||
num_prev_k_blk = total_k_tile_cnt
|
||||
total_k_tile_cnt += cur_k_tile_cnt
|
||||
|
||||
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt
|
||||
tma_wr_k_block = cutlass.Int32(0)
|
||||
smem_wr_buffer = (num_prev_k_blk + tma_wr_k_block) % self.num_ab_stage
|
||||
# Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt
|
||||
tma_wr_k_tile = cutlass.Int32(0)
|
||||
smem_wr_buffer = (num_prev_k_blk + tma_wr_k_tile) % self.num_ab_stage
|
||||
tma_wr_ab_empty_phase = (
|
||||
num_prev_k_blk + tma_wr_k_block
|
||||
num_prev_k_blk + tma_wr_k_tile
|
||||
) // self.num_ab_stage % 2 ^ 1
|
||||
peek_ab_empty_status = cute.arch.mbarrier_conditional_try_wait(
|
||||
tma_wr_k_block < cur_k_block_cnt,
|
||||
tma_wr_k_tile < cur_k_tile_cnt,
|
||||
ab_empty_mbar_ptr + smem_wr_buffer,
|
||||
tma_wr_ab_empty_phase,
|
||||
)
|
||||
@ -882,10 +881,10 @@ class GroupedGemmKernel:
|
||||
#
|
||||
# Tma load loop
|
||||
#
|
||||
for k_block in cutlass.range(0, cur_k_block_cnt, 1, unroll=1):
|
||||
tma_wr_k_block_next = tma_wr_k_block + 1
|
||||
for k_tile in cutlass.range(0, cur_k_tile_cnt, 1, unroll=1):
|
||||
tma_wr_k_tile_next = tma_wr_k_tile + 1
|
||||
smem_wr_buffer_next = (
|
||||
num_prev_k_blk + tma_wr_k_block_next
|
||||
num_prev_k_blk + tma_wr_k_tile_next
|
||||
) % self.num_ab_stage
|
||||
tma_wr_ab_empty_phase_next = (
|
||||
tma_wr_ab_empty_phase ^ 1
|
||||
@ -911,7 +910,7 @@ class GroupedGemmKernel:
|
||||
# Load A/B with TMA
|
||||
cute.copy(
|
||||
tma_atom_a,
|
||||
tAgA_slice[(None, tma_wr_k_block)],
|
||||
tAgA_slice[(None, tma_wr_k_tile)],
|
||||
tAsA[(None, smem_wr_buffer)],
|
||||
tma_bar_ptr=smem_full_mbar_ptr,
|
||||
mcast_mask=a_full_mcast_mask,
|
||||
@ -922,7 +921,7 @@ class GroupedGemmKernel:
|
||||
)
|
||||
cute.copy(
|
||||
tma_atom_b,
|
||||
tBgB_slice[(None, tma_wr_k_block)],
|
||||
tBgB_slice[(None, tma_wr_k_tile)],
|
||||
tBsB[(None, smem_wr_buffer)],
|
||||
tma_bar_ptr=smem_full_mbar_ptr,
|
||||
mcast_mask=b_full_mcast_mask,
|
||||
@ -932,14 +931,14 @@ class GroupedGemmKernel:
|
||||
),
|
||||
)
|
||||
|
||||
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt + k_block + 1
|
||||
# Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1
|
||||
peek_ab_empty_status = cute.arch.mbarrier_conditional_try_wait(
|
||||
tma_wr_k_block_next < cur_k_block_cnt,
|
||||
tma_wr_k_tile_next < cur_k_tile_cnt,
|
||||
ab_empty_mbar_ptr + smem_wr_buffer_next,
|
||||
tma_wr_ab_empty_phase_next,
|
||||
)
|
||||
|
||||
tma_wr_k_block = tma_wr_k_block_next
|
||||
tma_wr_k_tile = tma_wr_k_tile_next
|
||||
smem_wr_buffer = smem_wr_buffer_next
|
||||
tma_wr_ab_empty_phase = tma_wr_ab_empty_phase_next
|
||||
|
||||
@ -998,12 +997,12 @@ class GroupedGemmKernel:
|
||||
|
||||
work_tile = tile_sched.initial_work_tile_info()
|
||||
# tile count we have searched
|
||||
total_k_block_cnt = cutlass.Int32(0)
|
||||
total_k_tile_cnt = cutlass.Int32(0)
|
||||
while work_tile.is_valid_tile:
|
||||
cur_tile_coord = work_tile.tile_idx
|
||||
# MMA warp is only interested in number of tiles along K dimension
|
||||
(
|
||||
cur_k_block_cnt,
|
||||
cur_k_tile_cnt,
|
||||
cur_group_idx,
|
||||
) = group_gemm_ts_helper.search_cluster_tile_count_k(
|
||||
cur_tile_coord,
|
||||
@ -1014,17 +1013,17 @@ class GroupedGemmKernel:
|
||||
# (MMA, MMA_M, MMA_N)
|
||||
tCtAcc = tCtAcc_base[(None, None, None, acc_buf_idx)]
|
||||
|
||||
num_prev_k_blk = total_k_block_cnt
|
||||
total_k_block_cnt += cur_k_block_cnt
|
||||
num_prev_k_blk = total_k_tile_cnt
|
||||
total_k_tile_cnt += cur_k_tile_cnt
|
||||
|
||||
# Peek (try_wait) AB buffer full for k_block = 0
|
||||
mma_rd_k_block = cutlass.Int32(0)
|
||||
smem_rd_buffer = (num_prev_k_blk + mma_rd_k_block) % self.num_ab_stage
|
||||
# Peek (try_wait) AB buffer full for k_tile = 0
|
||||
mma_rd_k_tile = cutlass.Int32(0)
|
||||
smem_rd_buffer = (num_prev_k_blk + mma_rd_k_tile) % self.num_ab_stage
|
||||
need_check_rd_buffer_full = (
|
||||
mma_rd_k_block < cur_k_block_cnt and is_leader_cta
|
||||
mma_rd_k_tile < cur_k_tile_cnt and is_leader_cta
|
||||
)
|
||||
mma_rd_ab_full_phase = (
|
||||
(num_prev_k_blk + mma_rd_k_block) // self.num_ab_stage % 2
|
||||
(num_prev_k_blk + mma_rd_k_tile) // self.num_ab_stage % 2
|
||||
)
|
||||
peek_ab_full_status = cute.arch.mbarrier_conditional_try_wait(
|
||||
need_check_rd_buffer_full,
|
||||
@ -1051,10 +1050,10 @@ class GroupedGemmKernel:
|
||||
#
|
||||
# Mma mainloop
|
||||
#
|
||||
for k_block in range(cur_k_block_cnt):
|
||||
mma_rd_k_block_next = cutlass.Int32(k_block + 1)
|
||||
for k_tile in range(cur_k_tile_cnt):
|
||||
mma_rd_k_tile_next = cutlass.Int32(k_tile + 1)
|
||||
smem_rd_buffer_next = (
|
||||
num_prev_k_blk + mma_rd_k_block_next
|
||||
num_prev_k_blk + mma_rd_k_tile_next
|
||||
) % self.num_ab_stage
|
||||
mma_rd_ab_full_phase_next = (
|
||||
mma_rd_ab_full_phase ^ 1
|
||||
@ -1069,18 +1068,18 @@ class GroupedGemmKernel:
|
||||
)
|
||||
|
||||
# tCtAcc += tCrA * tCrB
|
||||
num_kphases = cute.size(tCrA, mode=[2])
|
||||
for kphase_idx in cutlass.range(num_kphases, unroll_full=True):
|
||||
kphase_coord = (None, None, kphase_idx, smem_rd_buffer)
|
||||
num_kblocks = cute.size(tCrA, mode=[2])
|
||||
for kblock_idx in cutlass.range(num_kblocks, unroll_full=True):
|
||||
kblock_coord = (None, None, kblock_idx, smem_rd_buffer)
|
||||
|
||||
cute.gemm(
|
||||
tiled_mma,
|
||||
tCtAcc,
|
||||
tCrA[kphase_coord],
|
||||
tCrB[kphase_coord],
|
||||
tCrA[kblock_coord],
|
||||
tCrB[kblock_coord],
|
||||
tCtAcc,
|
||||
)
|
||||
# Enable accumulate on tCtAcc after first kphase
|
||||
# Enable accumulate on tCtAcc after first kblock
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
|
||||
# Async arrive AB buffer empty
|
||||
@ -1091,9 +1090,9 @@ class GroupedGemmKernel:
|
||||
self.cta_group,
|
||||
)
|
||||
|
||||
# Peek (try_wait) AB buffer full for k_block = k_block + 1
|
||||
# Peek (try_wait) AB buffer full for k_tile = k_tile + 1
|
||||
need_check_rd_buffer_full = (
|
||||
mma_rd_k_block_next < cur_k_block_cnt and is_leader_cta
|
||||
mma_rd_k_tile_next < cur_k_tile_cnt and is_leader_cta
|
||||
)
|
||||
|
||||
peek_ab_full_status = cute.arch.mbarrier_conditional_try_wait(
|
||||
@ -1102,7 +1101,7 @@ class GroupedGemmKernel:
|
||||
mma_rd_ab_full_phase_next,
|
||||
)
|
||||
|
||||
mma_rd_k_block = mma_rd_k_block_next
|
||||
mma_rd_k_tile = mma_rd_k_tile_next
|
||||
smem_rd_buffer = smem_rd_buffer_next
|
||||
mma_rd_ab_full_phase = mma_rd_ab_full_phase_next
|
||||
|
||||
@ -1201,7 +1200,7 @@ class GroupedGemmKernel:
|
||||
# wait tensormap initialization complete before update
|
||||
tensormap_manager.fence_tensormap_initialization()
|
||||
# tile count we have searched
|
||||
total_k_block_cnt = cutlass.Int32(0)
|
||||
total_k_tile_cnt = cutlass.Int32(0)
|
||||
# group index of last tile
|
||||
last_group_idx = cutlass.Int32(-1)
|
||||
while work_tile.is_valid_tile:
|
||||
@ -1240,8 +1239,8 @@ class GroupedGemmKernel:
|
||||
grouped_gemm_cta_tile_info.cta_tile_idx_n,
|
||||
0,
|
||||
)
|
||||
cur_k_block_cnt = grouped_gemm_cta_tile_info.cta_tile_count_k
|
||||
total_k_block_cnt += cur_k_block_cnt
|
||||
cur_k_tile_cnt = grouped_gemm_cta_tile_info.cta_tile_count_k
|
||||
total_k_tile_cnt += cur_k_tile_cnt
|
||||
|
||||
#
|
||||
# Slice to per mma tile index
|
||||
@ -1370,8 +1369,8 @@ class GroupedGemmKernel:
|
||||
#
|
||||
if warp_idx == self.epilog_warp_id[0]:
|
||||
cute.arch.mbarrier_wait(
|
||||
(ab_empty_mbar_ptr + ((total_k_block_cnt - 1) % self.num_ab_stage)),
|
||||
(((total_k_block_cnt - 1) // self.num_ab_stage) % 2),
|
||||
(ab_empty_mbar_ptr + ((total_k_tile_cnt - 1) % self.num_ab_stage)),
|
||||
(((total_k_tile_cnt - 1) // self.num_ab_stage) % 2),
|
||||
)
|
||||
|
||||
@cute.jit
|
||||
|
||||
@ -622,7 +622,6 @@ class SSDKernel:
|
||||
block=[self.threads_per_cta, 1, 1],
|
||||
cluster=self.cluster_shape_mnk,
|
||||
min_blocks_per_mp=1,
|
||||
smem=self.shared_storage.size_in_bytes(),
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
@ -693,7 +692,7 @@ class SSDKernel:
|
||||
G = cute.size(tma_tensor_b, mode=[3])
|
||||
NGROUP_RATIO = EH // G
|
||||
|
||||
# Make tiledMma
|
||||
# Make TiledMma
|
||||
(
|
||||
tiled_mma_intra1,
|
||||
tiled_mma_intra2,
|
||||
@ -1745,7 +1744,7 @@ class SSDKernel:
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
|
||||
# Combine INTER1_ACC/last_column/State
|
||||
exp_last_column = cute.arch.exp(last_column.ir_value())
|
||||
exp_last_column = cute.math.exp(last_column, fastmath=True)
|
||||
for reg_idx in range(0, cute.size(tTR_rP), 2):
|
||||
(
|
||||
tTR_rP[reg_idx],
|
||||
@ -2267,9 +2266,11 @@ class SSDKernel:
|
||||
) = cute.arch.fma_packed_f32x2(
|
||||
(tTR_rInter[reg_idx], tTR_rInter[reg_idx + 1]),
|
||||
(
|
||||
cute.arch.exp(tTR_rDeltaA[reg_idx].ir_value()),
|
||||
cute.arch.exp(
|
||||
tTR_rDeltaA[reg_idx + 1].ir_value()
|
||||
cute.math.exp(
|
||||
tTR_rDeltaA[reg_idx], fastmath=True
|
||||
),
|
||||
cute.math.exp(
|
||||
tTR_rDeltaA[reg_idx + 1], fastmath=True
|
||||
),
|
||||
),
|
||||
(tTR_rIntra[reg_idx], tTR_rIntra[reg_idx + 1]),
|
||||
@ -3072,14 +3073,19 @@ class SSDKernel:
|
||||
m, n = tCoord[subtile_idx]
|
||||
if m < n:
|
||||
tCompute[subtile_idx] = cutlass.Float32(-float("inf"))
|
||||
LOG2_E = cutlass.Float32(1.4426950408889634)
|
||||
for subtile_idx in cutlass.range(0, cute.size(tTR_rQ), 2, unroll_full=True):
|
||||
# TODO: use math.exp directly
|
||||
tCompute_log2e = cute.arch.mul_packed_f32x2(
|
||||
(tCompute[subtile_idx], tCompute[subtile_idx + 1]), (LOG2_E, LOG2_E)
|
||||
)
|
||||
(
|
||||
tCompute[subtile_idx],
|
||||
tCompute[subtile_idx + 1],
|
||||
) = cute.arch.mul_packed_f32x2(
|
||||
cute.arch.exp_packed_f32x2(
|
||||
(tCompute[subtile_idx], tCompute[subtile_idx + 1])
|
||||
(
|
||||
cute.math.exp2(tCompute_log2e[0], fastmath=True),
|
||||
cute.math.exp2(tCompute_log2e[1], fastmath=True),
|
||||
),
|
||||
(tCrDelta[subtile_idx], tCrDelta[subtile_idx + 1]),
|
||||
)
|
||||
@ -3245,11 +3251,11 @@ class SSDKernel:
|
||||
for reg_idx in range(0, cute.size(tBrB_Compute), 2):
|
||||
tCompute[reg_idx], tCompute[reg_idx + 1] = cute.arch.mul_packed_f32x2(
|
||||
(
|
||||
cute.arch.exp(
|
||||
(last_column - tBrDeltaA_Compute[reg_idx]).ir_value()
|
||||
cute.math.exp(
|
||||
(last_column - tBrDeltaA_Compute[reg_idx]), fastmath=True
|
||||
),
|
||||
cute.arch.exp(
|
||||
(last_column - tBrDeltaA_Compute[reg_idx + 1]).ir_value()
|
||||
cute.math.exp(
|
||||
(last_column - tBrDeltaA_Compute[reg_idx + 1]), fastmath=True
|
||||
),
|
||||
),
|
||||
(tBrDelta_Compute[reg_idx], tBrDelta_Compute[reg_idx + 1]),
|
||||
|
||||
@ -44,7 +44,7 @@ import cutlass.utils.hopper_helpers as sm90_utils
|
||||
|
||||
"""
|
||||
A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Hopper architecture
|
||||
using CUTE DSL.
|
||||
using CuTe DSL.
|
||||
- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M")
|
||||
- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K")
|
||||
- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M")
|
||||
@ -70,7 +70,7 @@ To run this example:
|
||||
.. code-block:: bash
|
||||
|
||||
python examples/hopper/dense_gemm.py \
|
||||
--mnkl 8192,8192,8192,1 --tile_shape_mnk 128,256,64 \
|
||||
--mnkl 8192,8192,8192,1 --tile_shape_mn 128,256 \
|
||||
--cluster_shape_mn 1,1 --a_dtype Float16 --b_dtype Float16 \
|
||||
--c_dtype Float16 --acc_dtype Float32 \
|
||||
--a_major k --b_major k --c_major n
|
||||
@ -85,7 +85,7 @@ To collect performance with NCU profiler:
|
||||
.. code-block:: bash
|
||||
|
||||
ncu python examples/hopper/dense_gemm.py \
|
||||
--mnkl 8192,8192,8192,1 --tile_shape_mnk 128,256,64 \
|
||||
--mnkl 8192,8192,8192,1 --tile_shape_mn 128,256 \
|
||||
--cluster_shape_mn 1,1 --a_dtype Float16 --b_dtype Float16 \
|
||||
--c_dtype Float16 --acc_dtype Float32 \
|
||||
--a_major k --b_major k --c_major n
|
||||
@ -95,14 +95,11 @@ Constraints:
|
||||
* For fp16 types, A and B must have the same data type
|
||||
* For fp8 types, A and B can have different types (e4m3fn or e5m2) but both must be 8-bit
|
||||
* Fp8 types only support k-major layout
|
||||
* Only fp32 accumulation is supported in this example
|
||||
* CTA tile shape M must be 64/128
|
||||
* CTA tile shape N must be 64/128/256
|
||||
* CTA tile shape K must be 64
|
||||
* Cluster shape M/N must be positive and power of 2, total cluster size <= 4
|
||||
* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned,
|
||||
i.e, number of elements is a multiple of 8, 16 for Float16, and Float8, respectively.
|
||||
* OOB tiles are not allowed when TMA store is disabled
|
||||
"""
|
||||
|
||||
|
||||
@ -128,10 +125,10 @@ def parse_arguments() -> argparse.Namespace:
|
||||
help="mnkl dimensions (comma-separated)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tile_shape_mnk",
|
||||
"--tile_shape_mn",
|
||||
type=parse_comma_separated_ints,
|
||||
choices=[(128, 128, 64), (128, 256, 64), (128, 64, 64), (64, 64, 64)],
|
||||
default=(128, 128, 64),
|
||||
choices=[(128, 128), (128, 256), (128, 64), (64, 64)],
|
||||
default=(128, 128),
|
||||
help="Cta tile shape (comma-separated)",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -190,8 +187,8 @@ def parse_arguments() -> argparse.Namespace:
|
||||
|
||||
if len(args.mnkl) != 4:
|
||||
parser.error("--mnkl must contain exactly 4 values")
|
||||
if len(args.tile_shape_mnk) != 3:
|
||||
parser.error("--tile_shape_mnk must contain exactly 3 values")
|
||||
if len(args.tile_shape_mn) != 2:
|
||||
parser.error("--tile_shape_mn must contain exactly 2 values")
|
||||
if len(args.cluster_shape_mn) != 2:
|
||||
parser.error("--cluster_shape_mn must contain exactly 2 values")
|
||||
|
||||
@ -210,10 +207,10 @@ class HopperWgmmaGemmKernel:
|
||||
|
||||
:param acc_dtype: Data type for accumulation during computation
|
||||
:type acc_dtype: type[cutlass.Numeric]
|
||||
:param tile_shape_mnk: Shape of the CTA tile (M,N,K)
|
||||
:type tile_shape_mnk: Tuple[int, int, int]
|
||||
:param cluster_shape_mnk: Cluster dimensions (M,N,K) for parallel processing
|
||||
:type cluster_shape_mnk: Tuple[int, int, int]
|
||||
:param tile_shape_mn: Shape of the CTA tile (M,N)
|
||||
:type tile_shape_mn: Tuple[int, int]
|
||||
:param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing
|
||||
:type cluster_shape_mn: Tuple[int, int]
|
||||
|
||||
:note: Data type requirements:
|
||||
- For 16-bit types: A and B must have the same data type
|
||||
@ -236,8 +233,8 @@ class HopperWgmmaGemmKernel:
|
||||
Example:
|
||||
>>> gemm = HopperWgmmaGemmKernel(
|
||||
... acc_dtype=cutlass.Float32,
|
||||
... tile_shape_mnk=(128, 256, 64),
|
||||
... cluster_shape_mnk=(1, 1, 1)
|
||||
... tile_shape_mn=(128, 256),
|
||||
... cluster_shape_mn=(1, 1)
|
||||
... )
|
||||
>>> gemm(a_tensor, b_tensor, c_tensor, stream)
|
||||
"""
|
||||
@ -245,8 +242,8 @@ class HopperWgmmaGemmKernel:
|
||||
def __init__(
|
||||
self,
|
||||
acc_dtype: type[cutlass.Numeric],
|
||||
tile_shape_mnk: tuple[int, int, int],
|
||||
cluster_shape_mnk: tuple[int, int, int],
|
||||
tile_shape_mn: tuple[int, int],
|
||||
cluster_shape_mn: tuple[int, int],
|
||||
):
|
||||
"""
|
||||
Initializes the configuration for a Hopper dense GEMM kernel.
|
||||
@ -256,28 +253,30 @@ class HopperWgmmaGemmKernel:
|
||||
|
||||
:param acc_dtype: Data type for accumulation during computation
|
||||
:type acc_dtype: type[cutlass.Numeric]
|
||||
:param tile_shape_mnk: Shape of the CTA tile (M,N,K)
|
||||
:type tile_shape_mnk: Tuple[int, int, int]
|
||||
:param cluster_shape_mnk: Cluster dimensions (M,N,K) for parallel processing
|
||||
:type cluster_shape_mnk: Tuple[int, int, int]
|
||||
:param tile_shape_mn: Shape of the CTA tile (M,N)
|
||||
:type tile_shape_mn: Tuple[int, int]
|
||||
:param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing
|
||||
:type cluster_shape_mn: Tuple[int, int]
|
||||
"""
|
||||
|
||||
self.acc_dtype = acc_dtype
|
||||
|
||||
self.cluster_shape_mnk = cluster_shape_mnk
|
||||
self.cluster_shape_mn = cluster_shape_mn
|
||||
self.mma_inst_shape_mn = None
|
||||
self.tile_shape_mnk = tuple(tile_shape_mnk)
|
||||
# K dimension is deferred in _setup_attributes
|
||||
self.tile_shape_mnk = (*tile_shape_mn, 1)
|
||||
# For large tile size, using two warp groups is preferred because using only one warp
|
||||
# group may result in register spill
|
||||
self.atom_layout_mnk = (
|
||||
(2, 1, 1)
|
||||
if tile_shape_mnk[0] > 64 and tile_shape_mnk[1] > 128
|
||||
if self.tile_shape_mnk[0] > 64 and self.tile_shape_mnk[1] > 128
|
||||
else (1, 1, 1)
|
||||
)
|
||||
self.num_mcast_ctas_a = None
|
||||
self.num_mcast_ctas_b = None
|
||||
self.is_a_mcast = False
|
||||
self.is_b_mcast = False
|
||||
self.tiled_mma = None
|
||||
|
||||
self.occupancy = 1
|
||||
self.mma_warp_groups = math.prod(self.atom_layout_mnk)
|
||||
@ -315,12 +314,27 @@ class HopperWgmmaGemmKernel:
|
||||
raise ValueError("CTA tile shape M must be 64/128")
|
||||
if self.tile_shape_mnk[1] not in [64, 128, 256]:
|
||||
raise ValueError("CTA tile shape N must be 64/128/256")
|
||||
if self.tile_shape_mnk[2] not in [64]:
|
||||
raise ValueError("CTA tile shape K must be 64")
|
||||
|
||||
self.cta_layout_mnk = cute.make_layout(self.cluster_shape_mnk)
|
||||
self.num_mcast_ctas_a = self.cluster_shape_mnk[1]
|
||||
self.num_mcast_ctas_b = self.cluster_shape_mnk[0]
|
||||
self.tiled_mma = sm90_utils.make_trivial_tiled_mma(
|
||||
self.a_dtype,
|
||||
self.b_dtype,
|
||||
self.a_layout.sm90_mma_major_mode(),
|
||||
self.b_layout.sm90_mma_major_mode(),
|
||||
self.acc_dtype,
|
||||
self.atom_layout_mnk,
|
||||
tiler_mn=(64, self.tile_shape_mnk[1]),
|
||||
)
|
||||
mma_inst_shape_k = cute.size(self.tiled_mma.shape_mnk, mode=[2])
|
||||
mma_inst_tile_k = 4
|
||||
self.tile_shape_mnk = (
|
||||
self.tile_shape_mnk[0],
|
||||
self.tile_shape_mnk[1],
|
||||
mma_inst_shape_k * mma_inst_tile_k,
|
||||
)
|
||||
|
||||
self.cta_layout_mnk = cute.make_layout((*self.cluster_shape_mn, 1))
|
||||
self.num_mcast_ctas_a = self.cluster_shape_mn[1]
|
||||
self.num_mcast_ctas_b = self.cluster_shape_mn[0]
|
||||
self.is_a_mcast = self.num_mcast_ctas_a > 1
|
||||
self.is_b_mcast = self.num_mcast_ctas_b > 1
|
||||
|
||||
@ -401,28 +415,18 @@ class HopperWgmmaGemmKernel:
|
||||
|
||||
self._setup_attributes()
|
||||
|
||||
tiled_mma = sm90_utils.make_trivial_tiled_mma(
|
||||
self.a_dtype,
|
||||
self.b_dtype,
|
||||
self.a_layout.sm90_mma_major_mode(),
|
||||
self.b_layout.sm90_mma_major_mode(),
|
||||
self.acc_dtype,
|
||||
self.atom_layout_mnk,
|
||||
tiler_mn=(64, self.tile_shape_mnk[1]),
|
||||
)
|
||||
|
||||
tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors(
|
||||
a,
|
||||
self.a_smem_layout_staged,
|
||||
(self.tile_shape_mnk[0], self.tile_shape_mnk[2]),
|
||||
self.cluster_shape_mnk[1],
|
||||
self.cluster_shape_mn[1],
|
||||
)
|
||||
|
||||
tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors(
|
||||
b,
|
||||
self.b_smem_layout_staged,
|
||||
(self.tile_shape_mnk[1], self.tile_shape_mnk[2]),
|
||||
self.cluster_shape_mnk[0],
|
||||
self.cluster_shape_mn[0],
|
||||
)
|
||||
|
||||
tma_atom_c, tma_tensor_c = self._make_tma_store_atoms_and_tensors(
|
||||
@ -431,20 +435,20 @@ class HopperWgmmaGemmKernel:
|
||||
self.epi_tile,
|
||||
)
|
||||
|
||||
grid = self._compute_grid(c, self.tile_shape_mnk, self.cluster_shape_mnk)
|
||||
grid = self._compute_grid(c, self.tile_shape_mnk, self.cluster_shape_mn)
|
||||
|
||||
@cute.struct
|
||||
class SharedStorage:
|
||||
mainloop_pipeline_array_ptr: cute.struct.MemRange[
|
||||
cutlass.Int64, self.ab_stage * 2
|
||||
]
|
||||
sa: cute.struct.Align[
|
||||
sA: cute.struct.Align[
|
||||
cute.struct.MemRange[
|
||||
self.a_dtype, cute.cosize(self.a_smem_layout_staged)
|
||||
],
|
||||
self.buffer_align_bytes,
|
||||
]
|
||||
sb: cute.struct.Align[
|
||||
sB: cute.struct.Align[
|
||||
cute.struct.MemRange[
|
||||
self.b_dtype, cute.cosize(self.b_smem_layout_staged)
|
||||
],
|
||||
@ -461,7 +465,7 @@ class HopperWgmmaGemmKernel:
|
||||
tma_tensor_b,
|
||||
tma_atom_c,
|
||||
tma_tensor_c,
|
||||
tiled_mma,
|
||||
self.tiled_mma,
|
||||
self.cta_layout_mnk,
|
||||
self.a_smem_layout_staged,
|
||||
self.b_smem_layout_staged,
|
||||
@ -469,8 +473,7 @@ class HopperWgmmaGemmKernel:
|
||||
).launch(
|
||||
grid=grid,
|
||||
block=[self.threads_per_cta, 1, 1],
|
||||
cluster=self.cluster_shape_mnk,
|
||||
smem=self.shared_storage.size_in_bytes(),
|
||||
cluster=(*self.cluster_shape_mn, 1),
|
||||
stream=stream,
|
||||
)
|
||||
return
|
||||
@ -562,8 +565,8 @@ class HopperWgmmaGemmKernel:
|
||||
|
||||
# Get the pid from cluster id
|
||||
bidx_in_cluster = cute.arch.block_in_cluster_idx()
|
||||
pid_m = cid_m * self.cluster_shape_mnk[0] + bidx_in_cluster[0]
|
||||
pid_n = cid_n * self.cluster_shape_mnk[1] + bidx_in_cluster[1]
|
||||
pid_m = cid_m * self.cluster_shape_mn[0] + bidx_in_cluster[0]
|
||||
pid_n = cid_n * self.cluster_shape_mn[1] + bidx_in_cluster[1]
|
||||
|
||||
tile_coord_mnkl = (pid_m, pid_n, None, bidz)
|
||||
cta_rank_in_cluster = cute.arch.make_warp_uniform(
|
||||
@ -621,22 +624,22 @@ class HopperWgmmaGemmKernel:
|
||||
)
|
||||
|
||||
# Cluster arrive after barrier init
|
||||
if cute.size(self.cluster_shape_mnk) > 1:
|
||||
if cute.size(self.cluster_shape_mn) > 1:
|
||||
cute.arch.cluster_arrive_relaxed()
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Generate smem tensor A/B
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
sa = storage.sa.get_tensor(
|
||||
sA = storage.sA.get_tensor(
|
||||
a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner
|
||||
)
|
||||
sb = storage.sb.get_tensor(
|
||||
sB = storage.sB.get_tensor(
|
||||
b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner
|
||||
)
|
||||
sc_ptr = cute.recast_ptr(
|
||||
sa.iterator, epi_smem_layout_staged.inner, dtype=self.c_dtype
|
||||
sC_ptr = cute.recast_ptr(
|
||||
sA.iterator, epi_smem_layout_staged.inner, dtype=self.c_dtype
|
||||
)
|
||||
sc = cute.make_tensor(sc_ptr, epi_smem_layout_staged.outer)
|
||||
sC = cute.make_tensor(sC_ptr, epi_smem_layout_staged.outer)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Local_tile partition global tensors
|
||||
@ -673,34 +676,34 @@ class HopperWgmmaGemmKernel:
|
||||
# TMA load A partition_S/D
|
||||
a_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (0, None, 0)).shape)
|
||||
a_cta_crd = cluster_coord_mnk[1]
|
||||
sa_for_tma_partition = cute.group_modes(sa, 0, 2)
|
||||
sA_for_tma_partition = cute.group_modes(sA, 0, 2)
|
||||
gA_for_tma_partition = cute.group_modes(gA_mkl, 0, 2)
|
||||
tAsA, tAgA_mkl = cute.nvgpu.cpasync.tma_partition(
|
||||
tma_atom_a,
|
||||
a_cta_crd,
|
||||
a_cta_layout,
|
||||
sa_for_tma_partition,
|
||||
sA_for_tma_partition,
|
||||
gA_for_tma_partition,
|
||||
)
|
||||
|
||||
# TMA load B partition_S/D
|
||||
b_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (None, 0, 0)).shape)
|
||||
b_cta_crd = cluster_coord_mnk[0]
|
||||
sb_for_tma_partition = cute.group_modes(sb, 0, 2)
|
||||
sB_for_tma_partition = cute.group_modes(sB, 0, 2)
|
||||
gB_for_tma_partition = cute.group_modes(gB_nkl, 0, 2)
|
||||
tBsB, tBgB_nkl = cute.nvgpu.cpasync.tma_partition(
|
||||
tma_atom_b,
|
||||
b_cta_crd,
|
||||
b_cta_layout,
|
||||
sb_for_tma_partition,
|
||||
sB_for_tma_partition,
|
||||
gB_for_tma_partition,
|
||||
)
|
||||
|
||||
# //////////////////////////////////////////////////////////////////////////////
|
||||
# Make frangments
|
||||
# Make fragments
|
||||
# //////////////////////////////////////////////////////////////////////////////
|
||||
tCsA = thr_mma.partition_A(sa)
|
||||
tCsB = thr_mma.partition_B(sb)
|
||||
tCsA = thr_mma.partition_A(sA)
|
||||
tCsB = thr_mma.partition_B(sB)
|
||||
tCrA = tiled_mma.make_fragment_A(tCsA)
|
||||
tCrB = tiled_mma.make_fragment_B(tCsB)
|
||||
|
||||
@ -711,7 +714,7 @@ class HopperWgmmaGemmKernel:
|
||||
# Cluster wait
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# cluster wait for barrier init
|
||||
if cute.size(self.cluster_shape_mnk) > 1:
|
||||
if cute.size(self.cluster_shape_mn) > 1:
|
||||
cute.arch.cluster_wait()
|
||||
else:
|
||||
cute.arch.sync_threads()
|
||||
@ -788,7 +791,7 @@ class HopperWgmmaGemmKernel:
|
||||
|
||||
tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, False)
|
||||
num_k_blocks = cute.size(tCrA, mode=[2])
|
||||
for k_tile in range(k_pipe_mmas):
|
||||
for k_tile in cutlass.range_constexpr(k_pipe_mmas):
|
||||
# Wait for A/B buffer to be ready
|
||||
mainloop_pipeline.consumer_wait(
|
||||
mainloop_consumer_read_state, peek_ab_full_status
|
||||
@ -917,7 +920,7 @@ class HopperWgmmaGemmKernel:
|
||||
# /////////////////////////////////////////////////////////////////////////////
|
||||
cute.nvgpu.warpgroup.wait_group(0)
|
||||
|
||||
if cute.size(self.cluster_shape_mnk) > 1:
|
||||
if cute.size(self.cluster_shape_mn) > 1:
|
||||
# Wait for all threads in the cluster to finish, avoid early release of smem
|
||||
cute.arch.cluster_arrive()
|
||||
cute.arch.cluster_wait()
|
||||
@ -950,33 +953,45 @@ class HopperWgmmaGemmKernel:
|
||||
|
||||
# (R2S, R2S_M, R2S_N, PIPE_D)
|
||||
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
|
||||
tRS_sD = thr_copy_r2s.partition_D(sc)
|
||||
tRS_sD = thr_copy_r2s.partition_D(sC)
|
||||
# (R2S, R2S_M, R2S_N)
|
||||
tRS_rAcc = tiled_copy_r2s.retile(accumulators)
|
||||
|
||||
# Allocate D registers.
|
||||
rD_shape = cute.shape(thr_copy_r2s.partition_S(sc))
|
||||
rD_shape = cute.shape(thr_copy_r2s.partition_S(sC))
|
||||
tRS_rD_layout = cute.make_layout(rD_shape[:3])
|
||||
tRS_rD = cute.make_fragment_like(tRS_rD_layout, self.acc_dtype)
|
||||
size_tRS_rD = cute.size(tRS_rD)
|
||||
|
||||
sepi_for_tma_partition = cute.group_modes(sc, 0, 2)
|
||||
tcgc_for_tma_partition = cute.zipped_divide(gC_mnl, self.epi_tile)
|
||||
sepi_for_tma_partition = cute.group_modes(sC, 0, 2)
|
||||
tCgC_for_tma_partition = cute.zipped_divide(gC_mnl, self.epi_tile)
|
||||
|
||||
bSG_sD, bSG_gD = cute.nvgpu.cpasync.tma_partition(
|
||||
tma_atom_c,
|
||||
0,
|
||||
cute.make_layout(1),
|
||||
sepi_for_tma_partition,
|
||||
tcgc_for_tma_partition,
|
||||
tCgC_for_tma_partition,
|
||||
)
|
||||
|
||||
epi_tile_num = cute.size(tcgc_for_tma_partition, mode=[1])
|
||||
epi_tile_shape = tcgc_for_tma_partition.shape[1]
|
||||
epi_tile_num = cute.size(tCgC_for_tma_partition, mode=[1])
|
||||
epi_tile_shape = tCgC_for_tma_partition.shape[1]
|
||||
epi_tile_layout = cute.make_layout(
|
||||
epi_tile_shape, stride=(epi_tile_shape[1], 1)
|
||||
)
|
||||
|
||||
for epi_idx in cutlass.range(epi_tile_num, unroll=epi_tile_num):
|
||||
# Initialize tma store c_pipeline
|
||||
c_producer_group = pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, self.threads_per_cta, self.threads_per_cta
|
||||
)
|
||||
c_pipeline = pipeline.PipelineTmaStore.create(
|
||||
num_stages=self.epi_stage,
|
||||
producer_group=c_producer_group,
|
||||
)
|
||||
|
||||
for epi_idx in cutlass.range_constexpr(epi_tile_num):
|
||||
# Copy from accumulators to D registers
|
||||
for epi_v in range(size_tRS_rD):
|
||||
for epi_v in cutlass.range_constexpr(size_tRS_rD):
|
||||
tRS_rD[epi_v] = tRS_rAcc[epi_idx * size_tRS_rD + epi_v]
|
||||
|
||||
# Type conversion
|
||||
@ -997,10 +1012,6 @@ class HopperWgmmaGemmKernel:
|
||||
# barrier for sync
|
||||
cute.arch.barrier()
|
||||
|
||||
# Get the global memory coordinate for the current epi tile.
|
||||
epi_tile_layout = cute.make_layout(
|
||||
epi_tile_shape, stride=(epi_tile_shape[1], 1)
|
||||
)
|
||||
gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
|
||||
# Copy from shared memory to global memory
|
||||
if warp_idx == 0:
|
||||
@ -1009,11 +1020,14 @@ class HopperWgmmaGemmKernel:
|
||||
bSG_sD[(None, epi_buffer)],
|
||||
bSG_gD[(None, gmem_coord)],
|
||||
)
|
||||
cute.arch.cp_async_bulk_commit_group()
|
||||
cute.arch.cp_async_bulk_wait_group(self.epi_stage - 1, read=True)
|
||||
c_pipeline.producer_commit()
|
||||
c_pipeline.producer_acquire()
|
||||
|
||||
cute.arch.barrier()
|
||||
|
||||
if warp_idx == 0:
|
||||
c_pipeline.producer_tail()
|
||||
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
@ -1055,9 +1069,7 @@ class HopperWgmmaGemmKernel:
|
||||
mbar_helpers_bytes = 1024
|
||||
|
||||
ab_stage = (
|
||||
(smem_capacity - occupancy * 1024) // occupancy
|
||||
- mbar_helpers_bytes
|
||||
- epi_bytes
|
||||
smem_capacity // occupancy - mbar_helpers_bytes - epi_bytes
|
||||
) // ab_bytes_per_stage
|
||||
return ab_stage, epi_stage
|
||||
|
||||
@ -1195,7 +1207,7 @@ class HopperWgmmaGemmKernel:
|
||||
def _compute_grid(
|
||||
c: cute.Tensor,
|
||||
tile_shape_mnk: tuple[int, int, int],
|
||||
cluster_shape_mnk: tuple[int, int, int],
|
||||
cluster_shape_mn: tuple[int, int],
|
||||
) -> tuple[int, int, int]:
|
||||
"""Compute grid shape for the output tensor C.
|
||||
|
||||
@ -1203,8 +1215,8 @@ class HopperWgmmaGemmKernel:
|
||||
:type c: cute.Tensor
|
||||
:param tile_shape_mnk: The shape (M, N, K) of the CTA tile.
|
||||
:type tile_shape_mnk: tuple[int, int, int]
|
||||
:param cluster_shape_mnk: Shape of each cluster in M, N, K dimensions.
|
||||
:type cluster_shape_mnk: tuple[int, int, int]
|
||||
:param cluster_shape_mn: Shape of each cluster in M, N dimensions.
|
||||
:type cluster_shape_mn: tuple[int, int]
|
||||
|
||||
:return: Grid shape for kernel launch.
|
||||
:rtype: tuple[int, int, int]
|
||||
@ -1212,8 +1224,9 @@ class HopperWgmmaGemmKernel:
|
||||
|
||||
c_shape = (tile_shape_mnk[0], tile_shape_mnk[1])
|
||||
gc = cute.zipped_divide(c, tiler=c_shape)
|
||||
clusters = cute.ceil_div(cute.get(gc.layout, mode=[1]).shape, cluster_shape_mnk)
|
||||
grid = tuple(x * y for x, y in zip(clusters, cluster_shape_mnk))
|
||||
cluster_shape_mnl = (*cluster_shape_mn, 1)
|
||||
clusters = cute.ceil_div(cute.get(gc.layout, mode=[1]).shape, cluster_shape_mnl)
|
||||
grid = tuple(x * y for x, y in zip(clusters, cluster_shape_mnl))
|
||||
return grid
|
||||
|
||||
@staticmethod
|
||||
@ -1363,7 +1376,7 @@ def run(
|
||||
a_major: str,
|
||||
b_major: str,
|
||||
c_major: str,
|
||||
tile_shape_mnk: Tuple[int, int, int],
|
||||
tile_shape_mn: Tuple[int, int],
|
||||
cluster_shape_mn: Tuple[int, int],
|
||||
tolerance: float,
|
||||
warmup_iterations: int,
|
||||
@ -1387,8 +1400,8 @@ def run(
|
||||
:type acc_dtype: Type[cutlass.Numeric]
|
||||
:param a_major/b_major/c_major: Memory layout of tensor A/B/C
|
||||
:type a_major/b_major/c_major: str
|
||||
:param tile_shape_mnk: CTA tile shape (M, N, K)
|
||||
:type tile_shape_mnk: Tuple[int, int, int]
|
||||
:param tile_shape_mn: CTA tile shape (M, N)
|
||||
:type tile_shape_mn: Tuple[int, int]
|
||||
:param cluster_shape_mn: Cluster shape (M, N)
|
||||
:type cluster_shape_mn: Tuple[int, int]
|
||||
:param tolerance: Tolerance value for reference validation comparison
|
||||
@ -1411,7 +1424,7 @@ def run(
|
||||
f"A dtype: {a_dtype}, B dtype: {b_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}"
|
||||
)
|
||||
print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}")
|
||||
print(f"Tile Shape: {tile_shape_mnk}, Cluster Shape: {cluster_shape_mn}")
|
||||
print(f"Tile Shape: {tile_shape_mn}, Cluster Shape: {cluster_shape_mn}")
|
||||
print(f"Tolerance: {tolerance}")
|
||||
print(f"Warmup iterations: {warmup_iterations}")
|
||||
print(f"Iterations: {iterations}")
|
||||
@ -1420,7 +1433,6 @@ def run(
|
||||
|
||||
# Unpack parameters
|
||||
m, n, k, l = mnkl
|
||||
cluster_shape_mnk = (*cluster_shape_mn, 1)
|
||||
|
||||
# Skip unsupported types
|
||||
if not HopperWgmmaGemmKernel.is_valid_dtypes(
|
||||
@ -1488,7 +1500,7 @@ def run(
|
||||
b, mB, b_torch = create_and_permute_tensor(l, n, k, b_major == "n", b_dtype)
|
||||
c, mC, c_torch = create_and_permute_tensor(l, m, n, c_major == "m", c_dtype)
|
||||
|
||||
gemm = HopperWgmmaGemmKernel(acc_dtype, tile_shape_mnk, cluster_shape_mnk)
|
||||
gemm = HopperWgmmaGemmKernel(acc_dtype, tile_shape_mn, cluster_shape_mn)
|
||||
|
||||
torch_stream = torch.cuda.Stream()
|
||||
stream = cuda.CUstream(torch_stream.cuda_stream)
|
||||
@ -1572,7 +1584,7 @@ if __name__ == "__main__":
|
||||
args.a_major,
|
||||
args.b_major,
|
||||
args.c_major,
|
||||
args.tile_shape_mnk,
|
||||
args.tile_shape_mn,
|
||||
args.cluster_shape_mn,
|
||||
args.tolerance,
|
||||
args.warmup_iterations,
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -41,22 +41,9 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"a_vec: tensor_value<vector<12xf32> o (3, 4)>\n",
|
||||
"b_vec: tensor_value<vector<12xf32> o (3, 4)>\n",
|
||||
"tensor(raw_ptr(0x0000000006cff170: f32, generic, align<4>) o (3,4):(4,1), data=\n",
|
||||
" [[ 2.000000, 2.000000, 2.000000, 2.000000, ],\n",
|
||||
" [ 2.000000, 2.000000, 2.000000, 2.000000, ],\n",
|
||||
" [ 2.000000, 2.000000, 2.000000, 2.000000, ]])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def load_and_store(res: cute.Tensor, a: cute.Tensor, b: cute.Tensor):\n",
|
||||
@ -91,22 +78,9 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor_value<vector<24xf32> o (4, 2, 3)> -> tensor_value<vector<12xf32> o (4, 3)>\n",
|
||||
"tensor(raw_ptr(0x00000000071acaf0: f32, generic, align<4>) o (4,3):(3,1), data=\n",
|
||||
" [[ 3.000000, 4.000000, 5.000000, ],\n",
|
||||
" [ 9.000000, 10.000000, 11.000000, ],\n",
|
||||
" [ 15.000000, 16.000000, 17.000000, ],\n",
|
||||
" [ 21.000000, 22.000000, 23.000000, ]])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def apply_slice(src: cute.Tensor, dst: cute.Tensor, indices: cutlass.Constexpr):\n",
|
||||
@ -155,19 +129,9 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor_value<vector<24xf32> o (4, 2, 3)> -> ?\n",
|
||||
"tensor(raw_ptr(0x00000000013cbbe0: f32, generic, align<4>) o (1):(1), data=\n",
|
||||
" [ 10.000000, ])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def slice_2():\n",
|
||||
" src_shape = (4, 2, 3)\n",
|
||||
@ -195,40 +159,9 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [ 3.000000, ],\n",
|
||||
" [ 3.000000, ],\n",
|
||||
" [ 3.000000, ])\n",
|
||||
"tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [-1.000000, ],\n",
|
||||
" [-1.000000, ],\n",
|
||||
" [-1.000000, ])\n",
|
||||
"tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [ 2.000000, ],\n",
|
||||
" [ 2.000000, ],\n",
|
||||
" [ 2.000000, ])\n",
|
||||
"tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [ 0.500000, ],\n",
|
||||
" [ 0.500000, ],\n",
|
||||
" [ 0.500000, ])\n",
|
||||
"tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [ 0.000000, ],\n",
|
||||
" [ 0.000000, ],\n",
|
||||
" [ 0.000000, ])\n",
|
||||
"tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [ 1.000000, ],\n",
|
||||
" [ 1.000000, ],\n",
|
||||
" [ 1.000000, ])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def binary_op_1(res: cute.Tensor, a: cute.Tensor, b: cute.Tensor):\n",
|
||||
@ -236,28 +169,22 @@
|
||||
" b_vec = b.load()\n",
|
||||
"\n",
|
||||
" add_res = a_vec + b_vec\n",
|
||||
" res.store(add_res)\n",
|
||||
" cute.print_tensor(res) # prints [3.000000, 3.000000, 3.000000]\n",
|
||||
" cute.print_tensor(add_res) # prints [3.000000, 3.000000, 3.000000]\n",
|
||||
"\n",
|
||||
" sub_res = a_vec - b_vec\n",
|
||||
" res.store(sub_res)\n",
|
||||
" cute.print_tensor(res) # prints [-1.000000, -1.000000, -1.000000]\n",
|
||||
" cute.print_tensor(sub_res) # prints [-1.000000, -1.000000, -1.000000]\n",
|
||||
"\n",
|
||||
" mul_res = a_vec * b_vec\n",
|
||||
" res.store(mul_res)\n",
|
||||
" cute.print_tensor(res) # prints [2.000000, 2.000000, 2.000000]\n",
|
||||
" cute.print_tensor(mul_res) # prints [2.000000, 2.000000, 2.000000]\n",
|
||||
"\n",
|
||||
" div_res = a_vec / b_vec\n",
|
||||
" res.store(div_res)\n",
|
||||
" cute.print_tensor(res) # prints [0.500000, 0.500000, 0.500000]\n",
|
||||
" cute.print_tensor(div_res) # prints [0.500000, 0.500000, 0.500000]\n",
|
||||
"\n",
|
||||
" floor_div_res = a_vec // b_vec\n",
|
||||
" res.store(floor_div_res)\n",
|
||||
" cute.print_tensor(res) # prints [0.000000, 0.000000, 0.000000]\n",
|
||||
" cute.print_tensor(res) # prints [0.000000, 0.000000, 0.000000]\n",
|
||||
"\n",
|
||||
" mod_res = a_vec % b_vec\n",
|
||||
" res.store(mod_res)\n",
|
||||
" cute.print_tensor(res) # prints [1.000000, 1.000000, 1.000000]\n",
|
||||
" cute.print_tensor(mod_res) # prints [1.000000, 1.000000, 1.000000]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"a = np.empty((3,), dtype=np.float32)\n",
|
||||
@ -270,68 +197,31 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [ 3.000000, ],\n",
|
||||
" [ 3.000000, ],\n",
|
||||
" [ 3.000000, ])\n",
|
||||
"tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [-1.000000, ],\n",
|
||||
" [-1.000000, ],\n",
|
||||
" [-1.000000, ])\n",
|
||||
"tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [ 2.000000, ],\n",
|
||||
" [ 2.000000, ],\n",
|
||||
" [ 2.000000, ])\n",
|
||||
"tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [ 0.500000, ],\n",
|
||||
" [ 0.500000, ],\n",
|
||||
" [ 0.500000, ])\n",
|
||||
"tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [ 0.000000, ],\n",
|
||||
" [ 0.000000, ],\n",
|
||||
" [ 0.000000, ])\n",
|
||||
"tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [ 1.000000, ],\n",
|
||||
" [ 1.000000, ],\n",
|
||||
" [ 1.000000, ])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def binary_op_2(res: cute.Tensor, a: cute.Tensor, c: cutlass.Constexpr):\n",
|
||||
" a_vec = a.load()\n",
|
||||
"\n",
|
||||
" add_res = a_vec + c\n",
|
||||
" res.store(add_res)\n",
|
||||
" cute.print_tensor(res) # prints [3.000000, 3.000000, 3.000000]\n",
|
||||
" cute.print_tensor(add_res) # prints [3.000000, 3.000000, 3.000000]\n",
|
||||
"\n",
|
||||
" sub_res = a_vec - c\n",
|
||||
" res.store(sub_res)\n",
|
||||
" cute.print_tensor(res) # prints [-1.000000, -1.000000, -1.000000]\n",
|
||||
" cute.print_tensor(sub_res) # prints [-1.000000, -1.000000, -1.000000]\n",
|
||||
"\n",
|
||||
" mul_res = a_vec * c\n",
|
||||
" res.store(mul_res)\n",
|
||||
" cute.print_tensor(res) # prints [2.000000, 2.000000, 2.000000]\n",
|
||||
" cute.print_tensor(mul_res) # prints [2.000000, 2.000000, 2.000000]\n",
|
||||
"\n",
|
||||
" div_res = a_vec / c\n",
|
||||
" res.store(div_res)\n",
|
||||
" cute.print_tensor(res) # prints [0.500000, 0.500000, 0.500000]\n",
|
||||
" cute.print_tensor(div_res) # prints [0.500000, 0.500000, 0.500000]\n",
|
||||
"\n",
|
||||
" floor_div_res = a_vec // c\n",
|
||||
" res.store(floor_div_res)\n",
|
||||
" cute.print_tensor(res) # prints [0.000000, 0.000000, 0.000000]\n",
|
||||
" cute.print_tensor(floor_div_res) # prints [0.000000, 0.000000, 0.000000]\n",
|
||||
"\n",
|
||||
" mod_res = a_vec % c\n",
|
||||
" res.store(mod_res)\n",
|
||||
" cute.print_tensor(res) # prints [1.000000, 1.000000, 1.000000]\n",
|
||||
" cute.print_tensor(mod_res) # prints [1.000000, 1.000000, 1.000000]\n",
|
||||
"\n",
|
||||
"a = np.empty((3,), dtype=np.float32)\n",
|
||||
"a.fill(1.0)\n",
|
||||
@ -342,17 +232,9 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[False True False]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def binary_op_3(res: cute.Tensor, a: cute.Tensor, b: cute.Tensor):\n",
|
||||
@ -378,17 +260,9 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[3 0 7]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def binary_op_4(res: cute.Tensor, a: cute.Tensor, b: cute.Tensor):\n",
|
||||
@ -420,44 +294,23 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor(raw_ptr(0x0000000007fbd180: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [ 2.000000, ],\n",
|
||||
" [ 2.000000, ],\n",
|
||||
" [ 2.000000, ])\n",
|
||||
"tensor(raw_ptr(0x0000000007fbd180: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [-0.756802, ],\n",
|
||||
" [-0.756802, ],\n",
|
||||
" [-0.756802, ])\n",
|
||||
"tensor(raw_ptr(0x0000000007fbd180: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [ 16.000000, ],\n",
|
||||
" [ 16.000000, ],\n",
|
||||
" [ 16.000000, ])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def unary_op_1(res: cute.Tensor, a: cute.Tensor):\n",
|
||||
" a_vec = a.load()\n",
|
||||
"\n",
|
||||
" sqrt_res = cute.math.sqrt(a_vec)\n",
|
||||
" res.store(sqrt_res)\n",
|
||||
" cute.print_tensor(res) # prints [2.000000, 2.000000, 2.000000]\n",
|
||||
" cute.print_tensor(sqrt_res) # prints [2.000000, 2.000000, 2.000000]\n",
|
||||
"\n",
|
||||
" sin_res = cute.math.sin(a_vec)\n",
|
||||
" res.store(sin_res)\n",
|
||||
" cute.print_tensor(res) # prints [-0.756802, -0.756802, -0.756802]\n",
|
||||
" cute.print_tensor(sin_res) # prints [-0.756802, -0.756802, -0.756802]\n",
|
||||
"\n",
|
||||
" exp2_res = cute.math.exp2(a_vec)\n",
|
||||
" res.store(exp2_res)\n",
|
||||
" cute.print_tensor(res) # prints [16.000000, 16.000000, 16.000000]\n",
|
||||
" cute.print_tensor(exp2_res) # prints [16.000000, 16.000000, 16.000000]\n",
|
||||
"\n",
|
||||
"a = np.array([4.0, 4.0, 4.0], dtype=np.float32)\n",
|
||||
"res = np.empty((3,), dtype=np.float32)\n",
|
||||
@ -470,29 +323,18 @@
|
||||
"source": [
|
||||
"#### Reduction Operation\n",
|
||||
"\n",
|
||||
"The `TensorSSA`'s `reduce` method applies a specified reduction operation (`ReductionOp.ADD`, `ReductionOp.MUL`, `ReductionOp.MAX`, `ReductionOp.MIN`) starting with an initial value, and performs this reduction along the dimensions specified by the `reduction_profile.`. The result is typically a new `TensorSSA` with reduced dimensions or a scalar value if reduces across all axes."
|
||||
"The `TensorSSA`'s `reduce` method applies a specified reduction operation (`ReductionOp.ADD`, \n",
|
||||
"`ReductionOp.MUL`, `ReductionOp.MAX`, `ReductionOp.MIN`) starting with an initial value, and \n",
|
||||
"performs this reduction along the dimensions specified by the `reduction_profile`. The result \n",
|
||||
"is typically a new `TensorSSA` with reduced dimensions or a scalar value if it reduces across \n",
|
||||
"all axes."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"21.000000\n",
|
||||
"tensor(raw_ptr(0x00007ffd1ea2bca0: f32, rmem, align<32>) o (2):(1), data=\n",
|
||||
" [ 6.000000, ],\n",
|
||||
" [ 15.000000, ])\n",
|
||||
"tensor(raw_ptr(0x00007ffd1ea2bcc0: f32, rmem, align<32>) o (3):(1), data=\n",
|
||||
" [ 6.000000, ],\n",
|
||||
" [ 8.000000, ],\n",
|
||||
" [ 10.000000, ])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def reduction_op(a: cute.Tensor):\n",
|
||||
@ -507,36 +349,138 @@
|
||||
" 0.0,\n",
|
||||
" reduction_profile=0\n",
|
||||
" )\n",
|
||||
" cute.printf(red_res) # prints 21.000000\n",
|
||||
" cute.printf(red_res) # prints 21.000000\n",
|
||||
"\n",
|
||||
" red_res = a_vec.reduce(\n",
|
||||
" cute.ReductionOp.ADD,\n",
|
||||
" 0.0,\n",
|
||||
" reduction_profile=(None, 1)\n",
|
||||
" )\n",
|
||||
" # We can't print the TensorSSA directly at this point, so we store it to a new Tensor and print it.\n",
|
||||
" res = cute.make_fragment(red_res.shape, cutlass.Float32)\n",
|
||||
" res.store(red_res)\n",
|
||||
" cute.print_tensor(res) # prints [6.000000, 15.000000]\n",
|
||||
" cute.print_tensor(red_res) # prints [6.000000, 15.000000]\n",
|
||||
"\n",
|
||||
" red_res = a_vec.reduce(\n",
|
||||
" cute.ReductionOp.ADD,\n",
|
||||
" 1.0,\n",
|
||||
" reduction_profile=(1, None)\n",
|
||||
" )\n",
|
||||
" res = cute.make_fragment(red_res.shape, cutlass.Float32)\n",
|
||||
" res.store(red_res)\n",
|
||||
" cute.print_tensor(res) # prints [6.000000, 8.000000, 10.000000]\n",
|
||||
" cute.print_tensor(red_res) # prints [6.000000, 8.000000, 10.000000]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"a = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)\n",
|
||||
"reduction_op(from_dlpack(a))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Broadcast\n",
|
||||
"\n",
|
||||
"`TensorSSA` supports broadcasting operations following NumPy's broadcasting rules. Broadcasting \n",
|
||||
"allows you to perform operations on arrays of different shapes when certain conditions are met. \n",
|
||||
"The key rules are:\n",
|
||||
"\n",
|
||||
"1. Source shape is padded with 1's to match the rank of target shape\n",
|
||||
"2. The size in each mode of source shape must either be 1 or equal to target shape\n",
|
||||
"3. After broadcasting, all modes should match target shape\n",
|
||||
"\n",
|
||||
"Let's look at some examples of broadcasting in action:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import cutlass\n",
|
||||
"import cutlass.cute as cute\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@cute.jit\n",
|
||||
"def broadcast_examples():\n",
|
||||
" a = cute.make_fragment((1,3), dtype=cutlass.Float32)\n",
|
||||
" a[0] = 0.0\n",
|
||||
" a[1] = 1.0\n",
|
||||
" a[2] = 2.0\n",
|
||||
" a_val = a.load()\n",
|
||||
" cute.print_tensor(a_val.broadcast_to((4, 3)))\n",
|
||||
" # tensor(raw_ptr(0x00007ffe26625740: f32, rmem, align<32>) o (4,3):(1,4), data=\n",
|
||||
" # [[ 0.000000, 1.000000, 2.000000, ],\n",
|
||||
" # [ 0.000000, 1.000000, 2.000000, ],\n",
|
||||
" # [ 0.000000, 1.000000, 2.000000, ],\n",
|
||||
" # [ 0.000000, 1.000000, 2.000000, ]])\n",
|
||||
"\n",
|
||||
" c = cute.make_fragment((4,1), dtype=cutlass.Float32)\n",
|
||||
" c[0] = 0.0\n",
|
||||
" c[1] = 1.0\n",
|
||||
" c[2] = 2.0\n",
|
||||
" c[3] = 3.0\n",
|
||||
" cute.print_tensor(a.load() + c.load())\n",
|
||||
" # tensor(raw_ptr(0x00007ffe26625780: f32, rmem, align<32>) o (4,3):(1,4), data=\n",
|
||||
" # [[ 0.000000, 1.000000, 2.000000, ],\n",
|
||||
" # [ 1.000000, 2.000000, 3.000000, ],\n",
|
||||
" # [ 2.000000, 3.000000, 4.000000, ],\n",
|
||||
" # [ 3.000000, 4.000000, 5.000000, ]])\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"broadcast_examples()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"vscode": {
|
||||
"languageId": "raw"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"The examples above demonstrate two key broadcasting scenarios:\n",
|
||||
"\n",
|
||||
"1. **Row Vector Broadcasting**: In the first example, we create a row vector `a` with shape \n",
|
||||
" (1, 3) containing values [0.0, 1.0, 2.0]. When we broadcast it to shape (4, 3), the values \n",
|
||||
" are repeated across the first dimension, resulting in:\n",
|
||||
" ```\n",
|
||||
" [[0.0, 1.0, 2.0],\n",
|
||||
" [0.0, 1.0, 2.0],\n",
|
||||
" [0.0, 1.0, 2.0],\n",
|
||||
" [0.0, 1.0, 2.0]]\n",
|
||||
" ```\n",
|
||||
" This demonstrates how a row vector can be broadcast to create multiple identical rows.\n",
|
||||
"\n",
|
||||
"2. **Column Vector and Row Vector Addition**: In the second example, we have:\n",
|
||||
" - A row vector `a` with shape (1, 3) containing [0.0, 1.0, 2.0]\n",
|
||||
" - A column vector `c` with shape (4, 1) containing [0.0, 1.0, 2.0, 3.0]\n",
|
||||
" \n",
|
||||
" When we add these together, both vectors are broadcast to shape (4, 3):\n",
|
||||
" - The row vector is broadcast vertically (4 times)\n",
|
||||
" - The column vector is broadcast horizontally (3 times)\n",
|
||||
" \n",
|
||||
" The result is:\n",
|
||||
" ```\n",
|
||||
" [[0.0 + 0.0, 1.0 + 0.0, 2.0 + 0.0],\n",
|
||||
" [0.0 + 1.0, 1.0 + 1.0, 2.0 + 1.0],\n",
|
||||
" [0.0 + 2.0, 1.0 + 2.0, 2.0 + 2.0],\n",
|
||||
" [0.0 + 3.0, 1.0 + 3.0, 2.0 + 3.0]]\n",
|
||||
" ```\n",
|
||||
" =\n",
|
||||
" ```\n",
|
||||
" [[0.0, 1.0, 2.0],\n",
|
||||
" [1.0, 2.0, 3.0],\n",
|
||||
" [2.0, 3.0, 4.0],\n",
|
||||
" [3.0, 4.0, 5.0]]\n",
|
||||
" ```\n",
|
||||
"\n",
|
||||
"This demonstrates how `TensorSSA` can automatically handle broadcasting of both row and column \n",
|
||||
"vectors in arithmetic operations, following the broadcasting rules where each dimension must \n",
|
||||
"either be 1 or match the target size. The broadcasting is handled implicitly during operations, \n",
|
||||
"making it easy to work with tensors of different shapes.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"display_name": ".venv3_12",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@ -550,7 +494,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.5"
|
||||
"version": "3.12.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
Reference in New Issue
Block a user