CUTLASS 3.3.0 (#1167)

* Release 3.3.0

Adds support for mixed precision GEMMs On Hopper and Ampere
Adds support for < 16B aligned GEMMs on Hopper
Enhancements to EVT
Enhancements to Python interface
Enhancements to Sub-byte type handling in CuTe
Several other bug-fixes and performance improvements.

* minor doc update
This commit is contained in:
Pradeep Ramani
2023-11-02 08:09:05 -07:00
committed by GitHub
parent 922fb5108b
commit c008b4aea8
263 changed files with 16214 additions and 5008 deletions

View File

@ -28,4 +28,8 @@
if (CUTLASS_ENABLE_GTEST_UNIT_TESTS)
add_subdirectory(unit)
else()
# Always provide at least the phony test_unit target.
add_custom_target(test_unit)
endif()

View File

@ -36,8 +36,9 @@ Utilities for defining Conv2D problem sizes for testing.
This file was ported from the C++ version in test/unit/conv/device/conv2d_problems.h
"""
from cutlass_library import ConvMode
import cutlass
from cutlass import ConvMode
from cutlass.shape import Conv2DProblemSize

View File

@ -34,10 +34,11 @@
Utility functions for Conv2d tests.
"""
from cutlass_library import SubstituteTemplate
import torch
import cutlass
from cutlass import (
from cutlass_library import (
ConvKind,
ConvMode,
DataType,
@ -50,7 +51,6 @@ from cutlass import (
ShortLayoutTypeNames,
SplitKMode,
)
from cutlass.backend.utils.software import SubstituteTemplate
from cutlass.shape import Conv2DProblemSize
from cutlass.utils.datatypes import numpy_type, torch_type
@ -301,17 +301,19 @@ class Conv2dLauncherFrontend:
tensor_B = self.uniform_init(size=tensor_B_size, dtype=self.dtype_B)
tensor_C = self.uniform_init(size=tensor_C_size, dtype=self.dtype_C)
tensor_D = torch.zeros_like(tensor_C).to(memory_format=torch.channels_last)
self.operation.run(tensor_A, tensor_B, tensor_C, tensor_D,
args = self.operation.run(tensor_A, tensor_B, tensor_C, tensor_D,
stride=(ps.stride_h, ps.stride_w),
padding=(ps.pad_h, ps.pad_w),
dilation=(ps.dilation_h, ps.dilation_w),
alpha=alpha, beta=beta,
split_k=(split_k_mode, split_k_slices))
args.sync()
tensor_D_ref = self.reference(ps, tensor_A, tensor_B, tensor_C, alpha, beta, self.activation)
torch.cuda.synchronize()
passed = torch.equal(tensor_D, tensor_D_ref)
passed = torch.allclose(tensor_D, tensor_D_ref, atol=2e-06)
return passed
@ -378,7 +380,8 @@ def add_test(
conv2d_launcher = Conv2dLauncherFrontend(plan, 80, backend="torch")
for ps in problem_sizes:
if not validate_problem_size(ps, conv_kind, split_k_slices): continue
if not validate_problem_size(ps, conv_kind, split_k_slices):
continue
self.assertTrue(conv2d_launcher.run(ps, split_k_mode, split_k_slices, 1.0, 2.0))

View File

@ -38,9 +38,11 @@ import random
import tempfile
import unittest
from cutlass_library import ConvMode
import cutlass
if cutlass.utils.datatypes.torch_available:
if cutlass.utils.datatypes.is_torch_available():
import torch
@ -88,7 +90,7 @@ def _generate_problems(dtype, num):
def _generate_conv2d_problem(conv_kind, dtype, ps):
"""
Utility function to generate conv2d inputs
:param conv_kind: kind of convolution
:type conv_kind: str
:param dtype: data type of tensors
@ -114,7 +116,7 @@ def _generate_conv2d_problem(conv_kind, dtype, ps):
return [torch.ceil(torch.empty(size, dtype=dtype, device='cuda').uniform_(-4.5, 3.5)).to(memory_format=torch.channels_last) for size in sizes]
@unittest.skipIf(not cutlass.utils.datatypes.torch_available, 'PyTorch must be available to run PyTorch extension tests')
@unittest.skipIf(not cutlass.utils.datatypes.is_torch_available(), 'PyTorch must be available to run PyTorch extension tests')
class PyTorchExtensionTest(unittest.TestCase):
def test_gemm(self):
@ -183,18 +185,18 @@ class PyTorchExtensionTest(unittest.TestCase):
Ds_ref = [(a @ b) * alpha + (beta * c) for a, b, c in zip(As, Bs, Cs)]
Ds = mod.run(As, Bs, Cs, alpha, beta)
check_all(Ds, Ds_ref)
def test_conv2d_fprop(self):
torch.manual_seed(2023)
dtype = torch.float16
plan = cutlass.op.Conv2d(kind="fprop", element=dtype, element_accumulator=torch.float32)
plan.activation = "relu"
op = plan.construct()
with tempfile.TemporaryDirectory() as tmpdir:
mod = cutlass.emit.pytorch(op, name="conv2d_mod", cc=plan.cc, sourcedir=tmpdir, jit=True)
problem_size = cutlass.shape.Conv2DProblemSize(
1, 4, 4, 16,
8, 3, 3, 16,
@ -202,50 +204,50 @@ class PyTorchExtensionTest(unittest.TestCase):
3, 3,
1, 1
)
A, B, C = _generate_conv2d_problem("fprop", dtype, problem_size)
stride = (problem_size.stride_h, problem_size.stride_w)
padding = (problem_size.pad_h, problem_size.pad_w)
alpha = 1.0
beta = 0.5
D_ref = alpha * torch.ops.aten.conv2d(
A, B, stride=stride, padding=padding
) + beta * C
D_ref = torch.nn.functional.relu(D_ref)
D = mod.run(A, B, C, stride, padding, alpha=alpha, beta=beta)
assert torch.allclose(D, D_ref)
assert torch.allclose(D, D_ref)
# Test serial split-K
D_serial_split_k = mod.run(A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="serial", split_k_slices=3)
assert torch.allclose(D, D_serial_split_k)
# Test parallel split-K
D_parallel_split_k = mod.run(A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="parallel", split_k_slices=7)
assert torch.allclose(D, D_parallel_split_k)
def test_conv2d_dgrad(self):
torch.manual_seed(2023)
dtype = torch.float16
plan = cutlass.op.Conv2d(kind="dgrad", element=dtype, element_accumulator=torch.float32)
op = plan.construct()
with tempfile.TemporaryDirectory() as tmpdir:
mod = cutlass.emit.pytorch(op, name="conv2d_dgrad_mod", cc=plan.cc, sourcedir=tmpdir, jit=True)
problem_size = cutlass.shape.Conv2DProblemSize(
1, 4, 4, 16,
8, 3, 3, 16,
0, 0,
3, 3,
1, 1,
cutlass.ConvMode.CrossCorrelation,
ConvMode.CrossCorrelation,
1, 1
)
A, B, C = _generate_conv2d_problem("dgrad", dtype, problem_size)
stride = (problem_size.stride_h, problem_size.stride_w)
padding = (problem_size.pad_h, problem_size.pad_w)
@ -254,32 +256,32 @@ class PyTorchExtensionTest(unittest.TestCase):
beta = 0.5
input_size = (problem_size.N, problem_size.C, problem_size.H, problem_size.W)
D_ref = alpha * torch.nn.grad.conv2d_input(
input_size, B, A,
input_size, B, A,
stride=stride, padding=padding
) + beta * C
D = mod.run(input_size, A, B, C, stride, padding, alpha=alpha, beta=beta, )
assert torch.allclose(D, D_ref)
assert torch.allclose(D, D_ref)
def test_conv2d_wgrad(self):
torch.manual_seed(2023)
dtype = torch.float16
plan = cutlass.op.Conv2d(kind="wgrad", element=dtype, element_accumulator=torch.float32)
op = plan.construct()
with tempfile.TemporaryDirectory() as tmpdir:
mod = cutlass.emit.pytorch(op, name="conv2d_wgrad_mod", cc=plan.cc, sourcedir=tmpdir, jit=True)
problem_size = cutlass.shape.Conv2DProblemSize(
1, 4, 4, 16,
8, 3, 3, 16,
0, 0,
3, 3,
1, 1,
cutlass.ConvMode.CrossCorrelation,
ConvMode.CrossCorrelation,
1, 1
)
A, B, C = _generate_conv2d_problem("wgrad", dtype, problem_size)
stride = (problem_size.stride_h, problem_size.stride_w)
padding = (problem_size.pad_h, problem_size.pad_w)
@ -288,17 +290,17 @@ class PyTorchExtensionTest(unittest.TestCase):
beta = 0.5
weight_size = (problem_size.K, problem_size.C, problem_size.R, problem_size.S)
D_ref = alpha * torch.nn.grad.conv2d_weight(
B, weight_size, A,
B, weight_size, A,
stride=stride, padding=padding
) + beta * C
D = mod.run(weight_size, A, B, C, stride, padding, alpha=alpha, beta=beta)
assert torch.allclose(D, D_ref)
assert torch.allclose(D, D_ref)
# Test serial split-K
D_serial_split_k = mod.run(weight_size, A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="serial", split_k_slices=3)
assert torch.allclose(D, D_serial_split_k)
# Test parallel split-K
D_parallel_split_k = mod.run(weight_size, A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="parallel", split_k_slices=7)
assert torch.allclose(D, D_parallel_split_k)

View File

@ -40,9 +40,9 @@ import unittest
import cutlass
from cutlass import Tensor
import cutlass.backend.evt
from cutlass.profiler import CUDAEventProfiler
from cutlass.shape import GemmCoord
from cutlass.utils.datatypes import torch_type
from cutlass.utils.profiler import CUDAEventProfiler
class EVTReferenceModule:

View File

@ -43,7 +43,7 @@ import cutlass
from cutlass.backend.utils.device import device_cc
import torch
from utils import LayoutCombination, add_test_gemm
from utils import LayoutCombination
cutlass.set_log_level(logging.WARNING)

View File

@ -67,58 +67,58 @@ add_test_specialized = partial(add_test_gemm, element=cutlass.DataType.f16, cc=c
# Tests using TensorOp
add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp)
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NTN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NTN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NTT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NTT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TTN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TTN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TTT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TTT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 128, 32], warp_count=[1, 2, 1], stages=3)
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 64, 32], warp_count=[2, 1, 1], stages=3)
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=3)
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass.DataType.f16,
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass.DataType.f16,
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=5)
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[2, 2, 2], element_output=cutlass.DataType.f16,
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[2, 2, 2], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
# Tests using SIMT
add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt)
add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16,
add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16,
add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2)
add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16,
add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2)
add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16,
add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2)
add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.f16,
add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f16, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
# Stream K tests
add_test_streamk = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK)
add_test_streamk(cls=GemmF16Sm80StreamK, layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
add_test_streamk(cls=GemmF16Sm80StreamK, layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
add_test_streamk(cls=GemmF16Sm80StreamK, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
add_test_streamk(cls=GemmF16Sm80StreamK, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=5)
if __name__ == '__main__':

View File

@ -135,6 +135,10 @@ add_test_simt(layouts=LayoutCombination.NTN, element_output=cutlass.DataType.f16
add_test_simt(layouts=LayoutCombination.TTN, element_output=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 8])
add_test_simt(layouts=LayoutCombination.NNT, element_output=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f16, threadblock_shape=[128, 128, 8])
# Tests with void-C kernels
add_test_cluster_shape(layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], stages=None,
cluster_shape=[2, 1, 1], element_C=cutlass.DataType.void)
if __name__ == '__main__':
unittest.main()

View File

@ -68,31 +68,31 @@ add_test_specialized = partial(add_test_gemm, element=cutlass.DataType.f32, cc=c
# Tests using TensorOp
add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp)
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32,
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[4, 4, 4], element_output=cutlass.DataType.f32,
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32,
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 128, 32], warp_count=[1, 2, 1], stages=3)
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32,
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 32], warp_count=[1, 1, 1], stages=4)
# Tests using SIMT
add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt)
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32,
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32,
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2)
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32,
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2)
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32,
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2)
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.f32,
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
# Stream K tests
add_test_streamk = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK)
add_test_streamk(cls=GemmF32Sm80StreamK, layouts=LayoutCombination.TTN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32,
add_test_streamk(cls=GemmF32Sm80StreamK, layouts=LayoutCombination.TTN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)

View File

@ -68,30 +68,30 @@ add_test_specialized = partial(add_test_gemm, element=cutlass.DataType.f64, cc=c
# Tests using TensorOp
add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp)
add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64,
add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3)
add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64,
add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
element_accumulator=cutlass.DataType.f64, threadblock_shape=[ 64, 64, 16], warp_count=[2, 2, 1], stages=4)
add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64,
add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
element_accumulator=cutlass.DataType.f64, threadblock_shape=[ 32, 32, 16], warp_count=[2, 1, 1], stages=5)
# Tests using SIMT
add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt)
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64,
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64,
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
element_accumulator=cutlass.DataType.f64, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2)
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64,
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2)
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64,
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
element_accumulator=cutlass.DataType.f64, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2)
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.f64,
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
# Stream K tests
add_test_streamk = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK)
add_test_streamk(cls=GemmF64Sm80StreamK, layouts=LayoutCombination.NTT, alignments=[1, 1, 1], element_output=cutlass.DataType.f64,
add_test_streamk(cls=GemmF64Sm80StreamK, layouts=LayoutCombination.NTT, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3)

View File

@ -0,0 +1,72 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Low-level functionality tests for GEMM with mixed operands on SM80
"""
from functools import partial
import logging
import unittest
import cutlass
from cutlass.backend.utils.device import device_cc
from utils import LayoutCombination, add_test_gemm
cutlass.set_log_level(logging.WARNING)
cc = 80
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
class GemmMixedSm80(unittest.TestCase):
"""
Wrapper class to which tests will be added dynamically in __main__
"""
pass
add_test_mixed = partial(add_test_gemm, cls=GemmMixedSm80, element=cutlass.DataType.f16, cc=cc, cluster_shape=[1, 1, 1],
opclass=cutlass.OpcodeClass.TensorOp, threadblock_shape=[128, 128, 64],
warp_count=[2, 2, 1], stages=3, element_accumulator=cutlass.DataType.f32)
# Test with upcast on A
add_test_mixed(element_A=cutlass.DataType.s8, alignments=[16, 8, 8], layouts=LayoutCombination.TNT)
add_test_mixed(element_A=cutlass.DataType.s8, alignments=[16, 8, 8], layouts=LayoutCombination.TNN)
# Test with upcast on B
add_test_mixed(element_B=cutlass.DataType.s8, alignments=[8, 16, 8], layouts=LayoutCombination.TNT)
add_test_mixed(element_B=cutlass.DataType.s8, alignments=[8, 16, 8], layouts=LayoutCombination.TNN)
if __name__ == '__main__':
unittest.main()

View File

@ -68,30 +68,30 @@ add_test_specialized = partial(add_test_gemm, element=cutlass.DataType.s8, cc=cc
# Tests using TensorOp
add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp)
add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=cutlass.DataType.s8,
add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=cutlass.DataType.s8, element_C=cutlass.DataType.s8,
element_accumulator=cutlass.DataType.s32, threadblock_shape=[256, 128, 64], warp_count=[4, 2, 1], stages=3)
add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.s8,
add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.s8, element_C=cutlass.DataType.s8,
element_accumulator=cutlass.DataType.s32, threadblock_shape=[128, 256, 64], warp_count=[2, 4, 1], stages=3)
add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[16, 16, 4], element_output=cutlass.DataType.s32,
add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[16, 16, 4], element_output=cutlass.DataType.s32, element_C=cutlass.DataType.s32,
element_accumulator=cutlass.DataType.s32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=4)
# Tests using SIMT
add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt)
add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.s8,
add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.s8, element_C=cutlass.DataType.s8,
element_accumulator=cutlass.DataType.s32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.s8,
add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.s8, element_C=cutlass.DataType.s8,
element_accumulator=cutlass.DataType.s32, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2)
add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.s8,
add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.s8, element_C=cutlass.DataType.s8,
element_accumulator=cutlass.DataType.s32, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2)
add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.s32,
add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.s32, element_C=cutlass.DataType.s32,
element_accumulator=cutlass.DataType.s32, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2)
add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.s32,
add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.s32, element_C=cutlass.DataType.s32,
element_accumulator=cutlass.DataType.s32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
# Stream K tests
add_test_streamk = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK)
add_test_streamk(cls=GemmS8Sm80StreamK, layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.s8,
add_test_streamk(cls=GemmS8Sm80StreamK, layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.s8, element_C=cutlass.DataType.s8,
element_accumulator=cutlass.DataType.s32, threadblock_shape=[128, 256, 64], warp_count=[2, 4, 1], stages=3)

View File

@ -37,7 +37,7 @@ import subprocess
import torch
from cutlass import (
from cutlass_library import (
DataType,
DataTypeSize,
GemmUniversalMode,
@ -49,7 +49,6 @@ from cutlass import (
from cutlass.backend import compiler
from cutlass.backend.gemm_operation import GemmArguments, GemmOperationUniversal
from cutlass.backend.memory_manager import get_allocated_size
from cutlass.backend.reduction_operation import ReductionArguments, ReductionOperation
from cutlass.shape import GemmCoord, MatrixCoord
from cutlass.utils.datatypes import torch_type
@ -65,16 +64,6 @@ class GemmUniversalLauncher:
compiler_mode= "nvcc",
**kwargs,
) -> None:
# Create the reduction kernel, if needed
self.reduction_operation: ReductionOperation = ReductionOperation(
shape=MatrixCoord(4, 32 * operation.C.alignment),
C=operation.C,
element_accumulator=operation.tile_description.math_instruction.element_accumulator,
element_compute=operation.epilogue_functor.element_epilogue,
epilogue_functor=operation.epilogue_functor,
count=operation.C.alignment,
)
self.math_operation = operation.tile_description.math_instruction.math_operation
self.verification = verification
@ -88,19 +77,26 @@ class GemmUniversalLauncher:
op_list = [operation]
if operation.arch < 90:
# Split K via Python is currently only supported for pre-SM90 kernels
self.reduction_operation: ReductionOperation = ReductionOperation(
shape=MatrixCoord(4, 32 * operation.C.alignment),
C=operation.C,
element_accumulator=operation.tile_description.math_instruction.element_accumulator,
element_compute=operation.epilogue_functor.element_epilogue,
epilogue_functor=operation.epilogue_functor,
count=operation.C.alignment,
)
op_list.append(self.reduction_operation)
compiler.add_module(op_list, bypass_cache=False)
self.operation = operation
self.dtype_A = torch_type(operation.A.element)
self.dtype_B = torch_type(operation.B.element)
self.dtype_A = torch_type(operation.A.element if not self.operation.switched else self.operation.B.element)
self.dtype_B = torch_type(operation.B.element if not self.operation.switched else self.operation.A.element)
self.dtype_C = torch_type(operation.C.element)
self.dtype_D = torch_type(operation.C.element)
self.dtype_D = torch_type(operation.epilogue_functor.element_output)
accumulator_size = DataTypeSize[operation.tile_description.math_instruction.element_accumulator]
element_size = DataTypeSize[operation.A.element]
element_size = min(DataTypeSize[operation.A.element], DataTypeSize[operation.B.element])
if element_size == 1:
self.rand_max = 1
@ -154,7 +150,18 @@ class GemmUniversalLauncher:
def reference(self, problem_size, tensor_A, tensor_B, tensor_C, alpha, beta):
# If any tensor is on CPU, place all tensors on CPU unless only
# tensor C is on CPU
devices = [x.device.type for x in [tensor_A, tensor_B, tensor_C]]
# Handle mixed-input cases by casting to the larger data type and overriding
# to whatever the data type of the larger type is
if self.dtype_A != self.dtype_B:
if DataTypeSize[self.operation.A.element] < DataTypeSize[self.operation.B.element]:
tensor_A = tensor_A.to(self.dtype_B).to(tensor_B.device)
else:
tensor_B = tensor_B.to(self.dtype_A).to(tensor_A.device)
devices = [x.device.type for x in [tensor_A, tensor_B]]
if tensor_C is not None:
devices.append(tensor_C.device.type)
if "cpu" in devices and devices != ["cuda", "cuda", "cpu"]:
device = torch.device("cpu")
else:
@ -162,14 +169,17 @@ class GemmUniversalLauncher:
tensor_A = tensor_A.to(device)
tensor_B = tensor_B.to(device)
tensor_C = tensor_C.to(device)
if tensor_C is not None:
tensor_C = tensor_C.to(device)
dtype = torch_type(self.compute_type)
alpha_torch = torch.tensor([alpha], device=device).to(dtype)
beta_torch = torch.tensor([beta], device=device).to(dtype)
tmp = tensor_A @ tensor_B
tensor_D_ref = (alpha_torch * tmp) + (tensor_C * beta_torch)
tensor_D_ref = (alpha_torch * tmp)
if tensor_C is not None:
tensor_D_ref += (tensor_C * beta_torch)
return tensor_D_ref.to(self.dtype_D)
def run(self, mode, problem_size, batch_count=1, split_k_slices=1, alpha=1.0, beta=0.0):
@ -199,12 +209,22 @@ class GemmUniversalLauncher:
self.dtype_B,
self.operation.B.layout if not self.operation.switched else transpose(self.operation.A.layout),
)
tensor_C, tensor_C_ref = self.uniform_init(
if self.dtype_C is not None:
tensor_C, tensor_C_ref = self.uniform_init(
(true_batch_count, problem_size.m, problem_size.n),
self.dtype_C,
self.operation.C.layout if not self.operation.switched else transpose(self.operation.C.layout),
)
else:
tensor_C = None
tensor_C_ref = None
tensor_D, _ = self.uniform_init(
(true_batch_count, problem_size.m, problem_size.n),
self.dtype_C,
self.dtype_D,
self.operation.C.layout if not self.operation.switched else transpose(self.operation.C.layout),
)
tensor_D = torch.zeros_like(tensor_C)
tensor_D = torch.zeros_like(tensor_D)
if self.compute_type in [DataType.s8, DataType.s32, DataType.u8, DataType.u32]:
alpha = int(alpha)
@ -248,6 +268,10 @@ class GemmUniversalLauncher:
if self.verification:
if mode == GemmUniversalMode.GemmSplitKParallel:
reduction_arguments.sync()
# Free memory allocated by args because we are not
# calling `arguments.sync()` in this case (which will free memory)
arguments.free()
else:
arguments.sync()
tensor_D_ref = self.reference(
@ -274,9 +298,6 @@ class GemmUniversalLauncher:
if mode == GemmUniversalMode.GemmSplitKParallel:
del reduction_arguments
cur_size = get_allocated_size()
assert cur_size == 0, f"{cur_size} B of memory were not released after this run"
return passed

View File

@ -30,9 +30,10 @@
#
#################################################################################################
import cutlass
from cutlass_library import SubstituteTemplate
from cutlass import (
import cutlass
from cutlass_library import (
DataTypeNames,
EpilogueScheduleSuffixes,
KernelScheduleSuffixes,
@ -42,7 +43,6 @@ from cutlass import (
ShortLayoutTypeNames
)
from cutlass.backend import library
from cutlass.backend.utils.software import SubstituteTemplate
from gemm_testbed import test_all_gemm
@ -82,6 +82,7 @@ def get_name(
stages,
element_a,
element_b,
element_c,
arch,
opclass,
kernel_schedule=None,
@ -102,6 +103,7 @@ def get_name(
:type stages: int
:param element_a: data type of operand A
:param element_b: data type of operand B
:param element_c: data type of operand C
:param arch: compute capability of kernel being generated
:type arch: int
:param opclass: class of operation being performed (e.g., SIMT, Tensor Core)
@ -122,7 +124,7 @@ def get_name(
"arch": str(arch),
"eA": DataTypeNames[element_a],
"eB": DataTypeNames[element_b],
"eC": DataTypeNames[element_output],
"eC": DataTypeNames[element_c],
"lA": ShortLayoutTypeNames[layouts[0]],
"lB": ShortLayoutTypeNames[layouts[1]],
"lC": ShortLayoutTypeNames[layouts[2]],
@ -161,7 +163,10 @@ def add_test_gemm(
swizzle=None,
kernel_schedule=None,
epilogue_schedule=None,
compilation_modes=['nvcc', 'nvrtc']):
compilation_modes=['nvcc', 'nvrtc'],
element_A=None,
element_B=None,
element_C=None):
"""
Create test-running functions with the given specification and set it as a method of ``cls``.
@ -195,22 +200,38 @@ def add_test_gemm(
:param epilogue_schedule: epilogue schedule to use
:type epilogue_schedule: cutlass.EpilogueScheduleType
:param compilation_modes: list of compilers to used in testing the kernel (options: 'nvrtc', 'nvcc')
:type compilation_modes: list
:type compilation_modes: list,
:param element_A: data type of operand A. If set, overrides ``element``
:type element_A: cutlass.DataType
:param element_B: data type of operand B. If set, overrides ``element``
:type element_B: cutlass.DataType
:param element_C: data type of operand C. If set, overrides ``element``
:type element_C: cutlass.DataType
"""
if element_A is None:
element_A = element
if element_B is None:
element_B = element
if element_C is None:
element_C = element
if element_output is None:
element_output = element
if element_accumulator is None:
element_accumulator = element
for compilation_mode in compilation_modes:
def run(self):
"""
Dynamically-generated function that constructs a GEMM operation and verifies it against
multiple test cases.
"""
element_A = element
element_B = element
layout_A, layout_B, layout_C = layouts
alignment_A, alignment_B, alignment_C = alignments
plan = cutlass.op.Gemm(element_A=element_A, element_B=element_B,
element_C=element_output, element_D=element_output,
element_C=element_C, element_D=element_output,
layout_A=layout_A, layout_B=layout_B, layout_C=layout_C,
element_accumulator=element_accumulator,
kernel_cc=cc)
@ -233,7 +254,7 @@ def add_test_gemm(
name = get_name(
layouts=layouts, alignments=alignments, element_output=element_output, element_accumulator=element_accumulator,
element_epilogue=element_epilogue, cluster_shape=cluster_shape, threadblock_shape=threadblock_shape,
stages=stages, element_a=element, element_b=element, arch=cc, opclass=opclass,
stages=stages, element_a=element_A, element_b=element_B, element_c=element_C, arch=cc, opclass=opclass,
kernel_schedule=kernel_schedule, epilogue_schedule=epilogue_schedule, suffix=f'_{compilation_mode}')
setattr(cls, name, run)

View File

@ -0,0 +1,57 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Tests for a successful installation of the CUTLASS Python interface
"""
import os
import unittest
import cutlass
import cutlass_library
class InstallationTest(unittest.TestCase):
def test_cutlass_source_paths(self):
"""
Tests that CUTLASS source is available as part of the cutlass and cutlass_library packages
"""
src_file = 'include/cutlass/cutlass.h'
library_file = os.path.join(cutlass_library.source_path, src_file)
cutlass_file = os.path.join(cutlass.CUTLASS_PATH, src_file)
assert os.path.isfile(library_file), f"Unable to locate file {library_file}. Installation has not succeeded."
assert os.path.isfile(cutlass_file), f"Unable to locate file {cutlass_file}. Installation has not succeeded."
if __name__ == "__main__":
unittest.main()

View File

@ -50,7 +50,7 @@ class Conv2dEquivalence:
"""
def __init__(self, conv_kind, element_A, element_B, element_C, element_D, element_accumulator,
alignment_A, alignment_B, alignment_C):
self.element_A = element_A
self.element_B = element_B
self.element_C = element_C
@ -59,21 +59,21 @@ class Conv2dEquivalence:
self.alignment_A = alignment_A
self.alignment_B = alignment_B
self.alignment_C = alignment_C
self.conv_kind = conv_kind
self.plan = cutlass.op.Conv2d(
kind=self.conv_kind, element_A=element_A, element_B=element_B, element_C=element_C,
element_D=element_D, element_accumulator=element_accumulator)
self.op = self.plan.construct(
alignment_A=self.alignment_A, alignment_B=self.alignment_B,
alignment_A=self.alignment_A, alignment_B=self.alignment_B,
alignment_C=self.alignment_C)
def _plans_equal(self, other_plan) -> bool:
"""
Compares whether two plans are equal
:param other_plan: plan to compare against the default Conv2d
:type other_plan: cutlass.op.Conv2d
@ -81,9 +81,9 @@ class Conv2dEquivalence:
:rtype: bool
"""
other_op = other_plan.construct(
alignment_A=self.alignment_A, alignment_B=self.alignment_B,
alignment_A=self.alignment_A, alignment_B=self.alignment_B,
alignment_C=self.alignment_C)
return self.op.rt_module.emit() == other_op.rt_module.emit()
def generic_test(self):
@ -91,16 +91,16 @@ class Conv2dEquivalence:
Tests the equivalence of various constructions of the Conv2d interface when using CUTLASS data types
and layouts for constructing the Conv2d interface
"""
if not datatypes.numpy_available:
if not datatypes.is_numpy_available():
return
# Test when specifying all parameters
plan_other = cutlass.op.Conv2d(
kind=self.conv_kind,
element_A=self.element_A, element_B=self.element_B, element_C=self.element_C,
element_D=self.element_D, element_accumulator=self.element_accumulator)
assert self._plans_equal(plan_other)
# Test when specifying all parameters but A
plan_other = cutlass.op.Conv2d(
kind=self.conv_kind,
@ -108,7 +108,7 @@ class Conv2dEquivalence:
element_D=self.element_D, element_accumulator=self.element_accumulator,
element=self.element_A)
assert self._plans_equal(plan_other)
# Test when specifying all parameters but A and B as tensors using generic element and output
plan_other = cutlass.op.Conv2d(
kind=self.conv_kind,
@ -116,7 +116,7 @@ class Conv2dEquivalence:
element_D=self.element_D, element_accumulator=self.element_accumulator,
element=self.element_A)
assert self._plans_equal(plan_other)
# Test without explicit accumulator. Only run if the type of C and the accumulator are equal
if self.element_C == self.element_accumulator:
plan_other = cutlass.op.Conv2d(
@ -125,18 +125,18 @@ class Conv2dEquivalence:
element_D=self.element_D,
element=self.element_A)
assert self._plans_equal(plan_other)
# Test with only the generic types. Only rune if the types of A, B, C, and D are the same
if (self.element_A == self.element_B and self.element_A == self.element_C and self.element_A == self.element_D
and self.element_A == self.element_accumulator):
plan_other = cutlass.op.Conv2d(kind=self.conv_kind, element=self.element_A)
assert self._plans_equal(plan_other)
def numpy_test(self):
"""
Tests the equivalence of various constructions of the Conv2d interface when using numpy as a frontend
"""
if not datatypes.numpy_available:
if not datatypes.is_numpy_available():
return
import numpy as np
@ -145,7 +145,7 @@ class Conv2dEquivalence:
type_C = datatypes.numpy_type(self.element_C)
type_D = datatypes.numpy_type(self.element_D)
type_accum = datatypes.numpy_type(self.element_accumulator)
size = (2, 2)
A = np.zeros(size, dtype=type_A)
B = np.zeros(size, dtype=type_B)
@ -153,49 +153,49 @@ class Conv2dEquivalence:
D = np.zeros(size, dtype=type_D)
return self.tensor_test(type_A, type_B, type_C, type_D, type_accum, A, B, C, D)
def torch_test(self):
"""
Tests the equivalence of various constructions of the Conv2d interface when using torch as a frontend
"""
if not datatypes.torch_available:
if not datatypes.is_torch_available():
return
import torch
type_A = datatypes.torch_type(self.element_A)
type_B = datatypes.torch_type(self.element_B)
type_C = datatypes.torch_type(self.element_C)
type_D = datatypes.torch_type(self.element_D)
type_accum = datatypes.torch_type(self.element_accumulator)
size = (2, 2)
A = torch.empty(size, dtype=type_A)
B = torch.empty(size, dtype=type_B)
C = torch.empty(size, dtype=type_C)
D = torch.empty(size, dtype=type_D)
return self.tensor_test(type_A, type_B, type_C, type_D, type_accum, A, B, C, D)
def tensor_test(self, type_A, type_B, type_C, type_D, type_accum, A, B, C, D):
# Test when specifying all parameters via tensors
plan_np = cutlass.op.Conv2d(kind=self.conv_kind, A=A, B=B, C=C, D=D, element_accumulator=type_accum)
assert self._plans_equal(plan_np)
# Test when specifying all parameters but A as tensors
plan_np = cutlass.op.Conv2d(kind=self.conv_kind, B=B, C=C, D=D, element_accumulator=type_accum, element_A=type_A)
assert self._plans_equal(plan_np)
# Test when specifying all parameters but A and B as tensors and using generic element and output
if type_A == type_B:
plan_np = cutlass.op.Conv2d(kind=self.conv_kind, C=C, D=D, element_accumulator=type_accum, element=type_A)
assert self._plans_equal(plan_np)
# Test without explicit accumulator. Only run if the type of C and the accumulator.
if type_C == type_accum:
plan_np = cutlass.op.Conv2d(kind=self.conv_kind, A=A, B=B, C=C, D=D)
assert self._plans_equal(plan_np)
# Test with only the generic types and layouts. Only run if types and layouts of A, B, C, and D are the same.
if (type_A == type_B and type_A == type_C and type_A == type_D and type_A == type_accum):
plan_np = cutlass.op.Conv2d(kind=self.conv_kind, element=type_A)
@ -223,20 +223,20 @@ type2alignment = {
}
def add_test(conv_kind, element_A, element_B, element_C, element_D, element_accumulator):
test_name = f"test_conv2d_{conv_kind}_{element_A}_{element_B}_{element_C}_{element_D}_{element_accumulator}"
def run(self):
conv2d_eq = Conv2dEquivalence(
conv_kind=conv_kind,
conv_kind=conv_kind,
element_A=element_A, element_B=element_B,
element_C=element_C, element_D=element_D,
element_accumulator=element_accumulator,
element_accumulator=element_accumulator,
alignment_A=type2alignment[element_A], alignment_B=type2alignment[element_B],
alignment_C=type2alignment[element_C]
)
conv2d_eq.test_all()
setattr(ConvEquivalenceTest, test_name, run)
for conv_kind in ["fprop", "wgrad", "dgrad"]:
@ -255,25 +255,25 @@ class Conv2dErrorTests(unittest.TestCase):
"""
Tests various error scenarios that arise with the high-level Gemm interface
"""
def test_alignment(self):
"""
Tests case in which the alignment specified is unsupported
"""
plan = cutlass.op.Conv2d(kind="fprop", element=cutlass.DataType.f16)
with ExpectException(True, 'Alignment 3 is not supported for F16. The construction should fail.'):
op = plan.construct(alignment_A=3, alignment_B=3, alignment_C=3)
def test_invalid_tile_description(self):
"""
Tests scenarios in which an invalid tile description is provided for a given CC
"""
plan = cutlass.op.Conv2d(kind="fprop", element=cutlass.DataType.f16)
td = plan.tile_descriptions()[0]
td.threadblock_shape=[17, 32, 5]
plan.tile_description = td
with ExpectException(True, 'The threadblock shape is invalid. The compilation should fail.'):
plan.compile()

View File

@ -93,13 +93,16 @@ class EVTErrorTests(unittest.TestCase):
"""
Test when the epilogue consumes too much shared memory
"""
def evt_too_much_shared_memory(accum, C1, C2, C3, C4, C5):
def evt_too_much_shared_memory(accum, C1, C2, C3, C4, C5, C6, C7, C8):
D1 = accum + C1
D2 = D1 + C2
D3 = D2 + C3
D4 = D3 + C4
D = D4 + C5
return D, D1, D2, D3, D4
D5 = D4 + C5
D6 = D5 + C6
D7 = D6 + C7
D = D7 + C8
return D, D1, D2, D3, D4, D5, D6, D7
example_tensors = {
"accum": self.fake_tensor(np.float16, (6, 512, 512)),
@ -108,10 +111,16 @@ class EVTErrorTests(unittest.TestCase):
"C3": self.fake_tensor(np.float16, (6, 512, 512)),
"C4": self.fake_tensor(np.float16, (6, 512, 512)),
"C5": self.fake_tensor(np.float16, (6, 512, 512)),
"C6": self.fake_tensor(np.float16, (6, 512, 512)),
"C7": self.fake_tensor(np.float16, (6, 512, 512)),
"C8": self.fake_tensor(np.float16, (6, 512, 512)),
"D1": self.fake_tensor(np.float16, (6, 512, 512)),
"D2": self.fake_tensor(np.float16, (6, 512, 512)),
"D3": self.fake_tensor(np.float16, (6, 512, 512)),
"D4": self.fake_tensor(np.float16, (6, 512, 512)),
"D5": self.fake_tensor(np.float16, (6, 512, 512)),
"D6": self.fake_tensor(np.float16, (6, 512, 512)),
"D7": self.fake_tensor(np.float16, (6, 512, 512)),
"D": self.fake_tensor(np.float16, (6, 512, 512))
}

View File

@ -85,7 +85,7 @@ class GemmEquivalence:
Tests the equivalence of various constructions of the Gemm interface when using CUTLASS data types
and layouts for constructing the Gemm interface
"""
if not datatypes.numpy_available:
if not datatypes.is_numpy_available():
return
# Test when specifying all parameters
@ -126,7 +126,7 @@ class GemmEquivalence:
"""
Tests the equivalence of various constructions of the Gemm interface when using numpy as a frontend
"""
if not datatypes.numpy_available:
if not datatypes.is_numpy_available():
return
import numpy as np

View File

@ -114,7 +114,7 @@ void run_test_integer_range_all() {
);
destination.sync_host();
// Verify conversion
bool passed = true;
for (int i = 0; i < kN; ++i) {
@ -124,7 +124,7 @@ void run_test_integer_range_all() {
}
}
EXPECT_TRUE(passed) << " FastNumericArrayConverter failed";
// Print out results for the failed conversion.
if (!passed) {
for (int i = 0; i < kN; ++i) {

View File

@ -48,11 +48,20 @@ TEST(float_e4m3_t, host_conversion) {
for (int i = -8; i < 8; ++i) {
float f = static_cast<float>(i);
cutlass::int4b_t s = static_cast<cutlass::int4b_t>(i);
FP8 w = static_cast<FP8>(s);
FP8 x = static_cast<FP8>(i);
FP8 y = static_cast<FP8>(f);
EXPECT_TRUE(static_cast<cutlass::int4b_t>(w) == s);
EXPECT_TRUE(static_cast<int>(x) == i);
EXPECT_TRUE(static_cast<float>(y) == f);
if (i >= 0) {
cutlass::uint4b_t u = static_cast<cutlass::uint4b_t>(i);
FP8 z = static_cast<FP8>(u);
EXPECT_TRUE(static_cast<unsigned>(z) == u);
}
}
// Try out default-ctor (zero initialization of primitive proxy type)
@ -72,11 +81,20 @@ TEST(float_e5m2_t, host_conversion) {
for (int i = -8; i < 8; ++i) {
float f = static_cast<float>(i);
cutlass::int4b_t s = static_cast<cutlass::int4b_t>(i);
FP8 w = static_cast<FP8>(s);
FP8 x = static_cast<FP8>(i);
FP8 y = static_cast<FP8>(f);
EXPECT_TRUE(static_cast<cutlass::int4b_t>(w) == s);
EXPECT_TRUE(static_cast<int>(x) == i);
EXPECT_TRUE(static_cast<float>(y) == f);
if (i >= 0) {
cutlass::uint4b_t u = static_cast<cutlass::uint4b_t>(i);
FP8 z = static_cast<FP8>(u);
EXPECT_TRUE(static_cast<cutlass::uint4b_t>(z) == u);
}
}
// Try out default-ctor (zero initialization of primitive proxy type)

View File

@ -60,7 +60,7 @@ __global__ void convert(
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Destination, typename Source, int Count>
template <typename Destination, typename Source, int Count, int Range = 4>
void run_test(const char dest_name[], const char source_name[]) {
const int kN = Count;
@ -69,9 +69,11 @@ void run_test(const char dest_name[], const char source_name[]) {
cutlass::HostTensor<Destination, cutlass::layout::RowMajor> destination({1, kN});
cutlass::HostTensor<Source, cutlass::layout::RowMajor> source({1, kN});
auto source_ref = source.host_ref();
auto destination_ref = destination.host_ref();
for (int i = 0; i < kN; ++i) {
source.host_data()[i] = Source(i % 4);
source_ref.at({0, i}) = Source(i % Range);
}
source.sync_device();
@ -84,9 +86,67 @@ void run_test(const char dest_name[], const char source_name[]) {
destination.sync_host();
for (int i = 0; i < kN; ++i) {
EXPECT_TRUE(float(destination.host_data()[i]) == float(source.host_data()[i]))
<< "Destination type: " << dest_name
<< ", Source type: " << source_name
EXPECT_TRUE(float(destination_ref.at({0, i})) == float(source_ref.at({0, i})))
<< "Destination type: " << dest_name << " "<< float(destination_ref.at({0, i}))
<< ", Source type: " << source_name << " " << float(source_ref.at({0, i}))
<< ", Count: " << Count;
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Destination, typename Source, typename ScaleFactor, int Count>
__global__ void convert_with_scale_factor(
cutlass::Array<Destination, Count> *destination,
cutlass::Array<Source, Count> const *source,
cutlass::Array<ScaleFactor, Count> const *scale_factor) {
cutlass::NumericArrayConverter<Destination, Source, Count> convert;
*destination = convert(*source, *scale_factor);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Destination, typename Source, typename ScaleFactor, int Count, int Range = 4>
void run_test_with_scalefactor(const char dest_name[], const char source_name[], const char scale_factor_name[]) {
const int kN = Count;
dim3 grid(1, 1);
dim3 block(1, 1);
cutlass::HostTensor<Destination, cutlass::layout::RowMajor> destination({1, kN});
cutlass::HostTensor<Source, cutlass::layout::RowMajor> source({1, kN});
cutlass::HostTensor<ScaleFactor, cutlass::layout::RowMajor> scale_factor({1, kN});
auto source_ref = source.host_ref();
auto destination_ref = destination.host_ref();
auto scale_factor_ref = scale_factor.host_ref();
for (int i = 0; i < kN; ++i) {
source_ref.at({0, i}) = Source(i % Range);
}
for (int i = 0; i < kN; ++i) {
scale_factor_ref.at({0, i}) = ScaleFactor(1 + i % 8);
}
source.sync_device();
scale_factor.sync_device();
convert_with_scale_factor<Destination, Source, ScaleFactor, kN><<< grid, block >>>(
reinterpret_cast<cutlass::Array<Destination, kN> *>(destination.device_data()),
reinterpret_cast<cutlass::Array<Source, kN> const *>(source.device_data()),
reinterpret_cast<cutlass::Array<ScaleFactor, kN> const *>(scale_factor.device_data())
);
destination.sync_host();
for (int i = 0; i < kN; ++i) {
float ref = float(source_ref.at({0, i})) / float(scale_factor_ref.at({0, i}));
EXPECT_TRUE(float(destination_ref.at({0, i})) == ref)
<< "Destination type: " << dest_name << " "<< float(destination_ref.at({0, i}))
<< ", Source type: " << source_name << " " << float(source_ref.at({0, i}))
<< ", Count: " << Count;
}
}
@ -98,7 +158,16 @@ void run_test(const char dest_name[], const char source_name[]) {
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(NumericConversion, f32_to_f16_rn) {
int const kN = 1;
constexpr int kN = 1;
using Source = float;
const char source_name[] = "float";
using Destination = cutlass::half_t;
const char dest_name[] = "half_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
TEST(NumericConversion, f32x2_to_f16x2_rn) {
constexpr int kN = 2;
using Source = float;
const char source_name[] = "float";
using Destination = cutlass::half_t;
@ -107,7 +176,7 @@ TEST(NumericConversion, f32_to_f16_rn) {
}
TEST(NumericConversion, f32x8_to_f16x8_rn) {
int const kN = 8;
constexpr int kN = 8;
using Source = float;
const char source_name[] = "float";
using Destination = cutlass::half_t;
@ -394,4 +463,50 @@ TEST(NumericConversion, fe5m2_to_bf16_array) {
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
// These are included as regression tests for a special case when N = 4.
TEST(NumericConversion, int4b_t_to_fe5m2_t_array_4) {
int const kN = 4;
using Source = cutlass::int4b_t;
const char source_name[] = "int4b_t";
using Destination = cutlass::float_e5m2_t;
const char dest_name[] = "float_e5m2_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
TEST(NumericConversion, int_to_fe4m3_t_array_4) {
int const kN = 4;
using Source = int;
const char source_name[] = "int";
using Destination = cutlass::float_e4m3_t;
const char dest_name[] = "float_e4m3_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
TEST(NumericConversion, int2b_t_to_fe4m3_t_array_4) {
int const kN = 4;
using Source = cutlass::int2b_t;
const char source_name[] = "int2b_t";
using Destination = cutlass::float_e4m3_t;
const char dest_name[] = "float_e4m3_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
TEST(NumericConversion, fe5m2_t_to_double_array_4) {
int const kN = 4;
using Source = cutlass::float_e5m2_t;
const char source_name[] = "float_e5m2_t";
using Destination = double;
const char dest_name[] = "double";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
TEST(NumericConversion, int_to_fe4m3_t_array_32) {
int const kN = 32;
using Source = int;
const char source_name[] = "int";
using Destination = cutlass::float_e4m3_t;
const char dest_name[] = "float_e4m3_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}

View File

@ -32,6 +32,7 @@
#include "cutlass_unit_test.h"
#include <cutlass/trace.h>
#include <cute/pointer.hpp>
TEST(CuTe_core, Pointer)
@ -45,7 +46,7 @@ TEST(CuTe_core, Pointer)
// Test T* overloads (T can be nonconst or const)
{
using T = float;
using expected_type = cute::gmem_ptr<T>;
using expected_type = cute::gmem_ptr<T*>;
T* p = nullptr;
// explicit template argument
@ -58,7 +59,7 @@ TEST(CuTe_core, Pointer)
}
{
using T = float const;
using expected_type = cute::gmem_ptr<T>;
using expected_type = cute::gmem_ptr<T*>;
T* p = nullptr;
// explicit template argument
@ -74,7 +75,7 @@ TEST(CuTe_core, Pointer)
// (these require an explicit template argument)
{
using T = float;
using expected_type = cute::gmem_ptr<T>;
using expected_type = cute::gmem_ptr<T*>;
void* p = nullptr;
auto gmem_p0 = cute::make_gmem_ptr<T>(p);
@ -82,7 +83,7 @@ TEST(CuTe_core, Pointer)
}
{
using T = float const;
using expected_type = cute::gmem_ptr<T>;
using expected_type = cute::gmem_ptr<T*>;
void const* p = nullptr;
auto gmem_p0 = cute::make_gmem_ptr<T>(p);
@ -92,14 +93,14 @@ TEST(CuTe_core, Pointer)
// Test nullptr_t overload.
{
using T = float;
using expected_type = cute::gmem_ptr<T>;
using expected_type = cute::gmem_ptr<T*>;
auto gmem_p0 = cute::make_gmem_ptr<T>(nullptr);
static_assert(cute::is_same_v<decltype(gmem_p0), expected_type>);
}
{
using T = float const;
using expected_type = cute::gmem_ptr<T>;
using expected_type = cute::gmem_ptr<T*>;
auto gmem_p0 = cute::make_gmem_ptr<T>(nullptr);
static_assert(cute::is_same_v<decltype(gmem_p0), expected_type>);

View File

@ -416,7 +416,6 @@ TEST(SM90_CuTe_Hopper, Tma_Load_InternalType)
test_tma_load<half_t, uint64_t>(gmem_layout, smem_layout);
test_tma_load< float, uint64_t>(gmem_layout, smem_layout);
test_tma_load<double, uint64_t>(gmem_layout, smem_layout);
}
// Complex<double> is 128bit, which the TMA has no concept of

View File

@ -43,7 +43,7 @@ namespace cutlass::test {
template <class ElementType, class SmemLayout>
struct SharedStorage
{
cute::array_aligned<ElementType, cute::cosize_v<SmemLayout>> smem;
cute::ArrayEngine<ElementType, cute::cosize_v<SmemLayout>> smem;
cute::uint64_t tma_load_mbar[1];
};
@ -62,26 +62,26 @@ tma_test_device_cute(T const* g_in, T* g_out,
extern __shared__ char shared_memory[];
using SharedStorage = SharedStorage<T, SmemLayout>;
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(shared_memory);
// Construct SMEM tensor
Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.data()), smem_layout); // (CTA_TILE_M,CTA_TILE_N,...)
Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.begin()), smem_layout); // (CTA_TILE_M,CTA_TILE_N,...)
// Shared memory barriers use 64bits in SMEM for synchronization
uint64_t* tma_load_mbar = shared_storage.tma_load_mbar;
// TMA requires special handling of strides to deal with coord codomain mapping
// Represent the full tensors -- get these from TMA
Tensor mA = tma.get_tma_tensor(shape(gmem_layout));
Tensor mB = make_tensor(make_gmem_ptr(g_out), gmem_layout);
Tensor mB = make_tensor(make_gmem_ptr<T>(g_out), gmem_layout);
constexpr int R = rank_v<CTA_Tiler>;
Tensor gA = local_tile(mA, cta_tiler, repeat<R>(_)); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...)
Tensor gB = local_tile(mB, cta_tiler, repeat<R>(_)); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...)
Tensor gA = flat_divide(mA, cta_tiler); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...)
Tensor gB = flat_divide(mB, cta_tiler); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...)
//
// Prepare the TMA_LOAD
//
auto cta_tma = tma.get_slice(Int<0>{}); // CTA slice
Tensor tAgA_x = cta_tma.partition_S(gA); // (TMA,TMA_M,TMA_N,REST_M,REST_N)
Tensor tAsA_x = cta_tma.partition_D(sA); // (TMA,TMA_M,TMA_N)
@ -89,11 +89,13 @@ tma_test_device_cute(T const* g_in, T* g_out,
if (thread0()) {
print(tma);
print("TILE : "); print(cta_tiler); print("\n");
print(" mA : "); print( mA.data()); print(" o "); print( mA.layout()); print("\n");
print(" gA : "); print( gA.data()); print(" o "); print( gA.layout()); print("\n");
print("tAgA_x: "); print(tAgA_x.data()); print(" o "); print(tAgA_x.layout()); print("\n");
print(" sA : "); print( sA.data()); print(" o "); print( sA.layout()); print("\n");
print("tAsA_x: "); print(tAsA_x.data()); print(" o "); print(tAsA_x.layout()); print("\n");
print(" mA : "); print( mA); print("\n");
print(" mB : "); print( mB); print("\n");
print(" gA : "); print( gA); print("\n");
print(" gB : "); print( gB); print("\n");
print(" sA : "); print( sA); print("\n");
print("tAgA_x: "); print(tAgA_x); print("\n");
print("tAsA_x: "); print(tAsA_x); print("\n");
}
#endif
@ -111,9 +113,9 @@ tma_test_device_cute(T const* g_in, T* g_out,
#if 0
if (thread0()) {
print("tAgA : "); print(tAgA.data()); print(" o "); print(tAgA.layout()); print("\n");
print("tAsA : "); print(tAsA.data()); print(" o "); print(tAsA.layout()); print("\n");
print("tBgB : "); print(tBgB.data()); print(" o "); print(tBgB.layout()); print("\n");
print("tAgA : "); print(tAgA); print("\n");
print("tAsA : "); print(tAsA); print("\n");
print("tBgB : "); print(tBgB); print("\n");
}
#endif
@ -121,7 +123,7 @@ tma_test_device_cute(T const* g_in, T* g_out,
for (int stage = 0; stage < size<1>(tAgA); ++stage)
{
// Set the bytes transferred in this TMA transaction (may involve multiple issues)
constexpr int kTmaTransactionBytes = size(sA) * sizeof_bits_v<T> / 8;
constexpr int kTmaTransactionBytes = sizeof(ArrayEngine<T, size(sA)>);
if (threadIdx.x == 0)
{
@ -146,9 +148,15 @@ tma_test_device_cute(T const* g_in, T* g_out,
// print_tensor(sA);
//}
for (int i = threadIdx.x; i < size(sA); i += blockDim.x) {
tBgB(i,stage) = sA(i);
// for (int i = threadIdx.x; i < size(sA); i += blockDim.x) {
// tBgB(i,stage) = sA(i);
// }
// Subbyte elements could cause race conditions, so be even more conservative
if (thread0()) {
copy(sA, tBgB(_,stage));
}
__syncthreads();
}
}
@ -161,30 +169,38 @@ test_tma_load(CopyOp const& copy_op,
CTA_Tile const& cta_tile)
{
using namespace cute;
thrust::host_vector<T> h_in(cosize(gmem_layout));
for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i % 13); }
thrust::device_vector<T> d_in = h_in;
thrust::device_vector<T> d_out(h_in.size(), T(-1));
Tensor gA = make_tensor(d_in.data().get(), gmem_layout);
// Allocate and initialize host test data
size_t N = ceil_div(cosize(gmem_layout) * sizeof_bits<T>::value, 8);
thrust::host_vector<char> h_in(N);
Tensor hA_in = make_tensor(recast_ptr<T>(h_in.data()), gmem_layout);
for (int i = 0; i < size(hA_in); ++i) { hA_in(i) = static_cast<T>(i % 13); }
// Allocate and initialize device test data
thrust::device_vector<char> d_in = h_in;
thrust::device_vector<char> d_out(h_in.size(), char(-1));
// Create TMA for this device Tensor
Tensor gA = make_tensor(make_gmem_ptr<T>(raw_pointer_cast(d_in.data())), gmem_layout);
auto tma = make_tma_copy<TmaType>(copy_op, gA, smem_layout, cta_tile, Int<1>{});
//print(tma);
// Launch
int smem_size = int(sizeof(SharedStorage<T, decltype(smem_layout)>));
tma_test_device_cute<<<1, 128, smem_size>>>(
thrust::raw_pointer_cast(d_in.data()),
thrust::raw_pointer_cast(d_out.data()),
reinterpret_cast<T const*>(raw_pointer_cast(d_in.data())),
reinterpret_cast<T*> (raw_pointer_cast(d_out.data())),
tma, cta_tile,
gmem_layout,
smem_layout);
thrust::host_vector<T> h_out = d_out;
// Copy results back to host
thrust::host_vector<char> h_out = d_out;
Tensor hA_out = make_tensor(recast_ptr<T>(h_out.data()), gmem_layout);
// Validate the results, and tolerate the first 3 errors:
Tensor hA_in = make_tensor(h_in.data(), gmem_layout);
Tensor hA_out = make_tensor(h_out.data(), gmem_layout);
// Validate the results. Print only the first 3 errors.
int count = 3;
for (int i = 0; i < cute::size(gmem_layout) && count > 0; ++i) {
for (int i = 0; i < size(hA_out) && count > 0; ++i) {
EXPECT_EQ(hA_in(i), hA_out(i));
if (hA_in(i) != hA_out(i)) {
--count;

View File

@ -43,7 +43,7 @@ namespace cutlass::test {
template <class ElementType, class SmemLayout>
struct SharedStorage
{
cute::array_aligned<ElementType, cute::cosize_v<SmemLayout>> smem;
cute::ArrayEngine<ElementType, cute::cosize_v<SmemLayout>> smem;
};
#if CUDA_12_0_SM90_FEATURES_SUPPORTED
@ -61,24 +61,24 @@ tma_test_device_cute(T const* g_in, T* g_out,
extern __shared__ char shared_memory[];
using SharedStorage = SharedStorage<T, SmemLayout>;
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(shared_memory);
// Construct SMEM tensor
Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem.data()), smem_layout); // (CTA_TILE_M,CTA_TILE_N,...)
Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem.begin()), smem_layout); // (CTA_TILE_M,CTA_TILE_N,...)
// TMA requires special handling of strides to deal with coord codomain mapping
// Represent the full tensors -- get these from TMA
Tensor mA = make_tensor(make_gmem_ptr(g_in), gmem_layout);
Tensor mA = make_tensor(make_gmem_ptr<T>(g_in), gmem_layout);
Tensor mB = tma.get_tma_tensor(shape(gmem_layout));
constexpr int R = rank_v<CTA_Tiler>;
Tensor gA = local_tile(mA, cta_tiler, repeat<R>(_)); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...)
Tensor gB = local_tile(mB, cta_tiler, repeat<R>(_)); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...)
Tensor gA = flat_divide(mA, cta_tiler); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...)
Tensor gB = flat_divide(mB, cta_tiler); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...)
//
// Prepare the TMA_STORE
//
auto cta_tma = tma.get_slice(Int<0>{}); // CTA slice
Tensor tBsB_x = cta_tma.partition_S(sB); // (TMA,TMA_M,TMA_N)
Tensor tBgB_x = cta_tma.partition_D(gB); // (TMA,TMA_M,TMA_N,REST_M,REST_N)
@ -121,11 +121,17 @@ tma_test_device_cute(T const* g_in, T* g_out,
// Read in trivially gmem -> smem
//
for (int i = threadIdx.x; i < size(sB); i += blockDim.x) {
sB(i) = tAgA(i,stage);
// for (int i = threadIdx.x; i < size(sB); i += blockDim.x) {
// sB(i) = tAgA(i,stage);
// }
// Subbyte elements could cause race conditions, so be even more conservative
if (thread0()) {
copy(tAgA(_,stage), sB);
}
__syncthreads();
cute::cp_async_wait<0>();
//
// Perform the TMA_STORE
@ -148,30 +154,38 @@ test_tma_store(CopyOp const& copy_op,
CTA_Tile const& cta_tile)
{
using namespace cute;
thrust::host_vector<T> h_in(cosize(gmem_layout));
for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i % 13); }
thrust::device_vector<T> d_in = h_in;
thrust::device_vector<T> d_out(h_in.size(), T(-1));
Tensor gA = make_tensor(d_out.data().get(), gmem_layout);
// Allocate and initialize host test data
size_t N = ceil_div(cosize(gmem_layout) * sizeof_bits<T>::value, 8);
thrust::host_vector<char> h_in(N);
Tensor hA_in = make_tensor(recast_ptr<T>(h_in.data()), gmem_layout);
for (int i = 0; i < size(hA_in); ++i) { hA_in(i) = static_cast<T>(i % 13); }
// Allocate and initialize device test data
thrust::device_vector<char> d_in = h_in;
thrust::device_vector<char> d_out(h_in.size(), char(-1));
// Create TMA for this device Tensor
Tensor gA = make_tensor(make_gmem_ptr<T>(raw_pointer_cast(d_out.data())), gmem_layout);
auto tma = make_tma_copy<TmaType>(copy_op, gA, smem_layout, cta_tile, Int<1>{});
//print(tma);
// Launch
int smem_size = int(sizeof(SharedStorage<T, decltype(smem_layout)>));
tma_test_device_cute<<<1, 128, smem_size>>>(
thrust::raw_pointer_cast(d_in.data()),
thrust::raw_pointer_cast(d_out.data()),
reinterpret_cast<T const*>(raw_pointer_cast(d_in.data())),
reinterpret_cast<T*> (raw_pointer_cast(d_out.data())),
tma, cta_tile,
gmem_layout,
smem_layout);
thrust::host_vector<T> h_out = d_out;
// Copy results back to host
thrust::host_vector<char> h_out = d_out;
Tensor hA_out = make_tensor(recast_ptr<T>(h_out.data()), gmem_layout);
// Validate the results, and tolerate the first 3 errors:
Tensor hA_in = make_tensor(h_in.data(), gmem_layout);
Tensor hA_out = make_tensor(h_out.data(), gmem_layout);
// Validate the results. Print only the first 3 errors.
int count = 3;
for (int i = 0; i < cute::size(gmem_layout) && count > 0; ++i) {
for (int i = 0; i < size(hA_out) && count > 0; ++i) {
EXPECT_EQ(hA_in(i), hA_out(i));
if (hA_in(i) != hA_out(i)) {
--count;

View File

@ -242,6 +242,24 @@ cutlass_test_unit_add_executable(
sm80_gemm_tf32_tf32_f32_tensor_op_f32.cu
)
cutlass_test_unit_add_executable(
cutlass_test_unit_gemm_device_mixed_input_tensorop_sm80
BATCH_SOURCES ON
BATCH_SIZE 4
# Upcast on Operand A
gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu
gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu
gemm_universal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32_sm80.cu
# Upcast on Operand B
gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f16_sm80.cu
gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f16_sm80.cu
gemm_universal_bf16t_s8n_bf16t_mixed_input_tensor_op_f32_sm80.cu
)
cutlass_test_unit_add_executable(
cutlass_test_unit_gemm_device_tensorop_sm90
@ -272,10 +290,22 @@ cutlass_test_unit_add_executable(
BATCH_SOURCES ON
BATCH_SIZE 4
sm90_gemm_f16_f16_f16_alignx_tensor_op.cu
sm90_gemm_f16_f16_f16_alignx_tensor_op_f32.cu
sm90_gemm_f16_f16_f16_alignx_tensor_op_f32_warpspecialized.cu
sm90_gemm_f16_f16_f16_alignx_tensor_op_f32_warpspecialized_cooperative.cu
sm90_gemm_f16_f16_f16_alignx_tensor_op_f32_warpspecialized_pingpong.cu
sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32.cu
sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized.cu
sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_cooperative.cu
sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_pingpong.cu
sm90_gemm_s8_s8_s8_alignx_tensor_op_s32.cu
sm90_gemm_s8_s8_s8_alignx_tensor_op_s32_warpspecialized.cu
sm90_gemm_s8_s8_s8_alignx_tensor_op_s32_warpspecialized_cooperative.cu
sm90_gemm_s8_s8_s8_alignx_tensor_op_s32_warpspecialized_pingpong.cu
sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32.cu
sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32_warpspecialized.cu
sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32_warpspecialized_cooperative.cu
sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32_warpspecialized_pingpong.cu
)
# Fused epilogue tests
@ -298,7 +328,6 @@ cutlass_test_unit_add_executable(
sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_dag.cu
sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_dag.cu
)
cutlass_test_unit_add_executable(
cutlass_test_unit_gemm_device_tensorop_cluster_multicast_sm90
@ -311,7 +340,6 @@ cutlass_test_unit_add_executable(
sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu
sm90_gemm_f8_f8_f32_tensor_op_f32_cluster_warpspecialized_cooperative.cu
)
cutlass_test_unit_add_executable(
cutlass_test_unit_gemm_device_tensorop_gmma_rs_warpspecialized_sm90
@ -319,6 +347,8 @@ cutlass_test_unit_add_executable(
BATCH_SIZE 4
sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu
sm90_gemm_f8_f8_f32_tensor_op_f32_rs_cluster_warpspecialized_cooperative.cu
sm90_gemm_f16_f16_f32_tensor_op_f32_rs_cluster_warpspecialized_cooperative.cu
)
cutlass_test_unit_add_executable(
@ -341,23 +371,6 @@ cutlass_test_unit_add_executable(
sm80_gemm_f16_f16_f32_tensor_op_f32.cu
)
cutlass_test_unit_add_executable(
cutlass_test_unit_gemm_device_mixed_input_tensorop_sm80
BATCH_SOURCES ON
BATCH_SIZE 4
# Upcast on Operand A
gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu
gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu
gemm_universal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32_sm80.cu
# Upcast on Operand B
gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f16_sm80.cu
gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f16_sm80.cu
gemm_universal_bf16t_s8n_bf16t_mixed_input_tensor_op_f32_sm80.cu
)
cutlass_test_unit_add_executable(
cutlass_test_unit_gemm_device_tensorop_f64

View File

@ -49,16 +49,18 @@
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "testbed_utils.h"
#include "cutlass/kernel_hardware_info.hpp"
#include "cutlass/layout/matrix.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/fast_math.h"
#include "cutlass/platform/platform.h"
#include "cutlass/epilogue/fusion/operations.hpp"
#include "cute/int_tuple.hpp"
#include "cute/layout.hpp"
namespace test {
namespace gemm {
@ -68,9 +70,9 @@ namespace device {
namespace detail{
// Helper classes that take default data type when
// Helper classes that take default data type when
// the Gemm::EpilogueOutputOp does not have ElementCompute
// and ElementScalar.
// and ElementScalar.
// (e.g. when Sm90TreeVisitor is used as FusionCallbacks)
template <typename Gemm, typename Default, typename = void>
struct ElementComputeType {
@ -138,6 +140,34 @@ private:
int iterations_ = 20;
};
// The maxium swizzle size to use
//
// This class, like Splits above makes it harder to confuse
// the order of arguments of the various run(...) functions in this file.
class MaxSwizzleSize {
public:
MaxSwizzleSize() = default;
template<class IntegralNotBool,
__CUTE_REQUIRES((std::is_integral_v<IntegralNotBool> &&
!std::is_same_v<IntegralNotBool, bool>)) >
explicit MaxSwizzleSize(IntegralNotBool max_swizzle_size) : max_swizzle_size_(max_swizzle_size) {}
explicit operator int() const { return max_swizzle_size_; }
private:
int max_swizzle_size_ = 1;
};
template <typename T>
auto make_iterator(T* ptr) {
using namespace cute;
if constexpr (is_subbyte_v<T>) {
return subbyte_iterator<T>(ptr);
}
else {
return ptr;
}
}
template <
typename Gemm,
template <class T> class ActivationFunctor_ = cutlass::epilogue::thread::Identity
@ -161,6 +191,8 @@ struct TestbedImpl {
using ElementScalar = typename ElementScalarType<Gemm, ElementCompute>::Type;
using ActivationFunctor = ActivationFunctor_<ElementCompute>;
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions;
static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
@ -190,6 +222,7 @@ struct TestbedImpl {
using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t<StrideB>;
using LayoutTagC = cutlass::detail::StrideToLayoutTagA_t<StrideC>;
using LayoutTagD = cutlass::detail::StrideToLayoutTagA_t<StrideD>;
using LayoutTagVector = cutlass::layout::PackedVectorLayout;
/// Initialization
StrideA stride_a;
@ -323,10 +356,10 @@ struct TestbedImpl {
// 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode
auto a_coord = cutlass::make_Coord(M * L, K);
auto c_coord = cutlass::make_Coord(M * L, N);
// Cutlass has Row/Col major refers to MxK times KxN matrix product,
// Cutlass has Row/Col major refers to MxK times KxN matrix product,
// so the HostTensorB should be treated as KxN in "coord"'s view
auto b_coord = cutlass::make_Coord(K, N * L);
tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagA>::layout_factory(a_coord, stride_factor_A));
tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagB>::layout_factory(b_coord, stride_factor_B));
@ -387,7 +420,7 @@ struct TestbedImpl {
std::ofstream file(fname.str());
file
<< "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << L
<< ", alpha: " << float(alpha) << ", beta: " << float(beta) << "\n\n";
<< ", alpha: " << alpha << ", beta: " << beta << "\n\n";
file
<< "A =\n" << tensor_A.host_view()
@ -404,7 +437,7 @@ struct TestbedImpl {
bool verify(
ProblemShapeType problem_size,
ElementScalar alpha,
ElementScalar beta)
ElementScalar beta)
{
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
auto M = cute::size<0>(problem_shape_MNKL);
@ -412,13 +445,13 @@ struct TestbedImpl {
auto K = cute::size<2>(problem_shape_MNKL);
auto L = cute::size<3>(problem_shape_MNKL);
auto A = cute::make_tensor(tensor_A.host_data(),
auto A = cute::make_tensor(detail::make_iterator(tensor_A.host_data()),
cute::make_layout(cute::make_shape(M, K, L), stride_a));
auto B = cute::make_tensor(tensor_B.host_data(),
auto B = cute::make_tensor(detail::make_iterator(tensor_B.host_data()),
cute::make_layout(cute::make_shape(N, K, L), stride_b));
auto C = cute::make_tensor(tensor_C.host_data(),
auto C = cute::make_tensor(detail::make_iterator(tensor_C.host_data()),
cute::make_layout(cute::make_shape(M, N, L), stride_c));
auto D = cute::make_tensor(reference_D.host_data(),
auto D = cute::make_tensor(detail::make_iterator(reference_D.host_data()),
cute::make_layout(cute::make_shape(M, N, L), stride_d));
auto Bias = cute::make_tensor(static_cast<ElementCompute*>(nullptr),
cute::make_layout(cute::make_shape(M, cute::_1{})));
@ -451,7 +484,6 @@ struct TestbedImpl {
};
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
return compare_reference(problem_shape_MNKL, alpha, beta);
}
@ -529,8 +561,10 @@ struct TestbedImpl {
ElementScalar alpha = ElementScalar(1),
ElementScalar beta = ElementScalar(0),
bool profiling = false,
detail::Iterations iterations = Iterations{},
detail::Splits splits = Splits{})
detail::Iterations iterations = detail::Iterations{},
RasterOrderOptions raster_order = RasterOrderOptions::Heuristic,
detail::MaxSwizzleSize max_swizzle = detail::MaxSwizzleSize{},
detail::Splits splits = detail::Splits{})
{
// Fail test if insufficient CUDA device
if (!sufficient()) {
@ -557,7 +591,10 @@ struct TestbedImpl {
typename Gemm::GemmKernel::TileScheduler::Arguments scheduler_args;
if constexpr (std::is_same_v<typename Gemm::GemmKernel::TileSchedulerTag, cutlass::gemm::StreamKScheduler>) {
scheduler_args = { static_cast<int>(splits) };
scheduler_args = { static_cast<int>(splits), static_cast<int>(max_swizzle), raster_order };
}
else {
scheduler_args = { static_cast<int>(max_swizzle), raster_order };
}
// DefaultEpilogue
@ -613,7 +650,7 @@ struct TestbedImpl {
//
bool passed = this->verify(problem_size, alpha, beta);
if (!passed) {
std::cout << "Error : Failed : with alpha: " << float(alpha) << ", beta: " << float(beta)
std::cout << "Error : Failed : with alpha: " << alpha << ", beta: " << beta
<< "\n";
}
@ -648,6 +685,8 @@ struct Testbed3x {
using LayoutTagC = typename TestBedImpl::LayoutTagC;
using LayoutTagD = typename TestBedImpl::LayoutTagD;
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions;
// Detail Implementation
TestBedImpl impl_;
@ -661,7 +700,7 @@ struct Testbed3x {
uint64_t seed_ = TestBedImpl::kDefaultSeed)
: impl_(init_A_, init_B_, init_C_, seed_) {}
Testbed3x(
Testbed3x(
typename LayoutTagA::Stride stride_factor_A_,
typename LayoutTagB::Stride stride_factor_B_,
typename LayoutTagC::Stride stride_factor_C_,
@ -684,12 +723,14 @@ struct Testbed3x {
typename TestBedImpl::ProblemShapeType problem_size,
ElementScalar alpha = ElementScalar(1),
ElementScalar beta = ElementScalar(0),
RasterOrderOptions raster_order = RasterOrderOptions::Heuristic,
detail::MaxSwizzleSize max_swizzle = detail::MaxSwizzleSize{},
detail::Splits splits = detail::Splits{},
bool profiling = false,
detail::Iterations iterations = detail::Iterations{})
{
return impl_.run(
problem_size, alpha, beta, profiling, iterations, splits
problem_size, alpha, beta, profiling, iterations, raster_order, max_swizzle, splits
);
}
};
@ -722,13 +763,15 @@ struct Testbed3xFusionOperation {
using StrideD = typename Kernel::StrideD;
using ProblemShapeType = typename Kernel::ProblemShape;
using ElementAccumulator = typename Kernel::ElementAccumulator;
//
// FusionOperation derived types/queries
//
using FusionOp = typename Gemm::EpilogueOutputOp;
static_assert(cute::is_base_of_v<cutlass::epilogue::fusion::FusionOperation, FusionOp>);
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions;
// fusion types are potentially void if the fusion is not supported
// helper so we don't try to construct HostTensor with void type
template <typename T, typename U = uint8_t>
@ -744,11 +787,17 @@ struct Testbed3xFusionOperation {
cutlass::epilogue::thread::Identity<ElementCompute>>;
static constexpr bool IsBiasEnabled = FusionOp::IsPerRowBiasSupported;
static constexpr bool IsDeBiasEnabled = FusionOp::IsDePerRowBiasSupported;
static constexpr bool IsPerRowScaleEnabled = FusionOp::IsPerRowScaleSupported;
static constexpr bool IsScaleFactorEnabled = FusionOp::IsScaleFactorSupported;
static constexpr bool IsAuxEnabled = FusionOp::IsAuxOutSupported;
static constexpr bool IsAbsMaxEnabled = FusionOp::IsAbsMaxSupported;
static constexpr bool IsAuxInEnabled = FusionOp::IsAuxInSupported;
static constexpr bool IsAuxOutEnabled = FusionOp::IsAuxOutSupported;
static constexpr bool IsAbsMaxEnabledD = FusionOp::IsAbsMaxSupported &&
(cute::is_same_v<ElementD, cutlass::float_e4m3_t> ||
cute::is_same_v<ElementD, cutlass::float_e5m2_t>);
static constexpr bool IsAbsMaxEnabledAux = IsAuxOutEnabled && FusionOp::IsAbsMaxSupported &&
(cute::is_same_v<ElementAux, cutlass::float_e4m3_t> ||
cute::is_same_v<ElementAux, cutlass::float_e5m2_t>);
// Legacy support for deprecated bias-elementwise collective, will be removed next release
using EpiloguePolicy = typename Epilogue::DispatchPolicy;
static constexpr bool IsLegacy =
@ -773,6 +822,7 @@ struct Testbed3xFusionOperation {
cutlass::HostTensor<ElementAux , LayoutTagAux > tensor_Aux;
cutlass::gemm::TagToStrideC_t< LayoutTagAux > stride_Aux;
// References
cutlass::HostTensor<ElementBias, LayoutTagVector> reference_dbias;
cutlass::HostTensor<ElementAux , LayoutTagAux > reference_Aux;
cutlass::HostTensor<ElementAmax, LayoutTagScalar> reference_abs_max_Aux;
cutlass::HostTensor<ElementAmax, LayoutTagScalar> reference_abs_max_D;
@ -791,12 +841,6 @@ struct Testbed3xFusionOperation {
// Random distribution with which to initialize the bias vector
cutlass::Distribution::Kind init_bias = cutlass::Distribution::Uniform;
// Factors used for calculating relative equality. These default
// values are borrowed from those used by default in the CUTLASS
// profiler for performing relative equality checks.
float epsilon = 0.05f;
float nonzero_floor = 1.0f / 256.0f;
//
// Methods
//
@ -853,7 +897,7 @@ struct Testbed3xFusionOperation {
else {
beta.resize(col_vector_coord);
EXPECT_TRUE(impl_.initialize_tensor(beta.host_view(), init_scale, impl_.seed + 2024));
}
}
}
else {
alpha.resize(scalar_coord, use_device_scalars);
@ -885,13 +929,34 @@ struct Testbed3xFusionOperation {
bias.sync_device();
}
if constexpr (IsAbsMaxEnabled) {
abs_max_D.resize(scalar_coord);
abs_max_D.sync_device();
reference_abs_max_D.resize(scalar_coord);
if constexpr (IsDeBiasEnabled) {
bias.resize(col_vector_coord);
reference_dbias.resize(col_vector_coord);
cutlass::reference::host::TensorFill(bias.host_view(), ElementBias(0));
cutlass::reference::host::TensorFill(reference_dbias.host_view(), ElementBias(0));
bias.sync_device();
}
if constexpr (IsAuxEnabled) {
if constexpr (IsAbsMaxEnabledD) {
abs_max_D.resize(scalar_coord);
// ensure in-place device reductions perform their own initialization
cutlass::reference::host::TensorFill(abs_max_D.host_view(),
CUTLASS_STL_NAMESPACE::numeric_limits<ElementAmax>::max());
abs_max_D.sync_device();
reference_abs_max_D.resize(scalar_coord);
cutlass::reference::host::TensorFill(reference_abs_max_D.host_view(), ElementAmax(0));
}
if constexpr (IsAuxInEnabled) {
auto aux_coord = cutlass::make_Coord(M * L, N);
auto aux_layout = cutlass::layout::Affine2Layout_Factory<LayoutTagD>::layout_factory(aux_coord, typename LayoutTagAux::Stride{});
tensor_Aux.resize(aux_coord, aux_layout);
EXPECT_TRUE(impl_.initialize_tensor(tensor_Aux.host_view(), impl_.init_C, impl_.seed + 2023));
tensor_Aux.sync_device();
stride_Aux = cutlass::make_cute_packed_stride(cutlass::gemm::TagToStrideC_t<LayoutTagAux>{}, cute::make_shape(M, N, L));
}
if constexpr (IsAuxOutEnabled) {
auto aux_coord = cutlass::make_Coord(M * L, N);
auto aux_layout = cutlass::layout::Affine2Layout_Factory<LayoutTagD>::layout_factory(aux_coord, typename LayoutTagAux::Stride{});
tensor_Aux.resize(aux_coord, aux_layout);
@ -905,10 +970,14 @@ struct Testbed3xFusionOperation {
scale_Aux.sync_device();
}
if constexpr (IsAbsMaxEnabled) {
if constexpr (IsAbsMaxEnabledAux) {
abs_max_Aux.resize(scalar_coord);
// ensure in-place device reductions perform their own initialization
cutlass::reference::host::TensorFill(abs_max_Aux.host_view(),
CUTLASS_STL_NAMESPACE::numeric_limits<ElementAmax>::max());
abs_max_Aux.sync_device();
reference_abs_max_Aux.resize(scalar_coord);
cutlass::reference::host::TensorFill(reference_abs_max_Aux.host_view(), ElementAmax(0));
}
}
@ -922,9 +991,16 @@ struct Testbed3xFusionOperation {
cutlass::TensorView<Element, Layout> const& lhs,
cutlass::TensorView<Element, Layout> const& rhs) const {
// Factors used for calculating relative equality. CUTLASS's relative-equality
// checks in include/cutlass/relatively_equal.h are inspired by
// https://floating-point-gui.de/errors/comparison/. This reference suggests using
// the minimum normal value of a given type as the nonzero_floor.
Element epsilon(0.1f);
Element nonzero_floor(std::numeric_limits<Element>::min());
if (check_relative_equality) {
return cutlass::reference::host::TensorRelativelyEquals(
lhs, rhs, Element(epsilon), Element(nonzero_floor));
lhs, rhs, epsilon, nonzero_floor);
}
else {
return cutlass::reference::host::TensorEquals(lhs, rhs);
@ -933,6 +1009,7 @@ struct Testbed3xFusionOperation {
/// Compares computed reference with device reference and outputs to a file if incorrect
bool compare_reference(cute::Shape<int,int,int,int> problem_shape_MNKL) {
auto [M, N, K, L] = problem_shape_MNKL;
auto coord_0 = cutlass::make_Coord(0);
@ -947,17 +1024,24 @@ struct Testbed3xFusionOperation {
}
bool passed = equality_check(impl_.reference_D.host_view(), impl_.tensor_D.host_view());
if constexpr (IsAbsMaxEnabled) {
if constexpr (IsAbsMaxEnabledD) {
abs_max_D.sync_host();
passed &= equality_check(reference_abs_max_D.host_view(), abs_max_D.host_view());
}
if constexpr (IsAuxEnabled) {
if constexpr (IsDeBiasEnabled) {
bias.sync_host();
EXPECT_GT(cutlass::reference::host::TensorNorm(bias.host_view()), 0);
EXPECT_GT(cutlass::reference::host::TensorNorm(reference_dbias.host_view()), 0);
passed &= equality_check(reference_dbias.host_view(), bias.host_view());
}
if constexpr (IsAuxOutEnabled) {
tensor_Aux.sync_host();
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Aux.host_view()), 0);
EXPECT_GT(cutlass::reference::host::TensorNorm(reference_Aux.host_view()), 0);
passed &= equality_check(reference_Aux.host_view(), tensor_Aux.host_view());
if constexpr (IsAbsMaxEnabled) {
if constexpr (IsAbsMaxEnabledAux) {
abs_max_Aux.sync_host();
passed &= equality_check(reference_abs_max_Aux.host_view(), abs_max_Aux.host_view());
}
@ -990,7 +1074,7 @@ struct Testbed3xFusionOperation {
}
file << "\n\n";
if constexpr (IsAbsMaxEnabled) {
if constexpr (IsAbsMaxEnabledD) {
file << "scale_d: " << float(scale_D.at(coord_0));
file << "\nReference abs_max_D :";
file << " " << float(reference_abs_max_D.at(coord_0));
@ -998,15 +1082,16 @@ struct Testbed3xFusionOperation {
file << "\nComputed abs_max_D :";
file << " " << float(abs_max_D.at(coord_0));
file << "\n\n";
if constexpr (IsAuxEnabled) {
file << "scale_aux: " << float(scale_Aux.at(coord_0));
file << "\nReference abs_max_Aux :";
file << " " << float(reference_abs_max_Aux.at(coord_0));
}
file << "\nComputed abs_max_Aux :";
file << " " << float(abs_max_Aux.at(coord_0));
file << "\n\n";
}
if constexpr (IsAbsMaxEnabledAux) {
file << "scale_aux: " << float(scale_Aux.at(coord_0));
file << "\nReference abs_max_Aux :";
file << " " << float(reference_abs_max_Aux.at(coord_0));
file << "\nComputed abs_max_Aux :";
file << " " << float(abs_max_Aux.at(coord_0));
file << "\n\n";
}
file
@ -1018,7 +1103,16 @@ struct Testbed3xFusionOperation {
file << "\n\nBias = \n" << bias.host_view();
}
if constexpr (IsAuxEnabled) {
if constexpr (IsAuxInEnabled) {
file << "\n\nAux Input = \n" << tensor_Aux.host_view();
}
if constexpr (IsDeBiasEnabled) {
file << "\n\nReference dBias = \n" << reference_dbias.host_view();
file << "\n\nComputed dBias = \n" << bias.host_view();
}
if constexpr (IsAuxOutEnabled) {
file
<< "\n\nReference Aux =\n" << reference_Aux.host_view()
<< "\n\nComputed Aux =\n" << tensor_Aux.host_view();
@ -1041,21 +1135,21 @@ struct Testbed3xFusionOperation {
auto L = cute::get<3>(problem_shape_MNKL);
auto coord_0 = cutlass::make_Coord(0);
auto A = cute::make_tensor(impl_.tensor_A.host_data(),
auto A = cute::make_tensor(detail::make_iterator(impl_.tensor_A.host_data()),
cute::make_layout(cute::make_shape(M, K, L), impl_.stride_a));
auto B = cute::make_tensor(impl_.tensor_B.host_data(),
auto B = cute::make_tensor(detail::make_iterator(impl_.tensor_B.host_data()),
cute::make_layout(cute::make_shape(N, K, L), impl_.stride_b));
auto C = cute::make_tensor(impl_.tensor_C.host_data(),
auto C = cute::make_tensor(detail::make_iterator(impl_.tensor_C.host_data()),
cute::make_layout(cute::make_shape(M, N, L), impl_.stride_c));
auto D = cute::make_tensor(impl_.reference_D.host_data(),
auto D = cute::make_tensor(detail::make_iterator(impl_.reference_D.host_data()),
cute::make_layout(cute::make_shape(M, N, L), impl_.stride_d));
auto Bias = cute::make_tensor(bias.host_data(),
auto Bias = cute::make_tensor(detail::make_iterator(IsDeBiasEnabled ? reference_dbias.host_data() : bias.host_data()),
cute::make_layout(cute::make_shape(M, cute::_1{})));
auto Aux = cute::make_tensor(reference_Aux.host_data(),
auto Aux = cute::make_tensor(detail::make_iterator(IsAuxInEnabled ? tensor_Aux.host_data() : reference_Aux.host_data()),
cute::make_layout(cute::make_shape(M, N, L), stride_Aux));
auto Valpha = cute::make_tensor(alpha.host_data(),
auto Valpha = cute::make_tensor(detail::make_iterator(alpha.host_data()),
cute::make_layout(cute::make_shape(M, cute::_1{})));
auto Vbeta = cute::make_tensor(beta.host_data(),
auto Vbeta = cute::make_tensor(detail::make_iterator(beta.host_data()),
cute::make_layout(cute::make_shape(M, cute::_1{})));
cutlass::reference::host::GettMainloopParams<ElementAccumulator, decltype(A), decltype(B)> mainloop_params{A, B};
@ -1086,20 +1180,24 @@ struct Testbed3xFusionOperation {
epilogue_params.scale_d = scale_D.at(coord_0);
}
if constexpr (IsBiasEnabled) {
if constexpr (IsBiasEnabled or IsDeBiasEnabled) {
epilogue_params.Bias = Bias;
}
if constexpr (IsAbsMaxEnabled) {
if constexpr (IsAbsMaxEnabledD) {
epilogue_params.abs_max_D = reference_abs_max_D.host_data();
}
if constexpr (IsAuxEnabled) {
if constexpr (IsAuxInEnabled) {
epilogue_params.Aux = Aux;
}
if constexpr (IsAuxOutEnabled) {
epilogue_params.Aux = Aux;
if constexpr (IsScaleFactorEnabled) {
epilogue_params.scale_aux = scale_Aux.at(coord_0);
}
if constexpr (IsAbsMaxEnabled) {
if constexpr (IsAbsMaxEnabledAux) {
epilogue_params.abs_max_Aux = reference_abs_max_Aux.host_data();
}
}
@ -1121,6 +1219,8 @@ struct Testbed3xFusionOperation {
ProblemShapeType problem_size,
ElementScalar alpha_ = ElementScalar(1),
ElementScalar beta_ = ElementScalar(0),
RasterOrderOptions raster_order = RasterOrderOptions::Heuristic,
detail::MaxSwizzleSize max_swizzle = detail::MaxSwizzleSize{},
detail::Splits splits = detail::Splits{},
bool profiling = false,
detail::Iterations iterations = detail::Iterations{})
@ -1136,6 +1236,8 @@ struct Testbed3xFusionOperation {
typename Gemm::Arguments arguments;
cutlass::KernelHardwareInfo hw_info;
cudaDeviceProp prop;
hw_info.device_id = 0;
if (not profiling) {
impl_.sm_count = min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id));
@ -1146,6 +1248,8 @@ struct Testbed3xFusionOperation {
hw_info.sm_count = impl_.sm_count;
}
cudaGetDeviceProperties(&prop, hw_info.device_id);
/// Initializes data structures
/// A/B/C/D Tensor
initialize(problem_size, alpha_, beta_);
@ -1172,7 +1276,7 @@ struct Testbed3xFusionOperation {
hw_info,
scheduler_args
};
auto coord_0 = cutlass::make_Coord(0);
if constexpr (IsLegacy) {
arguments.epilogue.thread = {
@ -1186,7 +1290,7 @@ struct Testbed3xFusionOperation {
}
else {
auto &fusion_args = arguments.epilogue.thread;
fusion_args.alpha = alpha.at(coord_0);
fusion_args.beta = beta.at(coord_0);
fusion_args.alpha_ptr = alpha.device_data();
@ -1207,6 +1311,10 @@ struct Testbed3xFusionOperation {
fusion_args.bias_ptr = bias.device_data();
}
if constexpr (IsDeBiasEnabled) {
fusion_args.dbias_ptr = bias.device_data();
}
// example of how to set kernel activation arguments
if constexpr (cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::ScaledGELU_taylor<ElementCompute>>) {
// see ActivationFunctor::Arguments in activation.h for definition
@ -1214,18 +1322,23 @@ struct Testbed3xFusionOperation {
fusion_args.activation.scale = ElementCompute(1);
}
if constexpr (IsAbsMaxEnabled) {
if constexpr (IsAbsMaxEnabledD) {
fusion_args.amax_D_ptr = abs_max_D.device_data();
}
if constexpr (IsAuxEnabled) {
if constexpr (IsAuxInEnabled) {
fusion_args.aux_ptr = tensor_Aux.device_data();
fusion_args.dAux = stride_Aux;
}
if constexpr (IsAuxOutEnabled) {
fusion_args.aux_ptr = tensor_Aux.device_data();
fusion_args.dAux = stride_Aux;
if constexpr (IsScaleFactorEnabled) {
fusion_args.scale_aux = scale_Aux.at(coord_0);
fusion_args.scale_aux_ptr = scale_Aux.device_data();
}
if constexpr (IsAbsMaxEnabled) {
if constexpr (IsAbsMaxEnabledAux) {
fusion_args.amax_aux_ptr = abs_max_Aux.device_data();
}
}
@ -1277,6 +1390,7 @@ struct Testbed3xFusionOperation {
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
@ -1311,29 +1425,39 @@ bool TestAll(double alpha = 1.0, double beta = 0.0, Testbed testbed = {}) {
problem_splits.push_back(Stages + 1);
}
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions;
std::vector<RasterOrderOptions> raster_orders = {RasterOrderOptions::AlongM, RasterOrderOptions::AlongN};
std::vector<int> max_swizzle_sizes = {1, 4};
bool passed = true;
for (int m : problem_size_m) {
for (int n : problem_size_n) {
for (int k : problem_size_k) {
for (int splits : problem_splits) {
ProblemShapeType problem_size;
if constexpr (cute::rank(ProblemShapeType{}) == 4) {
problem_size = ProblemShapeType{m, n, k, /* l */ 1};
}
else {
problem_size = ProblemShapeType{m, n, k};
}
for (auto raster_order : raster_orders) {
for (int max_swizzle_size : max_swizzle_sizes) {
for (int splits : problem_splits) {
ProblemShapeType problem_size;
if constexpr (cute::rank(ProblemShapeType{}) == 4) {
problem_size = ProblemShapeType{m, n, k, /* l */ 1};
}
else {
problem_size = ProblemShapeType{m, n, k};
}
passed = testbed.run(
problem_size,
cutlass::from_real<ElementScalar>(alpha),
cutlass::from_real<ElementScalar>(beta),
detail::Splits(splits)
);
passed = testbed.run(
problem_size,
cutlass::from_real<ElementScalar>(alpha),
cutlass::from_real<ElementScalar>(beta),
raster_order,
detail::MaxSwizzleSize(max_swizzle_size),
detail::Splits(splits)
);
if (!passed) {
return false;
if (!passed) {
return false;
}
}
}
}
}

View File

@ -275,4 +275,4 @@ TEST(SM80_Device_GemmUniversal_bf16t_s8n_bf16t_mixed_input_tensor_op_f32, 16x128
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////

View File

@ -94,4 +94,4 @@ TEST(SM80_Device_GemmUniversal_f16t_s8n_f16t_mixed_input_tensor_op_f16, 128x128x
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////

View File

@ -94,4 +94,4 @@ TEST(SM80_Device_GemmUniversal_f16t_u8t_f16t_mixed_input_tensor_op_f16, 128x128x
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////

View File

@ -381,4 +381,4 @@ TEST(SM80_Device_GemmUniversal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32, 128x16
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////

View File

@ -94,4 +94,4 @@ TEST(SM80_Device_GemmUniversal_s8t_f16n_f16t_mixed_input_tensor_op_f16, 128x128x
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////

View File

@ -94,4 +94,4 @@ TEST(SM80_Device_GemmUniversal_u8t_f16t_f16t_mixed_input_tensor_op_f16, 128x128x
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////

View File

@ -29,7 +29,7 @@
*
**************************************************************************************************/
/*! \file
\brief Host reference and operations for Sm90 EVT unit test
\brief Host reference and operations for Sm90 EVT unit test
*/
#pragma once
#include "gemm_testbed_3x_evt.hpp"
@ -53,10 +53,10 @@ public:
using ScalarAlpha = HostScalarBroadcast<Gemm, 1>;
using AccFetchNode = HostAccumulator<Gemm>;
using AuxLoadNode = HostAuxLoad<Gemm, false, ElementAux, LayoutAux>;
using TernaryCompute0 = HEVT<HostCompute<Gemm, cutlass::multiply_add>, ScalarAlpha, AccFetchNode, AuxLoadNode>;
using TernaryCompute0 = HEVT<HostCompute<Gemm, cutlass::homogeneous_multiply_add>, ScalarAlpha, AccFetchNode, AuxLoadNode>;
using ScalarBeta = HostScalarBroadcast<Gemm, 1>;
using CLoadNode = HostAuxLoad<Gemm, true>;
using TernaryCompute1 = HEVT<HostCompute<Gemm, cutlass::multiply_add>, ScalarBeta, CLoadNode, TernaryCompute0>;
using TernaryCompute1 = HEVT<HostCompute<Gemm, cutlass::homogeneous_multiply_add>, ScalarBeta, CLoadNode, TernaryCompute0>;
using EVTModule = HEVT<HostAuxStore<Gemm, true>, TernaryCompute1>;
};
@ -67,10 +67,10 @@ public:
using ScalarAlpha = HostScalarBroadcast<Gemm, 1>;
using AccFetchNode = HostAccumulator<Gemm>;
using RowBroadcastNode = HostRowBroadcast<Gemm, ElementBias>;
using TernaryCompute0 = HEVT<HostCompute<Gemm, cutlass::multiply_add>, ScalarAlpha, AccFetchNode, RowBroadcastNode>;
using TernaryCompute0 = HEVT<HostCompute<Gemm, cutlass::homogeneous_multiply_add>, ScalarAlpha, AccFetchNode, RowBroadcastNode>;
using ScalarBeta = HostScalarBroadcast<Gemm, 1>;
using CLoadNode = HostAuxLoad<Gemm, true>;
using TernaryCompute1 = HEVT<HostCompute<Gemm, cutlass::multiply_add>, ScalarBeta, CLoadNode, TernaryCompute0>;
using TernaryCompute1 = HEVT<HostCompute<Gemm, cutlass::homogeneous_multiply_add>, ScalarBeta, CLoadNode, TernaryCompute0>;
using EVTModule = HEVT<HostAuxStore<Gemm, true>, TernaryCompute1>;
};
@ -95,13 +95,13 @@ public:
ScalarAlpha,
AccFetchNode,
AuxLoadNode,
HostCompute<Gemm, cutlass::multiply_add>,
HostCompute<Gemm, cutlass::homogeneous_multiply_add>,
HostCompute<Gemm, cutlass::epilogue::thread::ReLu>,
HostCompute<Gemm, cutlass::plus>
>;
using ScalarBeta = HostScalarBroadcast<Gemm, 1>;
using CLoadNode = HostAuxLoad<Gemm, true>;
using TernaryCompute1 = HEVT<HostCompute<Gemm, cutlass::multiply_add>, ScalarBeta, CLoadNode, DAGNode>;
using TernaryCompute1 = HEVT<HostCompute<Gemm, cutlass::homogeneous_multiply_add>, ScalarBeta, CLoadNode, DAGNode>;
using EVTModule = HEVT<HostAuxStore<Gemm, true>, TernaryCompute1>;
};
@ -114,7 +114,7 @@ public:
using EVTNode = HEVT<
HostAuxStore<Gemm, false, cutlass::half_t, cutlass::layout::RowMajor>,
HEVT<
HostCompute<Gemm, cutlass::multiply_add>,
HostCompute<Gemm, cutlass::homogeneous_multiply_add>,
HostScalarBroadcast<Gemm, 2>,
HostAccumulator<Gemm>,
HostAuxLoad<Gemm, true>
@ -133,7 +133,7 @@ public:
EVTNode,
HostColBroadcast<Gemm, cutlass::half_t>,
HostCompute<Gemm, cutlass::plus>,
HostCompute<Gemm, cutlass::maximum>
HostCompute<Gemm, cutlass::maximum_with_default_nan_propagation>
>
>;
};
@ -147,13 +147,13 @@ public:
using BinaryCompute0 = HEVT<HostCompute<Gemm, cutlass::multiplies>, ScalarAlpha, AccFetchNode>;
using ScalarBeta = HostScalarBroadcast<Gemm, 1>;
using CLoadNode = HostAuxLoad<Gemm, true>;
using TernaryCompute1 = HEVT<HostCompute<Gemm, cutlass::multiply_add>, ScalarBeta, CLoadNode, BinaryCompute0>;
using TernaryCompute1 = HEVT<HostCompute<Gemm, cutlass::homogeneous_multiply_add>, ScalarBeta, CLoadNode, BinaryCompute0>;
using ReduceNode = HEVT<ReduceOp<Gemm, cutlass::plus, float>, TernaryCompute1>;
using EVTModule = HEVT<HostAuxStore<Gemm, true>, ReduceNode>;
};
// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias
// if D is fp8
// if D is fp8
// D = scale_d * activation(Z)
// else
// D = activation(Z)
@ -167,11 +167,11 @@ public:
HEVT<
HostCompute<Gemm, ActivationFn>, // activation(Z)
HEVT<
HostCompute<Gemm, cutlass::multiply_add>,
HostCompute<Gemm, cutlass::homogeneous_multiply_add>,
HostScalarBroadcast<Gemm, 1, 2>, // scale_c * beta
HostAuxLoad<Gemm, true>, // C
HEVT<
HostCompute<Gemm, cutlass::multiply_add>,
HostCompute<Gemm, cutlass::homogeneous_multiply_add>,
HostScalarBroadcast<Gemm, 1, 3>, // scale_a * scale_b * alpha
HostAccumulator<Gemm>,
HostColBroadcast<Gemm, ElementD>,
@ -184,12 +184,12 @@ public:
};
// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias
// if D is fp8
// if D is fp8
// amax_d = max(abs(elements in activation(Z)))
// D = scale_d * activation(Z)
// else
// D = activation(Z)
// if Aux is fp8
// if Aux is fp8
// amax_aux = max(abs(elements in Z))
// Aux = scale_aux * Z
// else
@ -204,11 +204,11 @@ public:
HST<Gemm,
// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias
HEVT<
HostCompute<Gemm, cutlass::multiply_add>,
HostCompute<Gemm, cutlass::homogeneous_multiply_add>,
HostScalarBroadcast<Gemm, 1, 2>, // scale_c * beta
HostAuxLoad<Gemm, true>, // C
HEVT<
HostCompute<Gemm, cutlass::multiply_add>,
HostCompute<Gemm, cutlass::homogeneous_multiply_add>,
HostScalarBroadcast<Gemm, 1, 3>, // scale_a * scale_b * alpha
HostAccumulator<Gemm>,
HostColBroadcast<Gemm, ElementD>,
@ -218,7 +218,7 @@ public:
HEVT<
HostCompute<Gemm, cutlass::epilogue::fusion::detail::ScaleOutOp<ElementD>::Op>,
HEVT<
HostScalarReduce<Gemm, amax, float>,
HostScalarReduce<Gemm, amax, float>,
HEVT<
HostCompute<Gemm, ActivationFn>, //activation(Z) * scaled_d
HostAccumulator<Gemm>, // Z
@ -247,6 +247,13 @@ public:
namespace cutlass::epilogue {
namespace fusion {
namespace detail {
template <typename T>
struct maximum_with_default_nan_propagation : maximum<T> {};
} // namespace detail
//////////////////////////////////////////////////////////////////////////////
/// D = alpha * acc + beta * C + AuxLoad
template<
@ -258,16 +265,16 @@ template<
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
>
using Sm90LinCombAuxLoad =
Sm90EVT<Sm90Compute<multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + (alpha * acc + bias)
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + (alpha * acc + bias)
Sm90ScalarBroadcast<ElementScalar>, // beta
Sm90SrcFetch, // C
Sm90EVT<Sm90Compute<multiply_add, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc + bias
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc + bias
Sm90ScalarBroadcast<ElementScalar>, // alpha
Sm90AccFetch, // acc
Sm90AuxLoad<
AuxLoadDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile,
typename AuxLoadDescriptor::Element,
typename AuxLoadDescriptor::Stride, typename AuxLoadDescriptor::SmemLayoutAtom,
AuxLoadDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile,
typename AuxLoadDescriptor::Element,
typename AuxLoadDescriptor::Stride, typename AuxLoadDescriptor::SmemLayoutAtom,
typename AuxLoadDescriptor::CopyOpS2R // aux load
>
>
@ -286,7 +293,7 @@ template<
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
>
using Sm90LinCombEVTDAG =
Sm90EVT<Sm90Compute<multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + (alpha * acc + aux)
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + (alpha * acc + aux)
Sm90ScalarBroadcast<ElementScalar>, // beta
Sm90SrcFetch, // C
Sm90TopologicalVisitor<
@ -302,13 +309,13 @@ using Sm90LinCombEVTDAG =
Sm90ScalarBroadcast<ElementScalar>, // alpha
Sm90AccFetch, // acc
Sm90AuxLoad<
AuxLoadDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile,
typename AuxLoadDescriptor::Element, typename AuxLoadDescriptor::Stride,
AuxLoadDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile,
typename AuxLoadDescriptor::Element, typename AuxLoadDescriptor::Stride,
typename AuxLoadDescriptor::SmemLayoutAtom, typename AuxLoadDescriptor::CopyOpS2R>,
Sm90Compute<multiply_add, ElementCompute, ElementCompute, RoundStyle>,
Sm90Compute<homogeneous_multiply_add, ElementCompute, ElementCompute, RoundStyle>,
Sm90Compute<cutlass::epilogue::thread::ReLu, ElementCompute, ElementCompute, RoundStyle>,
Sm90Compute<plus, ElementCompute, ElementCompute, RoundStyle>
>
>
>;
@ -336,10 +343,10 @@ using Sm90LinCombDAGEVT =
>,
Sm90EVT<
Sm90AuxStore<
AuxStoreDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile,
AuxStoreDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile,
typename AuxStoreDescriptor::Element, RoundStyle, typename AuxStoreDescriptor::Stride,
typename AuxStoreDescriptor::SmemLayoutAtom, typename AuxStoreDescriptor::CopyOpR2S>,
Sm90EVT<Sm90Compute<multiply_add, ElementCompute, ElementCompute, RoundStyle>,
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementCompute, ElementCompute, RoundStyle>,
Sm90ScalarBroadcast<ElementScalar>,
Sm90AccFetch,
Sm90SrcFetch
@ -347,7 +354,7 @@ using Sm90LinCombDAGEVT =
>,
Sm90ColBroadcast<0, typename EpilogueDescriptor::TileShape, ElementBias>,
Sm90Compute<plus, ElementCompute, ElementCompute, RoundStyle>,
Sm90Compute<maximum, ElementOutput, ElementCompute, RoundStyle>
Sm90Compute<detail::maximum_with_default_nan_propagation, ElementOutput, ElementCompute, RoundStyle>
>;
@ -362,18 +369,18 @@ template<
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
>
using Sm90LinCombPerColumnBias =
Sm90EVT<Sm90Compute<multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + (alpha * acc + bias)
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + (alpha * acc + bias)
Sm90ScalarBroadcast<ElementScalar>, // beta
Sm90SrcFetch, // C
Sm90EVT<Sm90Compute<multiply_add, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc + bias
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc + bias
Sm90ScalarBroadcast<ElementScalar>, // alpha
Sm90AccFetch, // acc
Sm90RowBroadcast<
ceil_div(
EpilogueDescriptor::StagesC,
EpilogueDescriptor::StagesC,
size(shape_div(take<0, 2>(typename EpilogueDescriptor::TileShape{}), typename EpilogueDescriptor::EpilogueTile{}))
) + 1,
typename EpilogueDescriptor::TileShape,
) + 1,
typename EpilogueDescriptor::TileShape,
ElementBias
>
>
@ -385,7 +392,7 @@ using Sm90LinCombPerColumnBias =
template<
template <class> class RegReduceFn,
template <class> class GmemReduceFn,
class ElementReduce,
class ElementReduce,
class CtaTileShapeMNK,
class ElementOutput,
class ElementCompute,
@ -394,7 +401,7 @@ template<
>
using Sm90LinCombPerColumnReduce =
Sm90EVT<Sm90RowReduction<RegReduceFn, GmemReduceFn, 0, CtaTileShapeMNK, ElementReduce, ElementCompute, RoundStyle>, // per column reduce
Sm90EVT<Sm90Compute<multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + alpha * acc
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + alpha * acc
Sm90ScalarBroadcast<ElementScalar>, // beta
Sm90SrcFetch, // C
Sm90EVT<Sm90Compute<multiplies, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc
@ -410,7 +417,7 @@ using Sm90LinCombPerColumnReduce =
template<
template <class> class RegReduceFn,
template <class> class GmemReduceFn,
class ElementReduce,
class ElementReduce,
class CtaTileShapeMNK,
class ElementOutput,
class ElementCompute,
@ -419,7 +426,7 @@ template<
>
using Sm90LinCombPerRowReduce =
Sm90EVT<Sm90ColReduction<RegReduceFn, GmemReduceFn, 0, CtaTileShapeMNK, ElementReduce, ElementCompute, RoundStyle>, // per column reduce
Sm90EVT<Sm90Compute<multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + alpha * acc
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + alpha * acc
Sm90ScalarBroadcast<ElementScalar>, // beta
Sm90SrcFetch, // C
Sm90EVT<Sm90Compute<multiplies, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc
@ -435,7 +442,7 @@ using Sm90LinCombPerRowReduce =
template<
template <class> class RegReduceFn,
template <class> class GmemReduceFn,
class ElementReduce,
class ElementReduce,
class ElementOutput,
class ElementCompute,
class ElementScalar = ElementCompute,
@ -443,7 +450,7 @@ template<
>
using Sm90LinCombScalarReduce =
Sm90EVT<Sm90ScalarReduction<RegReduceFn, GmemReduceFn, ElementReduce, ElementCompute, RoundStyle>, // per column reduce
Sm90EVT<Sm90Compute<multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + alpha * acc
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + alpha * acc
Sm90ScalarBroadcast<ElementScalar>, // beta
Sm90SrcFetch, // C
Sm90EVT<Sm90Compute<multiplies, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc

View File

@ -58,44 +58,7 @@ using namespace cute;
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_bf16t_bf16t_bf16n_align8_tensor_op_gmma_f32, 64x128x64) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::bfloat16_t, LayoutA, 8,
cutlass::bfloat16_t, LayoutB, 8,
float,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::bfloat16_t, LayoutC, 8,
cutlass::bfloat16_t, LayoutC, 8,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align4_tensor_op_gmma_f32, 64x128x64) {
TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align4_tensor_op_gmma_f32, 128x128x64) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
@ -105,14 +68,14 @@ TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align4_tensor_op_gmma_f32, 64x128x64) {
cutlass::bfloat16_t, LayoutA, 4,
cutlass::bfloat16_t, LayoutB, 4,
float,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::bfloat16_t, LayoutC, 4,
@ -132,9 +95,9 @@ TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align4_tensor_op_gmma_f32, 64x128x64) {
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_bf16n_bf16t_bf16n_align2_tensor_op_gmma_f32, 64x128x64) {
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::RowMajor;
TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align2_tensor_op_gmma_f32, 128x128x64) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
@ -142,14 +105,14 @@ TEST(SM90_Device_Gemm_bf16n_bf16t_bf16n_align2_tensor_op_gmma_f32, 64x128x64) {
cutlass::bfloat16_t, LayoutA, 2,
cutlass::bfloat16_t, LayoutB, 2,
float,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::bfloat16_t, LayoutC, 2,
@ -169,41 +132,4 @@ TEST(SM90_Device_Gemm_bf16n_bf16t_bf16n_align2_tensor_op_gmma_f32, 64x128x64) {
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_bf16n_bf16n_bf16n_align8_tensor_op_gmma_f32, 64x128x64) {
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::bfloat16_t, LayoutA, 8,
cutlass::bfloat16_t, LayoutB, 8,
float,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::bfloat16_t, LayoutC, 8,
cutlass::bfloat16_t, LayoutC, 8,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@ -0,0 +1,172 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "../../common/cutlass_unit_test.h"
#include "gemm_testbed_3x.hpp"
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
using namespace cute;
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align8_tensor_op_gmma_f32_warpspecialized, 64x128x64) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::bfloat16_t, LayoutA, 8,
cutlass::bfloat16_t, LayoutB, 8,
float,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecialized
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::bfloat16_t, LayoutC, 8,
cutlass::bfloat16_t, LayoutC, 8,
cutlass::epilogue::NoSmemWarpSpecialized
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align4_tensor_op_gmma_f32_warpspecialized, 64x128x64) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::bfloat16_t, LayoutA, 4,
cutlass::bfloat16_t, LayoutB, 4,
float,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecialized
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::bfloat16_t, LayoutC, 4,
cutlass::bfloat16_t, LayoutC, 4,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align2_tensor_op_gmma_f32_warpspecialized, 64x128x64) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::bfloat16_t, LayoutA, 2,
cutlass::bfloat16_t, LayoutB, 2,
float,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecialized
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::bfloat16_t, LayoutC, 2,
cutlass::bfloat16_t, LayoutC, 2,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@ -0,0 +1,172 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "../../common/cutlass_unit_test.h"
#include "gemm_testbed_3x.hpp"
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
using namespace cute;
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align8_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::bfloat16_t, LayoutA, 8,
cutlass::bfloat16_t, LayoutB, 8,
float,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::bfloat16_t, LayoutC, 8,
cutlass::bfloat16_t, LayoutC, 8,
cutlass::epilogue::NoSmemWarpSpecialized
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align4_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::bfloat16_t, LayoutA, 4,
cutlass::bfloat16_t, LayoutB, 4,
float,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::bfloat16_t, LayoutC, 4,
cutlass::bfloat16_t, LayoutC, 4,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align2_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::bfloat16_t, LayoutA, 2,
cutlass::bfloat16_t, LayoutB, 2,
float,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::bfloat16_t, LayoutC, 2,
cutlass::bfloat16_t, LayoutC, 2,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@ -0,0 +1,172 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "../../common/cutlass_unit_test.h"
#include "gemm_testbed_3x.hpp"
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
using namespace cute;
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align8_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::bfloat16_t, LayoutA, 8,
cutlass::bfloat16_t, LayoutB, 8,
float,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::bfloat16_t, LayoutC, 8,
cutlass::bfloat16_t, LayoutC, 8,
cutlass::epilogue::NoSmemWarpSpecialized
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align4_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::bfloat16_t, LayoutA, 4,
cutlass::bfloat16_t, LayoutB, 4,
float,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::bfloat16_t, LayoutC, 4,
cutlass::bfloat16_t, LayoutC, 4,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align2_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::bfloat16_t, LayoutA, 2,
cutlass::bfloat16_t, LayoutB, 2,
float,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::bfloat16_t, LayoutC, 2,
cutlass::bfloat16_t, LayoutC, 2,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@ -0,0 +1,365 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "../../common/cutlass_unit_test.h"
#include "gemm_testbed_3x.hpp"
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
using namespace cute;
///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////// TT //////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16t_f16t_f16n_align4_tensor_op_gmma_f32, 128x128x64) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 4,
cutlass::half_t, LayoutB, 4,
float,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::half_t, LayoutC, 4,
cutlass::half_t, LayoutC, 4,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16t_f16t_f16n_align2_tensor_op_gmma_f32, 128x128x64) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 2,
cutlass::half_t, LayoutB, 2,
float,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::half_t, LayoutC, 2,
cutlass::half_t, LayoutC, 2,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////// TN //////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16t_f16n_f16n_align4_tensor_op_gmma_f32, 128x128x64) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 4,
cutlass::half_t, LayoutB, 4,
float,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::half_t, LayoutC, 4,
cutlass::half_t, LayoutC, 4,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16t_f16n_f16n_align2_tensor_op_gmma_f32, 128x128x64) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 2,
cutlass::half_t, LayoutB, 2,
float,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::half_t, LayoutC, 2,
cutlass::half_t, LayoutC, 2,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////// NT //////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16n_f16t_f16n_align4_tensor_op_gmma_f32, 128x128x64) {
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 4,
cutlass::half_t, LayoutB, 4,
float,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::half_t, LayoutC, 4,
cutlass::half_t, LayoutC, 4,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16n_f16t_f16n_align2_tensor_op_gmma_f32, 128x128x64) {
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 2,
cutlass::half_t, LayoutB, 2,
float,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::half_t, LayoutC, 2,
cutlass::half_t, LayoutC, 2,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////// NN //////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16n_f16n_f16n_align4_tensor_op_gmma_f32, 128x128x64) {
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 4,
cutlass::half_t, LayoutB, 4,
float,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::half_t, LayoutC, 4,
cutlass::half_t, LayoutC, 4,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16n_f16n_f16n_align2_tensor_op_gmma_f32, 128x128x64) {
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 2,
cutlass::half_t, LayoutB, 2,
float,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::half_t, LayoutC, 2,
cutlass::half_t, LayoutC, 2,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@ -60,7 +60,7 @@ using namespace cute;
///////////////////////////////////// TT //////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16t_f16t_f16n_align8_tensor_op_gmma_f32, 64x128x64) {
TEST(SM90_Device_Gemm_f16t_f16t_f16n_align8_tensor_op_gmma_f32_warpspecialized, 64x128x64) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::ColumnMajor;
@ -72,7 +72,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f16n_align8_tensor_op_gmma_f32, 64x128x64) {
float,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelMultistage
cutlass::gemm::KernelCpAsyncWarpSpecialized
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
@ -95,7 +95,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f16n_align8_tensor_op_gmma_f32, 64x128x64) {
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
TEST(SM90_Device_Gemm_f16t_f16t_f16n_align4_tensor_op_gmma_f32, 64x128x64) {
TEST(SM90_Device_Gemm_f16t_f16t_f16n_align4_tensor_op_gmma_f32_warpspecialized, 64x128x64) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::ColumnMajor;
@ -107,7 +107,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f16n_align4_tensor_op_gmma_f32, 64x128x64) {
float,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::collective::KernelScheduleAuto
cutlass::gemm::KernelCpAsyncWarpSpecialized
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
@ -131,7 +131,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f16n_align4_tensor_op_gmma_f32, 64x128x64) {
}
TEST(SM90_Device_Gemm_f16t_f16t_f16n_align2_tensor_op_gmma_f32, 64x128x64) {
TEST(SM90_Device_Gemm_f16t_f16t_f16n_align2_tensor_op_gmma_f32_warpspecialized, 64x128x64) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::ColumnMajor;
@ -143,7 +143,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f16n_align2_tensor_op_gmma_f32, 64x128x64) {
float,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::collective::KernelScheduleAuto
cutlass::gemm::KernelCpAsyncWarpSpecialized
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
@ -170,7 +170,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f16n_align2_tensor_op_gmma_f32, 64x128x64) {
///////////////////////////////////// TN //////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16t_f16n_f16n_align8_tensor_op_gmma_f32, 64x128x64) {
TEST(SM90_Device_Gemm_f16t_f16n_f16n_align8_tensor_op_gmma_f32_warpspecialized, 64x128x64) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
@ -182,7 +182,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_align8_tensor_op_gmma_f32, 64x128x64) {
float,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelMultistage
cutlass::gemm::KernelCpAsyncWarpSpecialized
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
@ -207,7 +207,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_align8_tensor_op_gmma_f32, 64x128x64) {
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16t_f16n_f16n_align4_tensor_op_gmma_f32, 64x128x64) {
TEST(SM90_Device_Gemm_f16t_f16n_f16n_align4_tensor_op_gmma_f32_warpspecialized, 64x128x64) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
@ -219,7 +219,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_align4_tensor_op_gmma_f32, 64x128x64) {
float,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::collective::KernelScheduleAuto
cutlass::gemm::KernelCpAsyncWarpSpecialized
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
@ -244,7 +244,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_align4_tensor_op_gmma_f32, 64x128x64) {
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16t_f16n_f16n_align2_tensor_op_gmma_f32, 64x128x64) {
TEST(SM90_Device_Gemm_f16t_f16n_f16n_align2_tensor_op_gmma_f32_warpspecialized, 64x128x64) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
@ -256,7 +256,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_align2_tensor_op_gmma_f32, 64x128x64) {
float,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::collective::KernelScheduleAuto
cutlass::gemm::KernelCpAsyncWarpSpecialized
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
@ -283,7 +283,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_align2_tensor_op_gmma_f32, 64x128x64) {
///////////////////////////////////// NT //////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16n_f16t_f16n_align8_tensor_op_gmma_f32, 64x128x64) {
TEST(SM90_Device_Gemm_f16n_f16t_f16n_align8_tensor_op_gmma_f32_warpspecialized, 64x128x64) {
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::ColumnMajor;
@ -295,7 +295,7 @@ TEST(SM90_Device_Gemm_f16n_f16t_f16n_align8_tensor_op_gmma_f32, 64x128x64) {
float,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelMultistage
cutlass::gemm::KernelCpAsyncWarpSpecialized
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
@ -320,7 +320,7 @@ TEST(SM90_Device_Gemm_f16n_f16t_f16n_align8_tensor_op_gmma_f32, 64x128x64) {
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16n_f16t_f16n_align4_tensor_op_gmma_f32, 64x128x64) {
TEST(SM90_Device_Gemm_f16n_f16t_f16n_align4_tensor_op_gmma_f32_warpspecialized, 64x128x64) {
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::ColumnMajor;
@ -332,7 +332,7 @@ TEST(SM90_Device_Gemm_f16n_f16t_f16n_align4_tensor_op_gmma_f32, 64x128x64) {
float,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::collective::KernelScheduleAuto
cutlass::gemm::KernelCpAsyncWarpSpecialized
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
@ -357,7 +357,7 @@ TEST(SM90_Device_Gemm_f16n_f16t_f16n_align4_tensor_op_gmma_f32, 64x128x64) {
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16n_f16t_f16n_align2_tensor_op_gmma_f32, 64x128x64) {
TEST(SM90_Device_Gemm_f16n_f16t_f16n_align2_tensor_op_gmma_f32_warpspecialized, 64x128x64) {
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::ColumnMajor;
@ -369,7 +369,7 @@ TEST(SM90_Device_Gemm_f16n_f16t_f16n_align2_tensor_op_gmma_f32, 64x128x64) {
float,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::collective::KernelScheduleAuto
cutlass::gemm::KernelCpAsyncWarpSpecialized
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
@ -396,7 +396,7 @@ TEST(SM90_Device_Gemm_f16n_f16t_f16n_align2_tensor_op_gmma_f32, 64x128x64) {
///////////////////////////////////// NN //////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16n_f16n_f16n_align8_tensor_op_gmma_f32, 64x128x64) {
TEST(SM90_Device_Gemm_f16n_f16n_f16n_align8_tensor_op_gmma_f32_warpspecialized, 64x128x64) {
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
@ -408,7 +408,7 @@ TEST(SM90_Device_Gemm_f16n_f16n_f16n_align8_tensor_op_gmma_f32, 64x128x64) {
float,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelMultistage
cutlass::gemm::KernelCpAsyncWarpSpecialized
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
@ -433,7 +433,7 @@ TEST(SM90_Device_Gemm_f16n_f16n_f16n_align8_tensor_op_gmma_f32, 64x128x64) {
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16n_f16n_f16n_align4_tensor_op_gmma_f32, 64x128x64) {
TEST(SM90_Device_Gemm_f16n_f16n_f16n_align4_tensor_op_gmma_f32_warpspecialized, 64x128x64) {
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
@ -445,7 +445,7 @@ TEST(SM90_Device_Gemm_f16n_f16n_f16n_align4_tensor_op_gmma_f32, 64x128x64) {
float,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::collective::KernelScheduleAuto
cutlass::gemm::KernelCpAsyncWarpSpecialized
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
@ -470,7 +470,7 @@ TEST(SM90_Device_Gemm_f16n_f16n_f16n_align4_tensor_op_gmma_f32, 64x128x64) {
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16n_f16n_f16n_align2_tensor_op_gmma_f32, 64x128x64) {
TEST(SM90_Device_Gemm_f16n_f16n_f16n_align2_tensor_op_gmma_f32_warpspecialized, 64x128x64) {
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
@ -482,7 +482,7 @@ TEST(SM90_Device_Gemm_f16n_f16n_f16n_align2_tensor_op_gmma_f32, 64x128x64) {
float,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::collective::KernelScheduleAuto
cutlass::gemm::KernelCpAsyncWarpSpecialized
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<

View File

@ -0,0 +1,510 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "../../common/cutlass_unit_test.h"
#include "gemm_testbed_3x.hpp"
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
using namespace cute;
///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////// TT //////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16t_f16t_f16n_align8_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 8,
cutlass::half_t, LayoutB, 8,
float,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::half_t, LayoutC, 8,
cutlass::half_t, LayoutC, 8,
cutlass::epilogue::NoSmemWarpSpecialized
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
TEST(SM90_Device_Gemm_f16t_f16t_f16n_align4_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 4,
cutlass::half_t, LayoutB, 4,
float,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::half_t, LayoutC, 4,
cutlass::half_t, LayoutC, 4,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
TEST(SM90_Device_Gemm_f16t_f16t_f16n_align2_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 2,
cutlass::half_t, LayoutB, 2,
float,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::half_t, LayoutC, 2,
cutlass::half_t, LayoutC, 2,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////// TN //////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16t_f16n_f16n_align8_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 8,
cutlass::half_t, LayoutB, 8,
float,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::half_t, LayoutC, 8,
cutlass::half_t, LayoutC, 8,
cutlass::epilogue::NoSmemWarpSpecialized
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16t_f16n_f16n_align4_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 4,
cutlass::half_t, LayoutB, 4,
float,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::half_t, LayoutC, 4,
cutlass::half_t, LayoutC, 4,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16t_f16n_f16n_align2_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 2,
cutlass::half_t, LayoutB, 2,
float,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::half_t, LayoutC, 2,
cutlass::half_t, LayoutC, 2,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////// NT //////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16n_f16t_f16n_align8_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) {
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 8,
cutlass::half_t, LayoutB, 8,
float,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::half_t, LayoutC, 8,
cutlass::half_t, LayoutC, 8,
cutlass::epilogue::NoSmemWarpSpecialized
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16n_f16t_f16n_align4_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) {
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 4,
cutlass::half_t, LayoutB, 4,
float,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::half_t, LayoutC, 4,
cutlass::half_t, LayoutC, 4,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16n_f16t_f16n_align2_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) {
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 2,
cutlass::half_t, LayoutB, 2,
float,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::half_t, LayoutC, 2,
cutlass::half_t, LayoutC, 2,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////// NN //////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16n_f16n_f16n_align8_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) {
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 8,
cutlass::half_t, LayoutB, 8,
float,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::half_t, LayoutC, 8,
cutlass::half_t, LayoutC, 8,
cutlass::epilogue::NoSmemWarpSpecialized
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16n_f16n_f16n_align4_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) {
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 4,
cutlass::half_t, LayoutB, 4,
float,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::half_t, LayoutC, 4,
cutlass::half_t, LayoutC, 4,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16n_f16n_f16n_align2_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) {
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 2,
cutlass::half_t, LayoutB, 2,
float,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::half_t, LayoutC, 2,
cutlass::half_t, LayoutC, 2,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@ -0,0 +1,510 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "../../common/cutlass_unit_test.h"
#include "gemm_testbed_3x.hpp"
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
using namespace cute;
///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////// TT //////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16t_f16t_f16n_align8_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 8,
cutlass::half_t, LayoutB, 8,
float,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::half_t, LayoutC, 8,
cutlass::half_t, LayoutC, 8,
cutlass::epilogue::NoSmemWarpSpecialized
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
TEST(SM90_Device_Gemm_f16t_f16t_f16n_align4_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 4,
cutlass::half_t, LayoutB, 4,
float,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::half_t, LayoutC, 4,
cutlass::half_t, LayoutC, 4,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
TEST(SM90_Device_Gemm_f16t_f16t_f16n_align2_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 2,
cutlass::half_t, LayoutB, 2,
float,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::half_t, LayoutC, 2,
cutlass::half_t, LayoutC, 2,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////// TN //////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16t_f16n_f16n_align8_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 8,
cutlass::half_t, LayoutB, 8,
float,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::half_t, LayoutC, 8,
cutlass::half_t, LayoutC, 8,
cutlass::epilogue::NoSmemWarpSpecialized
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16t_f16n_f16n_align4_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 4,
cutlass::half_t, LayoutB, 4,
float,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::half_t, LayoutC, 4,
cutlass::half_t, LayoutC, 4,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16t_f16n_f16n_align2_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 2,
cutlass::half_t, LayoutB, 2,
float,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::half_t, LayoutC, 2,
cutlass::half_t, LayoutC, 2,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////// NT //////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16n_f16t_f16n_align8_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) {
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 8,
cutlass::half_t, LayoutB, 8,
float,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::half_t, LayoutC, 8,
cutlass::half_t, LayoutC, 8,
cutlass::epilogue::NoSmemWarpSpecialized
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16n_f16t_f16n_align4_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) {
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 4,
cutlass::half_t, LayoutB, 4,
float,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::half_t, LayoutC, 4,
cutlass::half_t, LayoutC, 4,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16n_f16t_f16n_align2_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) {
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 2,
cutlass::half_t, LayoutB, 2,
float,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::half_t, LayoutC, 2,
cutlass::half_t, LayoutC, 2,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////// NN //////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16n_f16n_f16n_align8_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) {
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 8,
cutlass::half_t, LayoutB, 8,
float,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::half_t, LayoutC, 8,
cutlass::half_t, LayoutC, 8,
cutlass::epilogue::NoSmemWarpSpecialized
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16n_f16n_f16n_align4_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) {
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 4,
cutlass::half_t, LayoutB, 4,
float,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::half_t, LayoutC, 4,
cutlass::half_t, LayoutC, 4,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f16n_f16n_f16n_align2_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) {
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 2,
cutlass::half_t, LayoutB, 2,
float,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::half_t, LayoutC, 2,
cutlass::half_t, LayoutC, 2,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@ -469,7 +469,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
EXPECT_TRUE(passed);
}
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasS8_ReLU_VoidC) {
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasS8_ReLU_VoidC_U1Aux) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
@ -478,8 +478,9 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
using ClusterShape_MNK = Shape<_2,_2,_1>;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
// ReLU with uint1b_t aux will compute dReLU/dZ as the aux output, i.e. Aux(i) = (Z(i) >= 0) ? 1 : 0
using FusionOperation = cutlass::epilogue::fusion::LinCombPerRowBiasEltActAux<
LayoutC, cutlass::epilogue::thread::ReLu, cutlass::half_t, float, cutlass::half_t, int8_t>;
LayoutC, cutlass::epilogue::thread::ReLU, cutlass::half_t, float, cutlass::uint1b_t, int8_t>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
@ -514,4 +515,94 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
EXPECT_TRUE(passed);
}
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_dReLU_dBias_VoidC) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using TileShape_MNK = Shape<_256,_128,_64>;
using ClusterShape_MNK = Shape<_2,_2,_1>;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
using FusionOperation = cutlass::epilogue::fusion::LinCombDeEltActDePerRowBias<
LayoutC, cutlass::epilogue::thread::dReLU, cutlass::half_t, float, cutlass::uint1b_t, float>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
void, LayoutC, 8,
cutlass::half_t, LayoutC, 8,
EpilogueSchedule,
FusionOperation
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 8,
cutlass::half_t, LayoutB, 8,
float,
TileShape_MNK, ClusterShape_MNK,
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
cutlass::gemm::KernelTmaWarpSpecializedCooperative
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
bool passed = test::gemm::device::TestAllBiasElementwise<Gemm>();
EXPECT_TRUE(passed);
}
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_dGELU_VoidC) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using TileShape_MNK = Shape<_256,_128,_64>;
using ClusterShape_MNK = Shape<_2,_2,_1>;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
using FusionOperation = cutlass::epilogue::fusion::LinCombDeEltAct<
LayoutC, cutlass::epilogue::thread::dGELU, cutlass::half_t, float, cutlass::half_t>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
void, LayoutC, 8,
cutlass::half_t, LayoutC, 8,
EpilogueSchedule,
FusionOperation
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 8,
cutlass::half_t, LayoutB, 8,
float,
TileShape_MNK, ClusterShape_MNK,
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
cutlass::gemm::KernelTmaWarpSpecializedCooperative
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
bool passed = test::gemm::device::TestAllBiasElementwise<Gemm>(1.0, 0.0, /*check_relative_equality=*/true);
EXPECT_TRUE(passed);
}
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@ -460,4 +460,49 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128
EXPECT_TRUE(passed);
}
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_pingpong_epilogue, 128x128x64_2x2x1_dReLU_dBias_VoidC) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using TileShape_MNK = Shape<_128,_128,_64>;
using ClusterShape_MNK = Shape<_2,_2,_1>;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
using FusionOperation = cutlass::epilogue::fusion::LinCombDeEltActDePerRowBias<
LayoutC, cutlass::epilogue::thread::dReLU, cutlass::half_t, float, cutlass::uint1b_t, float>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
void, LayoutC, 8,
cutlass::half_t, LayoutC, 8,
EpilogueSchedule,
FusionOperation
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 8,
cutlass::half_t, LayoutB, 8,
float,
TileShape_MNK, ClusterShape_MNK,
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
cutlass::gemm::KernelTmaWarpSpecializedPingpong
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
bool passed = test::gemm::device::TestAllBiasElementwise<Gemm>();
EXPECT_TRUE(passed);
}
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@ -0,0 +1,209 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "../../common/cutlass_unit_test.h"
#include "gemm_testbed_3x.hpp"
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
using namespace cute;
TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative, 128x192x64_1x1x1) {
using ElementA = cutlass::half_t;
using LayoutA = cutlass::layout::RowMajor;
using ElementB = cutlass::half_t;
using LayoutB = cutlass::layout::ColumnMajor;
using ElementC = float;
using LayoutC = cutlass::layout::ColumnMajor;
using ElementAccumulator = float;
using TileShape_MNK = Shape<_128,_192,_64>;
using ClusterShape_MNK = Shape<_1,_1,_1>;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct<
cutlass::epilogue::thread::Identity, ElementC, ElementAccumulator, ElementAccumulator>;
using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, 16 / sizeof(ElementC),
ElementC, LayoutC, 16 / sizeof(ElementC),
EpilogueSchedule,
FusionOperation
>::CollectiveOp;
using KernelScheduleType = cutlass::gemm::KernelTmaWarpSpecializedCooperative;
using AtomLayoutMNK = Layout<Shape<_2,_1,_1>>;
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::rs_op_selector<
ElementA, ElementB, ElementAccumulator, TileShape_MNK, GMMA::Major::K, GMMA::Major::K>(), AtomLayoutMNK{}));
using GmemTiledCopyA = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
using GmemTiledCopyB = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
static constexpr cute::GMMA::Major GmmaMajorA = cutlass::gemm::collective::detail::gmma_rs_tag_to_major_A<LayoutA>();
static constexpr cute::GMMA::Major GmmaMajorB = cutlass::gemm::collective::detail::gmma_rs_tag_to_major_B<LayoutB>();
using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::rs_smem_selector<GmmaMajorA, ElementA,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), false>());
using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::rs_smem_selector<GmmaMajorB, ElementB,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), false>());
using StageCountType = cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename EpilogueOp::SharedStorage)>;
static constexpr int PipelineStages = cutlass::gemm::collective::detail::compute_stage_count_or_override<
cutlass::gemm::collective::detail::sm90_smem_capacity_bytes,
ElementA, ElementB, TileShape_MNK>(StageCountType{});
using DispatchPolicy = cutlass::gemm::MainloopSm90TmaGmmaRmemAWarpSpecialized<
PipelineStages, ClusterShape_MNK, KernelScheduleType>;
using CollectiveOp = cutlass::gemm::collective::CollectiveMma<
DispatchPolicy,
TileShape_MNK,
ElementA,
cutlass::gemm::TagToStrideA_t<LayoutA>,
ElementB,
cutlass::gemm::TagToStrideB_t<LayoutB>,
TiledMma,
GmemTiledCopyA,
SmemLayoutAtomA,
cute::Copy_Atom<cute::SM75_U32x4_LDSM_N,ElementA>,
cute::identity,
GmemTiledCopyB,
SmemLayoutAtomB,
void,
cute::identity
>;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
EpilogueOp
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise<Gemm>());
}
TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative, 128x192x64_2x1x1) {
using ElementA = cutlass::half_t;
using LayoutA = cutlass::layout::RowMajor;
using ElementB = cutlass::half_t;
using LayoutB = cutlass::layout::ColumnMajor;
using ElementC = float;
using LayoutC = cutlass::layout::ColumnMajor;
using ElementAccumulator = float;
using TileShape_MNK = Shape<_128,_192,_64>;
using ClusterShape_MNK = Shape<_2,_1,_1>;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct<
cutlass::epilogue::thread::Identity, ElementC, ElementAccumulator, ElementAccumulator>;
using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, 16 / sizeof(ElementC),
ElementC, LayoutC, 16 / sizeof(ElementC),
EpilogueSchedule,
FusionOperation
>::CollectiveOp;
using KernelScheduleType = cutlass::gemm::KernelTmaWarpSpecializedCooperative;
using AtomLayoutMNK = Layout<Shape<_2,_1,_1>>;
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::rs_op_selector<
ElementA, ElementB, ElementAccumulator, TileShape_MNK, GMMA::Major::K, GMMA::Major::K>(), AtomLayoutMNK{}));
using GmemTiledCopyA = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
using GmemTiledCopyB = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
static constexpr cute::GMMA::Major GmmaMajorA = cutlass::gemm::collective::detail::gmma_rs_tag_to_major_A<LayoutA>();
static constexpr cute::GMMA::Major GmmaMajorB = cutlass::gemm::collective::detail::gmma_rs_tag_to_major_B<LayoutB>();
using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::rs_smem_selector<GmmaMajorA, ElementA,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), false>());
using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::rs_smem_selector<GmmaMajorB, ElementB,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), false>());
using StageCountType = cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename EpilogueOp::SharedStorage)>;
static constexpr int PipelineStages = cutlass::gemm::collective::detail::compute_stage_count_or_override<
cutlass::gemm::collective::detail::sm90_smem_capacity_bytes,
ElementA, ElementB, TileShape_MNK>(StageCountType{});
using DispatchPolicy = cutlass::gemm::MainloopSm90TmaGmmaRmemAWarpSpecialized<
PipelineStages, ClusterShape_MNK, KernelScheduleType>;
using CollectiveOp = cutlass::gemm::collective::CollectiveMma<
DispatchPolicy,
TileShape_MNK,
ElementA,
cutlass::gemm::TagToStrideA_t<LayoutA>,
ElementB,
cutlass::gemm::TagToStrideB_t<LayoutB>,
TiledMma,
GmemTiledCopyA,
SmemLayoutAtomA,
cute::Copy_Atom<cute::SM75_U32x4_LDSM_N,ElementA>,
cute::identity,
GmemTiledCopyB,
SmemLayoutAtomB,
void,
cute::identity
>;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
EpilogueOp
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise<Gemm>());
}
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@ -0,0 +1,209 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "../../common/cutlass_unit_test.h"
#include "gemm_testbed_3x.hpp"
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
using namespace cute;
TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative, 128x128x128_1x1x1) {
using ElementA = cutlass::float_e4m3_t;
using LayoutA = cutlass::layout::RowMajor;
using ElementB = cutlass::float_e4m3_t;
using LayoutB = cutlass::layout::ColumnMajor;
using ElementC = float;
using LayoutC = cutlass::layout::ColumnMajor;
using ElementAccumulator = float;
using TileShape_MNK = Shape<_128,_128,_128>;
using ClusterShape_MNK = Shape<_1,_1,_1>;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct<
cutlass::epilogue::thread::Identity, ElementC, ElementAccumulator, ElementAccumulator>;
using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, 16 / sizeof(ElementC),
ElementC, LayoutC, 16 / sizeof(ElementC),
EpilogueSchedule,
FusionOperation
>::CollectiveOp;
using KernelScheduleType = cutlass::gemm::KernelTmaWarpSpecializedCooperative;
using AtomLayoutMNK = Layout<Shape<_2,_1,_1>>;
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::rs_op_selector<
ElementA, ElementB, ElementAccumulator, TileShape_MNK, GMMA::Major::K, GMMA::Major::K>(), AtomLayoutMNK{}));
using GmemTiledCopyA = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
using GmemTiledCopyB = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
static constexpr cute::GMMA::Major GmmaMajorA = cutlass::gemm::collective::detail::gmma_rs_tag_to_major_A<LayoutA>();
static constexpr cute::GMMA::Major GmmaMajorB = cutlass::gemm::collective::detail::gmma_rs_tag_to_major_B<LayoutB>();
using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::rs_smem_selector<GmmaMajorA, ElementA,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), false>());
using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::rs_smem_selector<GmmaMajorB, ElementB,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), false>());
using StageCountType = cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename EpilogueOp::SharedStorage)>;
static constexpr int PipelineStages = cutlass::gemm::collective::detail::compute_stage_count_or_override<
cutlass::gemm::collective::detail::sm90_smem_capacity_bytes,
ElementA, ElementB, TileShape_MNK>(StageCountType{});
using DispatchPolicy = cutlass::gemm::MainloopSm90TmaGmmaRmemAWarpSpecialized<
PipelineStages, ClusterShape_MNK, KernelScheduleType>;
using CollectiveOp = cutlass::gemm::collective::CollectiveMma<
DispatchPolicy,
TileShape_MNK,
ElementA,
cutlass::gemm::TagToStrideA_t<LayoutA>,
ElementB,
cutlass::gemm::TagToStrideB_t<LayoutB>,
TiledMma,
GmemTiledCopyA,
SmemLayoutAtomA,
cute::Copy_Atom<cute::SM75_U32x4_LDSM_N,ElementA>,
cute::identity,
GmemTiledCopyB,
SmemLayoutAtomB,
void,
cute::identity
>;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
EpilogueOp
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise<Gemm>());
}
TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative, 128x128x128_2x1x1) {
using ElementA = cutlass::float_e4m3_t;
using LayoutA = cutlass::layout::RowMajor;
using ElementB = cutlass::float_e4m3_t;
using LayoutB = cutlass::layout::ColumnMajor;
using ElementC = float;
using LayoutC = cutlass::layout::ColumnMajor;
using ElementAccumulator = float;
using TileShape_MNK = Shape<_128,_128,_128>;
using ClusterShape_MNK = Shape<_2,_1,_1>;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct<
cutlass::epilogue::thread::Identity, ElementC, ElementAccumulator, ElementAccumulator>;
using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, 16 / sizeof(ElementC),
ElementC, LayoutC, 16 / sizeof(ElementC),
EpilogueSchedule,
FusionOperation
>::CollectiveOp;
using KernelScheduleType = cutlass::gemm::KernelTmaWarpSpecializedCooperative;
using AtomLayoutMNK = Layout<Shape<_2,_1,_1>>;
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::rs_op_selector<
ElementA, ElementB, ElementAccumulator, TileShape_MNK, GMMA::Major::K, GMMA::Major::K>(), AtomLayoutMNK{}));
using GmemTiledCopyA = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
using GmemTiledCopyB = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
static constexpr cute::GMMA::Major GmmaMajorA = cutlass::gemm::collective::detail::gmma_rs_tag_to_major_A<LayoutA>();
static constexpr cute::GMMA::Major GmmaMajorB = cutlass::gemm::collective::detail::gmma_rs_tag_to_major_B<LayoutB>();
using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::rs_smem_selector<GmmaMajorA, ElementA,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), false>());
using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::rs_smem_selector<GmmaMajorB, ElementB,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), false>());
using StageCountType = cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename EpilogueOp::SharedStorage)>;
static constexpr int PipelineStages = cutlass::gemm::collective::detail::compute_stage_count_or_override<
cutlass::gemm::collective::detail::sm90_smem_capacity_bytes,
ElementA, ElementB, TileShape_MNK>(StageCountType{});
using DispatchPolicy = cutlass::gemm::MainloopSm90TmaGmmaRmemAWarpSpecialized<
PipelineStages, ClusterShape_MNK, KernelScheduleType>;
using CollectiveOp = cutlass::gemm::collective::CollectiveMma<
DispatchPolicy,
TileShape_MNK,
ElementA,
cutlass::gemm::TagToStrideA_t<LayoutA>,
ElementB,
cutlass::gemm::TagToStrideB_t<LayoutB>,
TiledMma,
GmemTiledCopyA,
SmemLayoutAtomA,
cute::Copy_Atom<cute::SM75_U32x4_LDSM_N,ElementA>,
cute::identity,
GmemTiledCopyB,
SmemLayoutAtomB,
void,
cute::identity
>;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
EpilogueOp
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise<Gemm>());
}
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@ -58,7 +58,7 @@ using namespace cute;
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_s8t_s8n_s8n_align8_tensor_op_gmma_s32, 64x128x128) {
TEST(SM90_Device_Gemm_s8t_s8n_s8n_align8_tensor_op_gmma_s32, 128x128x128) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
@ -68,14 +68,14 @@ TEST(SM90_Device_Gemm_s8t_s8n_s8n_align8_tensor_op_gmma_s32, 64x128x128) {
int8_t, LayoutA, 8,
int8_t, LayoutB, 8,
int32_t,
Shape<_64,_128,_128>, Shape<_1,_1,_1>,
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_64,_128,_128>, Shape<_1,_1,_1>,
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
int32_t, int32_t,
int8_t, LayoutC, 8,
@ -93,42 +93,7 @@ TEST(SM90_Device_Gemm_s8t_s8n_s8n_align8_tensor_op_gmma_s32, 64x128x128) {
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
TEST(SM90_Device_Gemm_s8t_s8n_s8n_align16_tensor_op_gmma_s32, 128x128x128) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
int8_t, LayoutA, 16,
int8_t, LayoutB, 16,
int32_t,
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelMultistage
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
int32_t, int32_t,
int8_t, LayoutC, 8,
int8_t, LayoutC, 8,
cutlass::epilogue::NoSmemWarpSpecialized
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
TEST(SM90_Device_Gemm_s8t_s8n_s8n_align4_tensor_op_gmma_s32, 128x64x128) {
TEST(SM90_Device_Gemm_s8t_s8n_s8n_align4_tensor_op_gmma_s32, 128x128x128) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
@ -138,14 +103,14 @@ TEST(SM90_Device_Gemm_s8t_s8n_s8n_align4_tensor_op_gmma_s32, 128x64x128) {
int8_t, LayoutA, 4,
int8_t, LayoutB, 4,
int32_t,
Shape<_128,_64,_128>, Shape<_1,_1,_1>,
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_64,_128>, Shape<_1,_1,_1>,
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
int32_t, int32_t,
int8_t, LayoutC, 4,

View File

@ -0,0 +1,168 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "../../common/cutlass_unit_test.h"
#include "gemm_testbed_3x.hpp"
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
using namespace cute;
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_s8t_s8n_s8n_align16_tensor_op_gmma_s32_warpspecialized, 128x128x128) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
int8_t, LayoutA, 16,
int8_t, LayoutB, 16,
int32_t,
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecialized
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
int32_t, int32_t,
int8_t, LayoutC, 16,
int8_t, LayoutC, 16,
cutlass::epilogue::NoSmemWarpSpecialized
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
TEST(SM90_Device_Gemm_s8t_s8n_s8n_align8_tensor_op_gmma_s32_warpspecialized, 128x128x128) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
int8_t, LayoutA, 8,
int8_t, LayoutB, 8,
int32_t,
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecialized
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
int32_t, int32_t,
int8_t, LayoutC, 8,
int8_t, LayoutC, 8,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
TEST(SM90_Device_Gemm_s8t_s8n_s8n_align4_tensor_op_gmma_s32_warpspecialized, 128x128x128) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
int8_t, LayoutA, 4,
int8_t, LayoutB, 4,
int32_t,
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecialized
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
int32_t, int32_t,
int8_t, LayoutC, 4,
int8_t, LayoutC, 4,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@ -0,0 +1,168 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "../../common/cutlass_unit_test.h"
#include "gemm_testbed_3x.hpp"
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
using namespace cute;
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_s8t_s8n_s8n_align16_tensor_op_gmma_s32_warpspecialized_cooperative, 128x128x128) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
int8_t, LayoutA, 16,
int8_t, LayoutB, 16,
int32_t,
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
int32_t, int32_t,
int8_t, LayoutC, 16,
int8_t, LayoutC, 16,
cutlass::epilogue::NoSmemWarpSpecialized
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
TEST(SM90_Device_Gemm_s8t_s8n_s8n_align8_tensor_op_gmma_s32_warpspecialized_cooperative, 128x128x128) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
int8_t, LayoutA, 8,
int8_t, LayoutB, 8,
int32_t,
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
int32_t, int32_t,
int8_t, LayoutC, 8,
int8_t, LayoutC, 8,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
TEST(SM90_Device_Gemm_s8t_s8n_s8n_align4_tensor_op_gmma_s32_warpspecialized_cooperative, 128x128x128) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
int8_t, LayoutA, 4,
int8_t, LayoutB, 4,
int32_t,
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
int32_t, int32_t,
int8_t, LayoutC, 4,
int8_t, LayoutC, 4,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@ -0,0 +1,168 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "../../common/cutlass_unit_test.h"
#include "gemm_testbed_3x.hpp"
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
using namespace cute;
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_s8t_s8n_s8n_align16_tensor_op_gmma_s32_warpspecialized_pingpong, 128x128x128) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
int8_t, LayoutA, 16,
int8_t, LayoutB, 16,
int32_t,
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
int32_t, int32_t,
int8_t, LayoutC, 16,
int8_t, LayoutC, 16,
cutlass::epilogue::NoSmemWarpSpecialized
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
TEST(SM90_Device_Gemm_s8t_s8n_s8n_align8_tensor_op_gmma_s32_warpspecialized_pingpong, 128x128x128) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
int8_t, LayoutA, 8,
int8_t, LayoutB, 8,
int32_t,
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
int32_t, int32_t,
int8_t, LayoutC, 8,
int8_t, LayoutC, 8,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
TEST(SM90_Device_Gemm_s8t_s8n_s8n_align4_tensor_op_gmma_s32_warpspecialized_pingpong, 128x128x128) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
int8_t, LayoutA, 4,
int8_t, LayoutB, 4,
int32_t,
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
int32_t, int32_t,
int8_t, LayoutC, 4,
int8_t, LayoutC, 4,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@ -62,10 +62,10 @@ run_scheduler(int* visit_counters, typename Scheduler::Params params, TileShape
Scheduler scheduler{params};
auto work_tile_info = scheduler.get_current_work();
while (work_tile_info.is_valid_tile) {
while (work_tile_info.is_valid()) {
// Increment counters to indicate coverage
auto tile_idx = Scheduler::output_tile_index(params, work_tile_info);
auto offset = tile_idx * params.k_tiles_per_output_tile_ + work_tile_info.K_idx;
auto offset = tile_idx * params.divmod_tiles_per_output_tile_.divisor + work_tile_info.K_idx;
for (auto i = 0; i < work_tile_info.k_tile_count; ++i) {
// Use atomicAdd because the visit counters are shared by multiple thread blocks.
// While having more than one block increment the same counter indicates failure,
@ -108,7 +108,7 @@ test_scheduler(
// Allocate counters indicating the number of times each k iteration of each output tile has been visited
auto [blk_m, blk_n, blk_l] = Scheduler::get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape, cluster_shape);
auto total_counters = blk_m * blk_n * blk_l * params.k_tiles_per_output_tile_;
auto total_counters = blk_m * blk_n * blk_l * params.divmod_tiles_per_output_tile_.divisor;
cutlass::DeviceAllocation<int> visit_counters(total_counters);
// Initialize counters to zero
@ -181,8 +181,6 @@ test_scheduler(
for (size_t i = 0; i < host_visit_counts.size(); ++i) {
if (host_visit_counts[i] != 1) {
// for (int count : host_visit_counts) {
// if (count != 1) {
std::cout << "Failed with problem size "
<< size<0>(problem_shape_mnkl) << "x"
<< size<1>(problem_shape_mnkl) << "x"
@ -191,11 +189,12 @@ test_scheduler(
<< " and grid size " << grid.x << "x"
<< grid.y << "x" << grid.z
<< " splits=" << params.splits_
<< " k_iter=" << params.k_tiles_per_output_tile_
<< " k_iter=" << params.divmod_tiles_per_output_tile_.divisor
<< " big_units=" << params.big_units_
<< " sk_tiles=" << params.sk_tiles_
<< " sk_units=" << params.sk_units_
<< " k_tiles_per_sk_unit=" << params.k_tiles_per_sk_unit_ << std::endl;
<< " k_tiles_per_sk_unit=" << params.k_tiles_per_sk_unit_
<< " units_per_problem=" << params.units_per_problem_ << std::endl;
std::cout << "Error at idx: " << i << ". Got count " << host_visit_counts[i] << std::endl;
return false;
}

View File

@ -57,42 +57,7 @@ using namespace cute;
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align4_tensor_op_gmma_f32, 64x128x32) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
tfloat32_t, LayoutA, 4,
tfloat32_t, LayoutB, 4,
float,
Shape<_64,_128,_32>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelMultistage
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_64,_128,_32>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
float, LayoutC, 4,
float, LayoutC, 4,
cutlass::epilogue::NoSmemWarpSpecialized
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align2_tensor_op_gmma_f32, 64x64x32) {
TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align2_tensor_op_gmma_f32, 128x64x32) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
@ -102,7 +67,7 @@ TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align2_tensor_op_gmma_f32, 64x64x32) {
cutlass::tfloat32_t, LayoutA, 2,
cutlass::tfloat32_t, LayoutB, 2,
float,
Shape<_64,_64,_32>, Shape<_1,_1,_1>,
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;

View File

@ -0,0 +1,167 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "../../common/cutlass_unit_test.h"
#include "gemm_testbed_3x.hpp"
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
using namespace cute;
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align4_tensor_op_gmma_f32_warpspecialized, 128x64x32) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
tfloat32_t, LayoutA, 4,
tfloat32_t, LayoutB, 4,
float,
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecialized
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
float, LayoutC, 4,
float, LayoutC, 4,
cutlass::epilogue::NoSmemWarpSpecialized
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align2_tensor_op_gmma_f32_warpspecialized, 128x64x32) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::tfloat32_t, LayoutA, 2,
cutlass::tfloat32_t, LayoutB, 2,
float,
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecialized
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
float, LayoutC, 2,
float, LayoutC, 2,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align1_tensor_op_gmma_f32_warpspecialized, 128x64x32) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::tfloat32_t, LayoutA, 1,
cutlass::tfloat32_t, LayoutB, 1,
float,
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecialized
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
float, LayoutC, 1,
float, LayoutC, 1,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@ -0,0 +1,167 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "../../common/cutlass_unit_test.h"
#include "gemm_testbed_3x.hpp"
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
using namespace cute;
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align4_tensor_op_gmma_f32_warpspecialized_cooperative, 128x64x32) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
tfloat32_t, LayoutA, 4,
tfloat32_t, LayoutB, 4,
float,
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
float, LayoutC, 4,
float, LayoutC, 4,
cutlass::epilogue::NoSmemWarpSpecialized
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align2_tensor_op_gmma_f32_warpspecialized_cooperative, 128x64x32) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::tfloat32_t, LayoutA, 2,
cutlass::tfloat32_t, LayoutB, 2,
float,
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
float, LayoutC, 2,
float, LayoutC, 2,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align1_tensor_op_gmma_f32_warpspecialized_cooperative, 128x64x32) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::tfloat32_t, LayoutA, 1,
cutlass::tfloat32_t, LayoutB, 1,
float,
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
float, LayoutC, 1,
float, LayoutC, 1,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@ -0,0 +1,167 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "../../common/cutlass_unit_test.h"
#include "gemm_testbed_3x.hpp"
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
using namespace cute;
///////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align4_tensor_op_gmma_f32_warpspecialized_pingpong, 128x64x32) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
tfloat32_t, LayoutA, 4,
tfloat32_t, LayoutB, 4,
float,
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
float, LayoutC, 4,
float, LayoutC, 4,
cutlass::epilogue::NoSmemWarpSpecialized
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align2_tensor_op_gmma_f32_warpspecialized_pingpong, 128x64x32) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::tfloat32_t, LayoutA, 2,
cutlass::tfloat32_t, LayoutB, 2,
float,
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
float, LayoutC, 2,
float, LayoutC, 2,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align1_tensor_op_gmma_f32_warpspecialized_pingpong, 128x64x32) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::tfloat32_t, LayoutA, 1,
cutlass::tfloat32_t, LayoutB, 1,
float,
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
>::CollectiveOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
float, LayoutC, 1,
float, LayoutC, 1,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@ -99,7 +99,7 @@ TEST(SM90_Device_Gemm_tf32n_tf32n_f32n_tensor_op_gmma_f32, 64x128x32) {
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::tfloat32_t, LayoutA, 1,
cutlass::tfloat32_t, LayoutA, 4,
cutlass::tfloat32_t, LayoutB, 4,
float,
Shape<_64,_128,_32>, Shape<_1,_1,_1>,
@ -136,8 +136,8 @@ TEST(SM90_Device_Gemm_tf32n_tf32t_f32n_tensor_op_gmma_f32, 64x128x32) {
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::tfloat32_t, LayoutA, 1,
cutlass::tfloat32_t, LayoutB, 1,
cutlass::tfloat32_t, LayoutA, 4,
cutlass::tfloat32_t, LayoutB, 4,
float,
Shape<_64,_128,_32>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
@ -149,8 +149,8 @@ TEST(SM90_Device_Gemm_tf32n_tf32t_f32n_tensor_op_gmma_f32, 64x128x32) {
Shape<_64,_128,_32>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
float, LayoutC, 1,
float, LayoutC, 1,
float, LayoutC, 4,
float, LayoutC, 4,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
@ -174,7 +174,7 @@ TEST(SM90_Device_Gemm_tf32t_tf32t_f32n_tensor_op_gmma_f32, 64x128x32) {
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::tfloat32_t, LayoutA, 4,
cutlass::tfloat32_t, LayoutB, 1,
cutlass::tfloat32_t, LayoutB, 4,
float,
Shape<_64,_128,_32>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAuto,
@ -188,7 +188,7 @@ TEST(SM90_Device_Gemm_tf32t_tf32t_f32n_tensor_op_gmma_f32, 64x128x32) {
float, float,
float, LayoutC, 4,
float, LayoutC, 4,
cutlass::epilogue::collective::EpilogueScheduleAuto
cutlass::gemm::EpilogueTransposed
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<

View File

@ -337,6 +337,8 @@ TEST(SM80_Device_Trmm_tf32t_tf32n_f32t_ls_u_un_tensor_op_f32_align1_align4, 128x
/////////////////////////////////////////////////////////////////////////////////////////////////
// This test fails on Ada when running with 11.8
#if ((__CUDACC_VER_MAJOR__ != 11) || (__CUDACC_VER_MINOR__ != 8) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 890)))
TEST(SM80_Device_Trmm_tf32t_tf32n_f32t_ls_u_nu_tensor_op_f32_align1_align4, 256x128x16_128x64x16) {
using ElementOutput = float;
@ -374,6 +376,7 @@ TEST(SM80_Device_Trmm_tf32t_tf32n_f32t_ls_u_nu_tensor_op_f32_align1_align4, 256x
EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal<Trmm>());
}
#endif
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -55,6 +55,7 @@
////////////////////////////////////////////////////////////////////////////////
/// F32 <= F16 * I8 + F32 (Upcast on Operand B)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i8, 128x128x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
@ -98,6 +99,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i8, 64x64x64_64x64x64_16
////////////////////////////////////////////////////////////////////////////////
/// F32 <= I8 * F16 + F32 (Upcast on Operand A)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_f16, 128x128x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
@ -118,7 +120,6 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_f16, 128x128x64_64x64x64_
.run();
}
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_f16, 64x64x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
@ -142,6 +143,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_f16, 64x64x64_64x64x64_16
////////////////////////////////////////////////////////////////////////////////
/// F32 <= F16 * U8 + F32 (Upcast on Operand B)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_u8, 64x64x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
@ -185,6 +187,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_u8, 128x128x64_64x64x64_
////////////////////////////////////////////////////////////////////////////////
/// F32 <= U8 * F16 + F32 (Upcast on Operand A)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_f16, 64x64x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
@ -225,10 +228,10 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_f16, 128x128x64_64x64x64_
.run();
}
////////////////////////////////////////////////////////////////////////////////
/// F32 <= B16 * U8 + F32 (Upcast on Operand B)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_bf16_u8, 64x64x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
@ -252,6 +255,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_bf16_u8, 64x64x64_64x64x64_1
////////////////////////////////////////////////////////////////////////////////
/// F32 <= U8 * BF16 + F32 (Upcast on Operand A)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_bf16, 64x64x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
@ -273,8 +277,9 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_bf16, 64x64x64_64x64x64_1
}
////////////////////////////////////////////////////////////////////////////////
/// F32 <= B16 * I8 + F32 (Upcast on Operand B)
/// F32 <= I8 * BF16 + F32 (Upcast on Operand A)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_bf16_i8, 64x64x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
@ -296,8 +301,9 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_bf16_i8, 64x64x64_64x64x64_1
}
////////////////////////////////////////////////////////////////////////////////
/// F32 <= I8 * BF16 + F32 (Upcast on Operand A)
/// F32 <= B16 * I8 + F32 (Upcast on Operand B)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_bf16, 64x64x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
@ -318,4 +324,4 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_bf16, 64x64x64_64x64x64_1
.run();
}
#endif // if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
#endif // if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)

View File

@ -44,7 +44,7 @@ void rmsnorm_host(cutlass::MatrixCoord tensor_size,
cutlass::TensorRef<ElementType, Layout> output,
cutlass::TensorRef<ElementType, Layout> input,
cutlass::TensorRef<ElementType, Layout> weight,
float epsilon) {
float epsilon) {
const int M = tensor_size.row();
const int N = tensor_size.column();
@ -94,7 +94,7 @@ void run_test(int M, int N) {
rmsnorm_host({M, N}, output_ref.host_ref(), input.host_ref(), weight.host_ref(), (float)1e-5);
cutlass::rmsnorm({M, N}, output.device_ref(),
input.device_ref(), weight.device_ref(), NULL, (float)1e-5);
input.device_ref(), weight.device_ref(), NULL, (float)1e-5L);
output.sync_host();