Compare commits

...

11 Commits

Author SHA1 Message Date
f3fde58372 Update pyproject.toml
update version to 4.2.1
2025-09-24 01:19:30 -04:00
a8749e67ba Update CHANGELOG.md
format change
2025-09-23 17:33:42 -04:00
c609b86db2 Feature/add bottom causal mask (#2480)
* Rebase to latest

* update

* upd

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Update fmha_fusion.hpp

* Update fmha_fusion.hpp

fixed flipped logic for isQBegin

* Update fmha_fusion.hpp

* Avoid use of booleans

The current expression is confusing

* fmt

* Update fmha_fusion.hpp

Reproduce error/fix with: 
./77_blackwell_fmha_fp16 --verify --b=1 --q=1013 --k=1024 --h=1 --h_k=1 --mask=causal --causal-type=qend

* add test, format

---------

Co-authored-by: Richard Cai <ricai@nvidia.com>
Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
2025-09-23 14:11:16 -07:00
177a82e251 Rename python/cutlass to python/cutlass_cppgen (#2652) 2025-09-23 14:10:50 -07:00
4260d4aef9 4.2.1 update 2025-09-23 14:09:15 -07:00
ee914c3cec v4.2.1 update. (#2667) 2025-09-23 14:25:14 -04:00
59b61c606f add support matrix 2025-09-17 20:20:50 -07:00
wbn
6b73aedb11 Fxied a typo in pipeline descript docs. (#2623) 2025-09-17 20:17:53 -07:00
ebf5e5effd Fix: a calculation error in the example of dividing out in the 02_layout_algebra doc (#2635) 2025-09-17 20:16:17 -07:00
df3923b0bb Fix doc cute 03_tensor.md link typo (#2627)
* Update 03_tensor.md fix link typo

change path to relative path

* Update 03_tensor.md

---------

Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
2025-09-17 20:15:14 -07:00
a49f8062e3 Remove old-version dsl examples (#2645) 2025-09-17 22:23:07 -04:00
96 changed files with 199 additions and 585 deletions

View File

@ -2,6 +2,20 @@
# CUTLASS 4.x
## [4.2.1](https://github.com/NVIDIA/cutlass/releases/tag/v4.2.1) (2025-09-22)
### CuTe DSL
* Bug fixings and improvements
- Fixed an issue when running DSL codes with cuda-python 13.0
- Fixed an issue when running inductor with DSL codes
- Fixed an issue with unexpected logging when running DSL codes in FlashInfer
- Fixed the issue reported in https://github.com/NVIDIA/cutlass/issues/2647
- Fixed an issue when conditional define of variables outside of dynamic control flow
### CUTLASS C++
* Bypass EVT for nosmem blockwise kernels on Blackwell.
* Rename cutlass/python/cutlass directory to cutlass/python/cutlass_cppgen.
## [4.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.2.0) (2025-09-15)
### CuTe DSL

View File

@ -1,9 +1,9 @@
![ALT](./media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition")
# Overview
# CUTLASS 4.2.0
# CUTLASS 4.2.1
_CUTLASS 4.2.0 - Sept 2025_
_CUTLASS 4.2.1 - Sept 2025_
CUTLASS is a collection of abstractions for implementing high-performance matrix-matrix multiplication (GEMM)
and related computations at all levels and scales within CUDA. It incorporates strategies for
@ -224,7 +224,10 @@ CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be
|NVIDIA H100 Tensor Core GPU |9.0|11.8|
|NVIDIA H200 Tensor Core GPU |9.0|11.8|
|NVIDIA B200 Tensor Core GPU |10.0|12.8|
|NVIDIA B300 Tensor Core GPU |10.3|13.0|
|NVIDIA DRIVE Thor |11.0|13.0|
|NVIDIA GeForce RTX 50x0 series |12.0|12.8|
|NVIDIA DGX Spark |12.1|13.0|
## Target Architecture

View File

@ -39,7 +39,8 @@ set_property(
)
set(TEST_BASIC --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=no)
set(TEST_CAUSAL --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=causal)
set(TEST_CAUSAL_00 --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=causal)
set(TEST_CAUSAL_01 --verify --iterations=0 --b=1 --h=1 --h_k=1 --q=1013 --k=1024 --d=128 --mask=causal --causal-type=qend)
set(TEST_VARLEN --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=residual --varlen)
set(TEST_HDIM64 --b=2 --h=4 --q=512 --k=512 --d=64 --verify)
set(TEST_GQA --b=2 --h=4 --h_k=2 --q=512 --k=512 --d=64 --verify)
@ -119,7 +120,8 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
77_blackwell_fmha.cu
TEST_COMMAND_OPTIONS
TEST_BASIC
TEST_CAUSAL
TEST_CAUSAL_00
TEST_CAUSAL_01
TEST_VARLEN
TEST_HDIM64
TEST_GQA
@ -222,7 +224,7 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
77_blackwell_mla_fwd.cu
TEST_COMMAND_OPTIONS
TEST_BASIC
TEST_CAUSAL
TEST_CAUSAL_00
TEST_VARLEN
TEST_HDIM64
TEST_GQA

View File

@ -225,8 +225,8 @@ struct CausalMask : NoMask {
if constexpr (IsQBegin) {
return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));
} else {
const int corner_count = int((get<1>(problem_size) % get<1>(tile_shape) || get<0>(problem_size) % get<0>(tile_shape))) ;
return std::min(trip_count, int(ceil_div(get<0>(tile_shape), get<1>(tile_shape))) + corner_count);
const int offset_tile_q = (get<1>(problem_size) - get<0>(problem_size)) % get<1>(tile_shape);
return std::min(trip_count, int(ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape))));
}
}

View File

@ -1,314 +0,0 @@
import os
import torch
import argparse
from cuda import cuda
from cuda.bindings import driver
from typing import Type
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem
import torch.multiprocessing as mp
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
from cutlass._mlir.dialects import llvm, builtin, vector, arith
WORLD_SIZE = 8
PING_PONG_SIZE = 3
def setup(rank, world_size):
# set environment variables for torch.distributed environment
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12959"
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
class AllReduceKernel:
@cute.jit
def __call__(
self,
rank,
signal,
local_input: cute.Tensor,
local_output: cute.Tensor,
buffer0: cute.Tensor,
buffer1: cute.Tensor,
buffer2: cute.Tensor,
buffer3: cute.Tensor,
buffer4: cute.Tensor,
buffer5: cute.Tensor,
buffer6: cute.Tensor,
buffer7: cute.Tensor,
stream: cuda.CUstream,
):
# define constants for future use
num_of_elements = cute.size(local_input.layout)
# 128 threads per block and 4 elements per thread
tv_layout = cute.make_layout(((128), (4)), stride=((1), (1)))
tile = cute.size(tv_layout.shape)
buffers = [
buffer0,
buffer1,
buffer2,
buffer3,
buffer4,
buffer5,
buffer6,
buffer7,
]
tiled_buffers = [
cute.logical_divide(buffer, (tile, None, None)) for buffer in buffers
]
tiled_input = cute.zipped_divide(local_input, cute.make_layout(tile))
tiled_output = cute.zipped_divide(local_output, cute.make_layout(tile))
self.kernel(
tiled_buffers[0],
tiled_buffers[1],
tiled_buffers[2],
tiled_buffers[3],
tiled_buffers[4],
tiled_buffers[5],
tiled_buffers[6],
tiled_buffers[7],
tiled_input,
tiled_output,
tv_layout,
cutlass.Int32(signal),
cutlass.Int32(rank),
).launch(
grid=[num_of_elements // tile, 1, 1],
block=[tv_layout.shape[0], 1, 1],
stream=stream,
)
# GPU device kernel
@cute.kernel
def kernel(
self,
buffer0: cute.Tensor,
buffer1: cute.Tensor,
buffer2: cute.Tensor,
buffer3: cute.Tensor,
buffer4: cute.Tensor,
buffer5: cute.Tensor,
buffer6: cute.Tensor,
buffer7: cute.Tensor,
local_input: cute.Tensor,
local_output: cute.Tensor,
tv_layout: cute.Layout,
signal: cutlass.Int32,
rank: cutlass.Int32,
):
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
ping = signal % 3
pong = (signal + 1) % 3
buffers = [
buffer0,
buffer1,
buffer2,
buffer3,
buffer4,
buffer5,
buffer6,
buffer7,
]
def get_buffer():
t = buffers[2]
if rank == cutlass.Int32(0):
t = buffers[0]
if rank == cutlass.Int32(1):
t = buffers[1]
if rank == cutlass.Int32(2):
t = buffers[2]
if rank == cutlass.Int32(3):
t = buffers[3]
if rank == cutlass.Int32(4):
t = buffers[4]
if rank == cutlass.Int32(5):
t = buffers[5]
if rank == cutlass.Int32(6):
t = buffers[6]
if rank == cutlass.Int32(7):
t = buffers[7]
return t
buffer_local = get_buffer()
cta_coord = (None, bidx)
local_tile_in = local_input[cta_coord]
local_tile_out = local_output[cta_coord]
ping_coord = ((None, bidx), None, ping)
read_buffer = buffer_local[ping_coord]
pong_coord = ((None, bidx), None, pong)
clear_buffer = buffer_local[pong_coord]
write_coord = ((None, bidx), rank, ping)
write_buffers = [buffer[write_coord] for buffer in buffers]
# assume all buffers have the same element type with input
copy_atom_load = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
buffer0.element_type,
num_bits_per_copy=64,
memory_scope=cute.nvgpu.common.MemoryScope.SYS,
memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE,
)
copy_atom_store = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
buffer0.element_type,
num_bits_per_copy=64,
memory_scope=cute.nvgpu.common.MemoryScope.SYS,
memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE,
)
tiled_copy = cute.make_tiled_copy_tv(copy_atom_load, tv_layout[0], tv_layout[1])
thr_copy = tiled_copy.get_slice(tidx)
thr_write_buffer_list = [
thr_copy.partition_D(tensor) for tensor in write_buffers
]
thr_read_buffer = thr_copy.partition_S(read_buffer)
thr_clear_buffer = thr_copy.partition_D(clear_buffer)
thr_in = thr_copy.partition_S(local_tile_in)
thr_out = thr_copy.partition_D(local_tile_out)
frg_in = cute.make_fragment_like(thr_in)
frg_clear = cute.make_fragment_like(thr_clear_buffer)
frg_acc = cute.make_fragment_like(thr_out)
frg_acc.fill(0.0)
clear_tensor = frg_clear.load()
frg_size = cute.size(clear_tensor.shape)
neg0_i32_vec = cute.full_like(clear_tensor, 0x80000000, cutlass.Int32)
neg0_f32_vec = vector.bitcast(T.vector(frg_size, T.f32()), neg0_i32_vec)
neg0_f32_tensor = cute.TensorSSA(
neg0_f32_vec, clear_tensor.shape, cutlass.Float32
)
frg_clear.store(neg0_f32_tensor)
cute.copy(copy_atom_load, thr_in, frg_in)
for thr_write_buffer in thr_write_buffer_list:
cute.copy(copy_atom_store, frg_in, thr_write_buffer)
cute.copy(copy_atom_store, frg_clear, thr_clear_buffer)
frg_in_vector_neg0_i32 = cute.full_like(
frg_in, cutlass.Int32(0x80000000), cutlass.Int32
)
frg_in_size = cute.size(frg_in.shape)
for i in range(WORLD_SIZE):
read_coord = (None, 0, i)
cute.copy(copy_atom_load, thr_read_buffer[read_coord], frg_in[None, 0])
frg_vector = frg_in.load()
frg_vector_i32 = vector.bitcast(T.vector(frg_in_size, T.i32()), frg_vector)
isNotNeg0 = cute.all_(frg_vector_i32 != frg_in_vector_neg0_i32)
while not isNotNeg0:
cute.copy(copy_atom_load, thr_read_buffer[read_coord], frg_in[None, 0])
frg_vector = frg_in.load()
frg_vector_i32 = vector.bitcast(
T.vector(frg_in_size, T.i32()), frg_vector
)
isNotNeg0 = cute.all_(frg_vector_i32 != frg_in_vector_neg0_i32)
frg_acc.store(frg_in.load() + frg_acc.load())
cute.copy(copy_atom_stg, frg_acc, thr_out)
def run_all_reduce(rank, M, N, dtype: Type[cutlass.Numeric]):
setup(rank, WORLD_SIZE)
input_tensor = torch.randn(M * N, device=f"cuda:{rank}")
output_tensor = torch.zeros(M * N, device=f"cuda:{rank}")
# init tensors on different devices
t = symm_mem.empty(
[
PING_PONG_SIZE,
WORLD_SIZE,
M * N,
],
device="cuda",
).neg_()
hdl = symm_mem.rendezvous(t, dist.group.WORLD)
buffer_tensor_list = [
hdl.get_buffer(rank, t.shape, t.dtype).permute(2, 1, 0)
for rank in range(WORLD_SIZE)
]
# enable peer access
driver.cuInit(0)
dev_list = [driver.cuDeviceGet(i)[1] for i in range(WORLD_SIZE)]
ctx_list = [driver.cuDevicePrimaryCtxRetain(dev)[1] for dev in dev_list]
for i in range(WORLD_SIZE):
driver.cuCtxSetCurrent(ctx_list[i])
for j in range(WORLD_SIZE):
if i == j:
continue
driver.cuCtxEnablePeerAccess(ctx_list[j], 0)
driver.cuCtxSetCurrent(ctx_list[rank])
stream = cutlass.cuda.default_stream()
all_reduce_kernel = AllReduceKernel()
dlpack_buffers = [from_dlpack(x, assumed_align=32) for x in buffer_tensor_list]
all_reduce_kernel(
rank,
0,
from_dlpack(input_tensor, assumed_align=32),
from_dlpack(output_tensor, assumed_align=32),
*dlpack_buffers,
stream,
)
torch.cuda.synchronize(0)
# use torch api to get reference and inplace stored to input_tensor
ref_tensor = input_tensor.clone()
dist.all_reduce(ref_tensor, op=dist.ReduceOp.SUM)
# check result of output tensor, allow small error due to different accumulator datatypes
equal_mask = (ref_tensor.cpu() - output_tensor.cpu()).abs() < 1e-4
result = (equal_mask.sum()).item() == ref_tensor.numel()
if result:
print(f"rank {rank} test passed")
else:
print(f"rank {rank} test failed")
print(
"ref_tensor[ref_tensor != output_tensor]: ",
ref_tensor[ref_tensor != output_tensor],
)
print(
"output_tensor[ref_tensor != output_tensor]: ",
output_tensor[ref_tensor != output_tensor],
)
cleanup()
def main():
M = 1024
N = 1024
# each process will run run_all_reduce on different device
mp.spawn(run_all_reduce, args=(M, N, cutlass.Float32), nprocs=WORLD_SIZE, join=True)
return
if __name__ == "__main__":
main()

View File

@ -1,189 +0,0 @@
import os
import torch
import argparse
from typing import Type
from cuda.bindings import driver
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem
import torch.multiprocessing as mp
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
def setup(rank, world_size):
# set environment variables for torch.distributed environment
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12995"
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
@cute.kernel
def vector_add_kernel(
g0: cute.Tensor,
g1: cute.Tensor,
g2: cute.Tensor,
g3: cute.Tensor,
g4: cute.Tensor,
g5: cute.Tensor,
g6: cute.Tensor,
g7: cute.Tensor,
gOut: cute.Tensor,
tv_layout: cute.Layout,
):
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
cta_coord = (None, bidx)
local_tile_out = gOut[cta_coord]
local_tile_list = [
g0[cta_coord],
g1[cta_coord],
g2[cta_coord],
g3[cta_coord],
g4[cta_coord],
g5[cta_coord],
g6[cta_coord],
g7[cta_coord],
]
copy_atom_load = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
g0.element_type,
memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE,
memory_scope=cute.nvgpu.common.MemoryScope.SYS,
)
copy_atom_store = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
g0.element_type,
memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE,
memory_scope=cute.nvgpu.common.MemoryScope.SYS,
)
tiled_copy = cute.make_tiled_copy_tv(copy_atom_load, tv_layout[0], tv_layout[1])
thr_copy = tiled_copy.get_slice(tidx)
thr_tensor_list = [thr_copy.partition_S(tensor) for tensor in local_tile_list]
thr_out = thr_copy.partition_D(local_tile_out)
frg_tensor_list = [cute.make_fragment_like(tensor) for tensor in thr_tensor_list]
frg_acc = cute.make_fragment_like(thr_out)
frg_acc.fill(0.0)
for thr, frg in zip(thr_tensor_list, frg_tensor_list):
cute.copy(copy_atom_load, thr, frg)
tmp = frg.load() + frg_acc.load()
frg_acc.store(tmp)
cute.copy(copy_atom_store, frg_acc, thr_out)
@cute.jit
def vector_add(
m0: cute.Tensor,
m1: cute.Tensor,
m2: cute.Tensor,
m3: cute.Tensor,
m4: cute.Tensor,
m5: cute.Tensor,
m6: cute.Tensor,
m7: cute.Tensor,
output: cute.Tensor,
):
# define constants for future use
num_of_elements = cute.size(m0.layout)
# 128 threads per block and 4 elements per thread
tv_layout = cute.make_layout(((128), (4)), stride=((1), (1)))
tile = cute.size(tv_layout.shape)
tensors = [m0, m1, m2, m3, m4, m5, m6, m7]
divided_tensors = [
cute.zipped_divide(tensor, cute.make_layout(tile)) for tensor in tensors
]
gOut = cute.zipped_divide(output, cute.make_layout(tile)) # ((Tile),(Rest))
vector_add_kernel(
divided_tensors[0],
divided_tensors[1],
divided_tensors[2],
divided_tensors[3],
divided_tensors[4],
divided_tensors[5],
divided_tensors[6],
divided_tensors[7],
gOut,
tv_layout,
).launch(
grid=[num_of_elements // tile, 1, 1],
block=[tv_layout.shape[0], 1, 1],
)
def run_vector_add(rank, world_size, M, N, dtype: Type[cutlass.Numeric]):
setup(rank, world_size)
t = symm_mem.empty(M * N, device="cuda")
hdl = symm_mem.rendezvous(t, dist.group.WORLD)
# get tensors from other devices from the symmetric memory
tensor_list = [hdl.get_buffer(rank, t.shape, t.dtype) for rank in range(world_size)]
tensor_list[rank].random_(0, 100)
# enable peer access
driver.cuInit(0)
dev_list = [driver.cuDeviceGet(i)[1] for i in range(world_size)]
ctx_list = [driver.cuDevicePrimaryCtxRetain(dev)[1] for dev in dev_list]
driver.cuCtxSetCurrent(ctx_list[rank])
for i in range(world_size):
if i == rank:
continue
driver.cuCtxEnablePeerAccess(ctx_list[i], 0)
output = torch.zeros(M * N, device=f"cuda:{rank}")
# we have to explicitly pass each tensor instead of a list of tensors
vector_add(
from_dlpack(tensor_list[0], assumed_align=32),
from_dlpack(tensor_list[1], assumed_align=32),
from_dlpack(tensor_list[2], assumed_align=32),
from_dlpack(tensor_list[3], assumed_align=32),
from_dlpack(tensor_list[4], assumed_align=32),
from_dlpack(tensor_list[5], assumed_align=32),
from_dlpack(tensor_list[6], assumed_align=32),
from_dlpack(tensor_list[7], assumed_align=32),
from_dlpack(output, assumed_align=32),
)
sum_tensor = sum([tensor.cpu() for tensor in tensor_list])
if sum(sum_tensor.cpu() == output.cpu()) == sum_tensor.numel():
print("test passed")
else:
print("test failed")
print(sum_tensor.cpu())
print(output.cpu())
cleanup()
def main():
world_size = torch.cuda.device_count()
M = 1024
N = 1024
# each process will run run_vector_add on different device
mp.spawn(
run_vector_add,
args=(world_size, M, N, cutlass.Float32),
nprocs=world_size,
join=True,
)
return
if __name__ == "__main__":
main()

View File

@ -1435,8 +1435,12 @@ private:
is_same_v<FastF32NoSmemWarpSpecialized2Sm, EpilogueScheduleType> ||
is_same_v<PtrArrayFastF32NoSmemWarpSpecialized1Sm, EpilogueScheduleType> ||
is_same_v<PtrArrayFastF32NoSmemWarpSpecialized2Sm, EpilogueScheduleType>;
// Input transform kernels - when dispatching to sm100 nosmem epilogue, go through the default path without EVT support.
static constexpr bool IsInputTransformSchedule = IsInterleavedComplex || IsFastF32Schedule;
static constexpr bool IsBlockwiseSchedule = is_same_v<BlockwiseNoSmemWarpSpecialized1Sm, EpilogueScheduleType> ||
is_same_v<BlockwiseNoSmemWarpSpecialized2Sm, EpilogueScheduleType> ||
is_same_v<PtrArrayBlockwiseNoSmemWarpSpecialized1Sm, EpilogueScheduleType> ||
is_same_v<PtrArrayBlockwiseNoSmemWarpSpecialized2Sm, EpilogueScheduleType>;
// Transform kernels - when dispatching to sm100 nosmem epilogue, go through the default path without EVT support.
static constexpr bool IsTransformSchedule = IsInterleavedComplex || IsFastF32Schedule || IsBlockwiseSchedule;
static_assert(Is1SmMma ^ Is2SmMma, "unsupported schedule");
static_assert(not (Is2SmMma && size<0>(ClusterShape_MNK{}) % 2 == 1), "schedule + cluster mismatch");
@ -1470,7 +1474,7 @@ private:
static_assert(is_tuple_v<EpilogueTileType>, "Shape or Tile");
return EpilogueTileType{};
}
else if constexpr (is_same_v<OpClass,arch::OpClassBlockScaledTensorOp> || not IsInputTransformSchedule) {
else if constexpr (is_same_v<OpClass,arch::OpClassBlockScaledTensorOp> || not IsTransformSchedule) {
// Save register usage for sm103 blockscaled kernels and sm100 cpasync kernels
// to avoid register spilling.
constexpr int EpiM = size<0>(CtaTileShape_MNK{});
@ -1501,7 +1505,7 @@ private:
DisableSource ? thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default;
if constexpr (IsDefaultFusionOp<FusionOp>::value &&\
not is_same_v<OpClass, arch::OpClassBlockScaledTensorOp> && \
(IsInputTransformSchedule || \
(IsTransformSchedule || \
is_same_v<EpilogueScheduleType, PtrArrayNoSmemWarpSpecialized1Sm> || \
is_same_v<EpilogueScheduleType, PtrArrayNoSmemWarpSpecialized2Sm>)
) {

View File

@ -63,10 +63,14 @@ struct NoSmemWarpSpecialized1Sm {};
struct NoSmemWarpSpecialized2Sm {};
struct FastF32NoSmemWarpSpecialized1Sm : NoSmemWarpSpecialized1Sm {};
struct FastF32NoSmemWarpSpecialized2Sm : NoSmemWarpSpecialized2Sm {};
struct BlockwiseNoSmemWarpSpecialized1Sm : NoSmemWarpSpecialized1Sm {};
struct BlockwiseNoSmemWarpSpecialized2Sm : NoSmemWarpSpecialized2Sm {};
struct PtrArrayNoSmemWarpSpecialized1Sm : NoSmemWarpSpecialized1Sm {};
struct PtrArrayNoSmemWarpSpecialized2Sm : NoSmemWarpSpecialized2Sm {};
struct PtrArrayFastF32NoSmemWarpSpecialized1Sm : PtrArrayNoSmemWarpSpecialized1Sm {};
struct PtrArrayFastF32NoSmemWarpSpecialized2Sm : PtrArrayNoSmemWarpSpecialized2Sm {};
struct PtrArrayBlockwiseNoSmemWarpSpecialized1Sm : PtrArrayNoSmemWarpSpecialized1Sm {};
struct PtrArrayBlockwiseNoSmemWarpSpecialized2Sm : PtrArrayNoSmemWarpSpecialized2Sm {};
// Blackwell TMA schedules
struct TmaWarpSpecialized1Sm {};
struct TmaWarpSpecialized2Sm {};

View File

@ -36,7 +36,7 @@
#define CUTLASS_MAJOR 4
#define CUTLASS_MINOR 2
#define CUTLASS_PATCH 0
#define CUTLASS_PATCH 1
#ifdef CUTLASS_VERSIONS_GENERATED
#include "cutlass/version_extended.h"

View File

@ -151,7 +151,7 @@ For example,
* `(3,6,2,8) / 9 => (1,2,2,8)`
* `(3,6,2,8) / 72 => (1,1,1,4)`
To compute the strides of the strided layout, the residues of the above operation are used to scale the strides of `A`. For instance, the last example `(3,6,2,8):(w,x,y,z) / 72` with strides `(w,x,y,z)` produces `(3*w,6*x,2*x,2*z)` as the strides of the strided layout.
To compute the strides of the strided layout, the residues of the above operation are used to scale the strides of `A`. For instance, the last example `(3,6,2,8):(w,x,y,z) / 72` with strides `(w,x,y,z)` produces `(72*w,24*x,4*y,2*z)` as the strides of the strided layout.
As you may have noticed, we can only divide shapes by certain values and get a sensible result. This is called the **stride divisibility condition** and is statically checked in CuTe when possible.
@ -171,7 +171,7 @@ This operation causes the result to have a shape that is compatible with `B`.
Again, this operation must satisfy a **shape divisibility condition** to yield a sensible result and is statically checked in CuTe when possible.
From the above examples, we can construct the composition `(3,6,2,8):(w,x,y,z) o 16:9 = (1,2,2,4):(3*w,3*x,y,z)`.
From the above examples, we can construct the composition `(3,6,2,8):(w,x,y,z) o 16:9 = (1,2,2,4):(9*w,3*x,y,z)`.
---
#### Example 1 -- Worked Example of Calculating a Composition

View File

@ -217,7 +217,7 @@ for (int i = 0; i < A.size(); ++i)
## Tiling a Tensor
Many of the [`Layout` algebra operations](https://github.com/NVIDIA/cutlass/blob/main/media/docs/cute/02_layout_algebra.md) can also be applied to `Tensor`.
Many of the [`Layout` algebra operations](https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/cute/02_layout_algebra.md) can also be applied to `Tensor`.
```cpp
composition(Tensor, Tiler)
logical_divide(Tensor, Tiler)

View File

@ -90,7 +90,7 @@ before issuing other instructions associated with a particular pipeline stage
(e.g., copy or write).
This is a blocking instruction
which blocks further execution of consumer threads
which blocks further execution of producer threads
unless the particular stage waiting to be acquired
is released by a consumer.

View File

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "nvidia-cutlass"
version = "4.2.0.0"
version = "4.2.1.0"
description = "CUTLASS"
readme = "README.md"
requires-python = ">=3.8"

View File

@ -20,6 +20,7 @@ import warnings
import inspect
from types import BuiltinFunctionType
from functools import lru_cache
from inspect import getmembers
from .utils.logger import log
from .common import *
@ -579,3 +580,37 @@ def redirect_builtin_function(fcn):
if isinstance(fcn, BuiltinFunctionType) and executor._builtin_redirector:
return executor._builtin_redirector(fcn)
return fcn
def copy_members(dest, src):
"""
Copies all non-callable, non-dunder members from src to dest if they exist in src.
Skips members that are callables or have names starting with double underscores.
"""
if id(dest) == id(src):
return
members = getmembers(dest)
for name, value in members:
if (
name.startswith("__")
or isinstance(value, Callable)
or not hasattr(src, name)
):
continue
setattr(dest, name, getattr(src, name))
def get_locals_or_none(locals, symbols):
"""
Given a locals() dictionary and a list of symbol names, return a list of their values
in the same order as the symbols list. If a symbol is not present in locals, None is returned
for that symbol.
"""
variables = []
for symbol in symbols:
if symbol in locals:
variables.append(locals[symbol])
else:
variables.append(None)
return variables

View File

@ -668,12 +668,7 @@ class DSLPreprocessor(ast.NodeTransformer):
ast.keyword(arg="prefetch_stages", value=prefetch_stages),
ast.keyword(
arg="write_args",
value=ast.List(
elts=[
ast.Name(id=arg, ctx=ast.Load()) for arg in write_args
],
ctx=ast.Load(),
),
value=self.generate_get_locals_or_none_call(write_args),
),
ast.keyword(
arg="full_write_args_count",
@ -707,28 +702,6 @@ class DSLPreprocessor(ast.NodeTransformer):
node,
)
def create_loop_call(self, func_name, iter_args):
"""
Assigns the returned value from the loop function directly (without a tuple unpacking).
"""
if len(iter_args) == 0:
return ast.Expr(value=ast.Name(id=func_name, ctx=ast.Load()))
elif len(iter_args) == 1:
return ast.Assign(
targets=[ast.Name(id=iter_args[0], ctx=ast.Store())],
value=ast.Name(id=func_name, ctx=ast.Load()),
)
else:
return ast.Assign(
targets=[
ast.Tuple(
elts=[ast.Name(id=var, ctx=ast.Store()) for var in iter_args],
ctx=ast.Store(),
)
],
value=ast.Name(id=func_name, ctx=ast.Load()),
)
def visit_BoolOp(self, node):
# Visit child nodes first
self.generic_visit(node)
@ -1140,10 +1113,10 @@ class DSLPreprocessor(ast.NodeTransformer):
full_write_args_count,
)
assign = ast.copy_location(self.create_loop_call(func_name, write_args), node)
assign = self.create_cf_call(func_name, write_args, node)
# This should work fine as it modifies the AST structure
exprs = exprs + [func_def, assign]
exprs = exprs + [func_def] + assign
if target_var_is_active_before_loop:
# Create a new assignment to the target variable
@ -1429,11 +1402,9 @@ class DSLPreprocessor(ast.NodeTransformer):
func_def = self.create_while_function(
func_name, node, write_args, full_write_args_count
)
assign = ast.copy_location(
self.create_loop_call(func_name, write_args), node
)
assign = self.create_cf_call(func_name, write_args, node)
return [func_def, assign]
return [func_def] + assign
def visit_Try(self, node):
with self.scope_manager:
@ -1447,17 +1418,27 @@ class DSLPreprocessor(ast.NodeTransformer):
self.generic_visit(node)
return node
def create_if_call(self, func_name, yield_args):
def create_cf_call(self, func_name, yield_args, node):
"""Creates the assignment statement for the if function call"""
if not yield_args:
return ast.Expr(value=ast.Name(id=func_name, ctx=ast.Load()))
elif len(yield_args) == 1:
return ast.Assign(
return [
ast.copy_location(
ast.Expr(value=ast.Name(id=func_name, ctx=ast.Load())), node
)
]
has_self = False
for i, arg in enumerate(yield_args):
if arg == "self":
has_self = True
yield_args[i] = "yield_self"
break
if len(yield_args) == 1:
assign = ast.Assign(
targets=[ast.Name(id=yield_args[0], ctx=ast.Store())],
value=ast.Name(id=func_name, ctx=ast.Load()),
)
else:
return ast.Assign(
assign = ast.Assign(
targets=[
ast.Tuple(
elts=[ast.Name(id=var, ctx=ast.Store()) for var in yield_args],
@ -1467,6 +1448,23 @@ class DSLPreprocessor(ast.NodeTransformer):
value=ast.Name(id=func_name, ctx=ast.Load()),
)
if has_self:
fix_self = ast.Expr(
value=ast.Call(
func=self._create_module_attribute(
"copy_members", lineno=node.lineno, col_offset=node.col_offset
),
args=[
ast.Name(id="self", ctx=ast.Load()),
ast.Name(id="yield_self", ctx=ast.Load()),
],
keywords=[],
)
)
return [ast.copy_location(assign, node), ast.copy_location(fix_self, node)]
else:
return [ast.copy_location(assign, node)]
def visit_IfExp(self, node):
"""
Visits an inline if-else expression (ternary operator).
@ -1567,9 +1565,24 @@ class DSLPreprocessor(ast.NodeTransformer):
func_def = self.create_if_function(
func_name, node, yield_args, full_write_args_count
)
assign = ast.copy_location(self.create_if_call(func_name, yield_args), node)
assign = self.create_cf_call(func_name, yield_args, node)
return [func_def, assign]
return [func_def] + assign
def generate_get_locals_or_none_call(self, write_args):
return ast.Call(
func=self._create_module_attribute("get_locals_or_none"),
args=[
ast.Call(
func=ast.Name(id="locals", ctx=ast.Load()), args=[], keywords=[]
),
ast.List(
elts=[ast.Constant(value=arg) for arg in write_args],
ctx=ast.Load(),
),
],
keywords=[],
)
def create_if_function(self, func_name, node, write_args, full_write_args_count):
test_expr = self.visit(node.test)
@ -1627,10 +1640,7 @@ class DSLPreprocessor(ast.NodeTransformer):
), # ast.Name(id="pred", ctx=ast.Load())
ast.keyword(
arg="write_args",
value=ast.List(
elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in write_args],
ctx=ast.Load(),
),
value=self.generate_get_locals_or_none_call(write_args),
),
]
@ -1813,10 +1823,7 @@ class DSLPreprocessor(ast.NodeTransformer):
ast.keyword(arg="pred", value=test_expr),
ast.keyword(
arg="write_args",
value=ast.List(
elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in write_args],
ctx=ast.Load(),
),
value=self.generate_get_locals_or_none_call(write_args),
),
]
decorator = ast.copy_location(

View File

@ -255,7 +255,13 @@ def initialize_cuda_context(device_id: int = 0, flags: int = 0):
_log().info(f"{cuDevice} <-- cuDeviceGet")
# Create context
_log().info(f"cuCtxCreate {0} {cuDevice}")
context = checkCudaErrors(cuda.cuCtxCreate(0, cuDevice))
if cuda.CUDA_VERSION >= 13000:
# Use cuCtxCreate_v4 API with explicit CUctxCreateParams None, since v2
# and v3 API has been removed from CTK 13.
# See https://github.com/NVIDIA/cuda-python/pull/792
context = checkCudaErrors(cuda.cuCtxCreate(None, 0, cuDevice))
else:
context = checkCudaErrors(cuda.cuCtxCreate(0, cuDevice))
_log().info(f"{context} <-- cuCtxCreate")
return context

View File

@ -47,7 +47,8 @@ def setup_log(
if log_to_console or log_to_file:
logger.setLevel(log_level)
else:
logger.setLevel(logging.NOTSET)
# Makes sure logging is OFF
logger.setLevel(logging.CRITICAL + 1)
# Clear existing handlers to prevent duplicate logs
if logger.hasHandlers():

View File

@ -55,3 +55,5 @@ LaunchConfig = _dsl.BaseDSL.LaunchConfig
register_jit_arg_adapter = _dsl.JitArgAdapterRegistry.register_jit_arg_adapter
gpu = _dsl.cutlass_gpu
cuda = _dsl.cuda_helpers
CACHE_FILE = "compiled_cache.db"

View File

@ -31,6 +31,8 @@ from ..base_dsl.ast_helpers import (
range_perf_warning,
cf_symbol_check,
redirect_builtin_function,
copy_members,
get_locals_or_none,
)
from ..base_dsl import *

View File

@ -1,3 +1,3 @@
# Use `pip install -r requirements.txt` with the present file to install a
# wheel consistent with the present state of the github repository
nvidia-cutlass-dsl==4.2.0
nvidia-cutlass-dsl==4.2.1

View File

@ -133,7 +133,7 @@ def get_option_registry():
this._option_registry = OptionRegistry(device_cc())
return this._option_registry
this.__version__ = '4.2.0'
this.__version__ = '4.2.1'
from cutlass_cppgen.backend import create_memory_pool
from cutlass_cppgen.emit.pytorch import pytorch

View File

@ -7467,8 +7467,9 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version,
kernel_schedule = to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100, grouped)
epi_schedule = to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)
epi_schedule_nosmem = to_grouped_schedule(EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm, grouped)
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
[[kernel_schedule, epi_schedule]],
[[kernel_schedule, epi_schedule], [kernel_schedule, epi_schedule_nosmem]],
tile_schedulers=tile_schedulers, gemm_kind=gemm_kind)
def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.Universal3x):

View File

@ -811,10 +811,14 @@ class EpilogueScheduleType(enum.Enum):
NoSmemWarpSpecialized2Sm = enum_auto()
FastF32NoSmemWarpSpecialized1Sm = enum_auto()
FastF32NoSmemWarpSpecialized2Sm = enum_auto()
BlockwiseNoSmemWarpSpecialized1Sm = enum_auto()
BlockwiseNoSmemWarpSpecialized2Sm = enum_auto()
PtrArrayNoSmemWarpSpecialized1Sm = enum_auto()
PtrArrayNoSmemWarpSpecialized2Sm = enum_auto()
PtrArrayFastF32NoSmemWarpSpecialized1Sm = enum_auto()
PtrArrayFastF32NoSmemWarpSpecialized2Sm = enum_auto()
PtrArrayBlockwiseNoSmemWarpSpecialized1Sm = enum_auto()
PtrArrayBlockwiseNoSmemWarpSpecialized2Sm = enum_auto()
TmaWarpSpecialized = enum_auto()
TmaWarpSpecializedCooperative = enum_auto()
TmaWarpSpecialized1Sm = enum_auto()
@ -834,10 +838,14 @@ EpilogueScheduleTag = {
EpilogueScheduleType.NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::NoSmemWarpSpecialized2Sm',
EpilogueScheduleType.FastF32NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::FastF32NoSmemWarpSpecialized1Sm',
EpilogueScheduleType.FastF32NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::FastF32NoSmemWarpSpecialized2Sm',
EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm',
EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::BlockwiseNoSmemWarpSpecialized2Sm',
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized1Sm',
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized2Sm',
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayFastF32NoSmemWarpSpecialized1Sm',
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayFastF32NoSmemWarpSpecialized2Sm',
EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayBlockwiseNoSmemWarpSpecialized1Sm',
EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayBlockwiseNoSmemWarpSpecialized2Sm',
EpilogueScheduleType.TmaWarpSpecialized: 'cutlass::epilogue::TmaWarpSpecialized',
EpilogueScheduleType.TmaWarpSpecializedCooperative: 'cutlass::epilogue::TmaWarpSpecializedCooperative',
EpilogueScheduleType.TmaWarpSpecialized1Sm: 'cutlass::epilogue::TmaWarpSpecialized1Sm',
@ -858,10 +866,14 @@ EpilogueScheduleSuffixes = {
EpilogueScheduleType.NoSmemWarpSpecialized2Sm: '_epi_nosmem',
EpilogueScheduleType.FastF32NoSmemWarpSpecialized1Sm: '_epi_nosmem_fastf32',
EpilogueScheduleType.FastF32NoSmemWarpSpecialized2Sm: '_epi_nosmem_fastf32',
EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm: '_epi_nosmem',
EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized2Sm: '_epi_nosmem',
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: '_epi_nosmem',
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: '_epi_nosmem',
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized1Sm: '_epi_nosmem_fastf32',
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized2Sm: '_epi_nosmem_fastf32',
EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized1Sm: '_epi_nosmem',
EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized2Sm: '_epi_nosmem',
EpilogueScheduleType.TmaWarpSpecialized: '_epi_tma',
EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma',
EpilogueScheduleType.TmaWarpSpecialized1Sm: '',
@ -926,6 +938,8 @@ def to_grouped_schedule(schedule, grouped):
EpilogueScheduleType.TmaWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm,
EpilogueScheduleType.NoSmemWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm,
EpilogueScheduleType.NoSmemWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm,
EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized1Sm,
EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized2Sm,
# SM103
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103,
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103,

View File

@ -36,7 +36,7 @@ from setuptools import setup
def perform_setup():
setup(
name='cutlass_library',
version='4.2.0',
version='4.2.1',
description='CUTLASS library generation scripts',
packages=['cutlass_library']
)

View File

@ -36,7 +36,7 @@ from setuptools import setup
def perform_setup():
setup(
name='pycute',
version='4.2.0',
version='4.2.1',
description='Python implementation of CuTe',
packages=['pycute'],
)

View File

@ -20,7 +20,7 @@ packages =
cutlass_library.source
pycute
package_dir =
cutlass_cppgen=python/cutlass
cutlass_cppgen=python/cutlass_cppgen
cutlass_library=python/cutlass_library
cutlass_library.source=.
pycute=python/pycute

View File

@ -69,6 +69,7 @@ template<cute::UMMA::Major SFAMajor,
int ScaleGranularityN,
int ScaleGranularityK,
bool Is2SM,
bool NoSmemEpilogue,
class LayoutA,
class LayoutB,
class LayoutCD,
@ -77,8 +78,10 @@ template<cute::UMMA::Major SFAMajor,
bool groupwise_test(
Int<ScaleGranularityM>, Int<ScaleGranularityN>, Int<ScaleGranularityK>, C<Is2SM>,
LayoutA, LayoutB, LayoutCD,
MmaTileShape, ClusterShape) {
MmaTileShape, ClusterShape,
C<NoSmemEpilogue>) {
using Epilogue1SM = conditional_t<NoSmemEpilogue, cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm, cutlass::epilogue::TmaWarpSpecialized1Sm>;
using Epilogue2SM = conditional_t<NoSmemEpilogue, cutlass::epilogue::BlockwiseNoSmemWarpSpecialized2Sm, cutlass::epilogue::TmaWarpSpecialized2Sm>;
using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, SFAMajor, SFBMajor>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
@ -90,7 +93,7 @@ bool groupwise_test(
float, float,
cutlass::float_e4m3_t, LayoutCD, 16,
cutlass::float_e4m3_t, LayoutCD, 16,
conditional_t<Is2SM, cutlass::epilogue::TmaWarpSpecialized2Sm, cutlass::epilogue::TmaWarpSpecialized1Sm>
conditional_t<Is2SM, Epilogue2SM, Epilogue1SM>
>::CollectiveOp;
using CollectiveMainloop =
@ -259,11 +262,26 @@ TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_1sm_f32_align16_blockwise, 128
cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{},
cutlass::layout::RowMajor{},
Shape<_128,_128,_128>{},
Shape<_1,_1,_1>{});
Shape<_1,_1,_1>{},
false_type{});
EXPECT_TRUE(passed);
}
TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_1sm_f32_align16_blockwise, 128x128x128_1x1x1_2x2x32_scale_direct_store) {
bool passed = groupwise_test<UMMA::Major::MN, UMMA::Major::K>(
Int<2>{}, Int<2>{}, Int<32>{}, false_type{},
cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{},
cutlass::layout::RowMajor{},
Shape<_128,_128,_128>{},
Shape<_1,_1,_1>{},
true_type{});
EXPECT_TRUE(passed);
}
TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_2sm_f32_align16_blockwise, 256x128x128_2x1x1_64x4x32_scale) {
bool passed = groupwise_test<UMMA::Major::MN, UMMA::Major::MN>(
@ -271,7 +289,8 @@ TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_2sm_f32_align16_blockwise, 256
cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{},
cutlass::layout::RowMajor{},
Shape<_256,_128,_128>{},
Shape<_2,_1,_1>{});
Shape<_2,_1,_1>{},
false_type{});
EXPECT_TRUE(passed);
@ -284,7 +303,8 @@ TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_1sm_f32_align16_blockwise, 128
cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{},
cutlass::layout::RowMajor{},
Shape<_128,_128,_128>{},
Shape<_1,_1,_1>{});
Shape<_1,_1,_1>{},
false_type{});
EXPECT_TRUE(passed);
@ -297,7 +317,8 @@ TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_2sm_f32_align16_blockwise, 256
cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{},
cutlass::layout::RowMajor{},
Shape<_256,_128,_128>{},
Shape<_2,_1,_1>{});
Shape<_2,_1,_1>{},
false_type{});
EXPECT_TRUE(passed);
@ -311,7 +332,8 @@ TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_2sm_f32_align16_blockwise, 256
cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{},
cutlass::layout::RowMajor{},
Shape<_256,_128,_128>{},
Shape<_2,_1,_1>{});
Shape<_2,_1,_1>{},
false_type{});
EXPECT_TRUE(passed);