CUTLASS 3.3.0 (#1167)
* Release 3.3.0 Adds support for mixed precision GEMMs On Hopper and Ampere Adds support for < 16B aligned GEMMs on Hopper Enhancements to EVT Enhancements to Python interface Enhancements to Sub-byte type handling in CuTe Several other bug-fixes and performance improvements. * minor doc update
This commit is contained in:
@ -28,4 +28,8 @@
|
||||
|
||||
if (CUTLASS_ENABLE_GTEST_UNIT_TESTS)
|
||||
add_subdirectory(unit)
|
||||
else()
|
||||
# Always provide at least the phony test_unit target.
|
||||
add_custom_target(test_unit)
|
||||
endif()
|
||||
|
||||
|
||||
@ -36,8 +36,9 @@ Utilities for defining Conv2D problem sizes for testing.
|
||||
This file was ported from the C++ version in test/unit/conv/device/conv2d_problems.h
|
||||
"""
|
||||
|
||||
from cutlass_library import ConvMode
|
||||
|
||||
import cutlass
|
||||
from cutlass import ConvMode
|
||||
from cutlass.shape import Conv2DProblemSize
|
||||
|
||||
|
||||
|
||||
@ -34,10 +34,11 @@
|
||||
Utility functions for Conv2d tests.
|
||||
"""
|
||||
|
||||
from cutlass_library import SubstituteTemplate
|
||||
import torch
|
||||
|
||||
import cutlass
|
||||
from cutlass import (
|
||||
from cutlass_library import (
|
||||
ConvKind,
|
||||
ConvMode,
|
||||
DataType,
|
||||
@ -50,7 +51,6 @@ from cutlass import (
|
||||
ShortLayoutTypeNames,
|
||||
SplitKMode,
|
||||
)
|
||||
from cutlass.backend.utils.software import SubstituteTemplate
|
||||
from cutlass.shape import Conv2DProblemSize
|
||||
from cutlass.utils.datatypes import numpy_type, torch_type
|
||||
|
||||
@ -301,17 +301,19 @@ class Conv2dLauncherFrontend:
|
||||
tensor_B = self.uniform_init(size=tensor_B_size, dtype=self.dtype_B)
|
||||
tensor_C = self.uniform_init(size=tensor_C_size, dtype=self.dtype_C)
|
||||
tensor_D = torch.zeros_like(tensor_C).to(memory_format=torch.channels_last)
|
||||
self.operation.run(tensor_A, tensor_B, tensor_C, tensor_D,
|
||||
args = self.operation.run(tensor_A, tensor_B, tensor_C, tensor_D,
|
||||
stride=(ps.stride_h, ps.stride_w),
|
||||
padding=(ps.pad_h, ps.pad_w),
|
||||
dilation=(ps.dilation_h, ps.dilation_w),
|
||||
alpha=alpha, beta=beta,
|
||||
split_k=(split_k_mode, split_k_slices))
|
||||
|
||||
args.sync()
|
||||
|
||||
tensor_D_ref = self.reference(ps, tensor_A, tensor_B, tensor_C, alpha, beta, self.activation)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
passed = torch.equal(tensor_D, tensor_D_ref)
|
||||
passed = torch.allclose(tensor_D, tensor_D_ref, atol=2e-06)
|
||||
|
||||
return passed
|
||||
|
||||
@ -378,7 +380,8 @@ def add_test(
|
||||
conv2d_launcher = Conv2dLauncherFrontend(plan, 80, backend="torch")
|
||||
|
||||
for ps in problem_sizes:
|
||||
if not validate_problem_size(ps, conv_kind, split_k_slices): continue
|
||||
if not validate_problem_size(ps, conv_kind, split_k_slices):
|
||||
continue
|
||||
|
||||
self.assertTrue(conv2d_launcher.run(ps, split_k_mode, split_k_slices, 1.0, 2.0))
|
||||
|
||||
|
||||
@ -38,9 +38,11 @@ import random
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from cutlass_library import ConvMode
|
||||
|
||||
import cutlass
|
||||
|
||||
if cutlass.utils.datatypes.torch_available:
|
||||
if cutlass.utils.datatypes.is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
@ -88,7 +90,7 @@ def _generate_problems(dtype, num):
|
||||
def _generate_conv2d_problem(conv_kind, dtype, ps):
|
||||
"""
|
||||
Utility function to generate conv2d inputs
|
||||
|
||||
|
||||
:param conv_kind: kind of convolution
|
||||
:type conv_kind: str
|
||||
:param dtype: data type of tensors
|
||||
@ -114,7 +116,7 @@ def _generate_conv2d_problem(conv_kind, dtype, ps):
|
||||
return [torch.ceil(torch.empty(size, dtype=dtype, device='cuda').uniform_(-4.5, 3.5)).to(memory_format=torch.channels_last) for size in sizes]
|
||||
|
||||
|
||||
@unittest.skipIf(not cutlass.utils.datatypes.torch_available, 'PyTorch must be available to run PyTorch extension tests')
|
||||
@unittest.skipIf(not cutlass.utils.datatypes.is_torch_available(), 'PyTorch must be available to run PyTorch extension tests')
|
||||
class PyTorchExtensionTest(unittest.TestCase):
|
||||
|
||||
def test_gemm(self):
|
||||
@ -183,18 +185,18 @@ class PyTorchExtensionTest(unittest.TestCase):
|
||||
Ds_ref = [(a @ b) * alpha + (beta * c) for a, b, c in zip(As, Bs, Cs)]
|
||||
Ds = mod.run(As, Bs, Cs, alpha, beta)
|
||||
check_all(Ds, Ds_ref)
|
||||
|
||||
|
||||
def test_conv2d_fprop(self):
|
||||
torch.manual_seed(2023)
|
||||
|
||||
|
||||
dtype = torch.float16
|
||||
plan = cutlass.op.Conv2d(kind="fprop", element=dtype, element_accumulator=torch.float32)
|
||||
plan.activation = "relu"
|
||||
|
||||
|
||||
op = plan.construct()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
mod = cutlass.emit.pytorch(op, name="conv2d_mod", cc=plan.cc, sourcedir=tmpdir, jit=True)
|
||||
|
||||
|
||||
problem_size = cutlass.shape.Conv2DProblemSize(
|
||||
1, 4, 4, 16,
|
||||
8, 3, 3, 16,
|
||||
@ -202,50 +204,50 @@ class PyTorchExtensionTest(unittest.TestCase):
|
||||
3, 3,
|
||||
1, 1
|
||||
)
|
||||
|
||||
|
||||
A, B, C = _generate_conv2d_problem("fprop", dtype, problem_size)
|
||||
stride = (problem_size.stride_h, problem_size.stride_w)
|
||||
padding = (problem_size.pad_h, problem_size.pad_w)
|
||||
|
||||
alpha = 1.0
|
||||
beta = 0.5
|
||||
|
||||
|
||||
D_ref = alpha * torch.ops.aten.conv2d(
|
||||
A, B, stride=stride, padding=padding
|
||||
) + beta * C
|
||||
D_ref = torch.nn.functional.relu(D_ref)
|
||||
D = mod.run(A, B, C, stride, padding, alpha=alpha, beta=beta)
|
||||
|
||||
assert torch.allclose(D, D_ref)
|
||||
|
||||
|
||||
assert torch.allclose(D, D_ref)
|
||||
|
||||
# Test serial split-K
|
||||
D_serial_split_k = mod.run(A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="serial", split_k_slices=3)
|
||||
assert torch.allclose(D, D_serial_split_k)
|
||||
|
||||
|
||||
# Test parallel split-K
|
||||
D_parallel_split_k = mod.run(A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="parallel", split_k_slices=7)
|
||||
assert torch.allclose(D, D_parallel_split_k)
|
||||
|
||||
|
||||
|
||||
|
||||
def test_conv2d_dgrad(self):
|
||||
torch.manual_seed(2023)
|
||||
dtype = torch.float16
|
||||
plan = cutlass.op.Conv2d(kind="dgrad", element=dtype, element_accumulator=torch.float32)
|
||||
|
||||
|
||||
op = plan.construct()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
mod = cutlass.emit.pytorch(op, name="conv2d_dgrad_mod", cc=plan.cc, sourcedir=tmpdir, jit=True)
|
||||
|
||||
|
||||
problem_size = cutlass.shape.Conv2DProblemSize(
|
||||
1, 4, 4, 16,
|
||||
8, 3, 3, 16,
|
||||
0, 0,
|
||||
3, 3,
|
||||
1, 1,
|
||||
cutlass.ConvMode.CrossCorrelation,
|
||||
ConvMode.CrossCorrelation,
|
||||
1, 1
|
||||
)
|
||||
|
||||
|
||||
A, B, C = _generate_conv2d_problem("dgrad", dtype, problem_size)
|
||||
stride = (problem_size.stride_h, problem_size.stride_w)
|
||||
padding = (problem_size.pad_h, problem_size.pad_w)
|
||||
@ -254,32 +256,32 @@ class PyTorchExtensionTest(unittest.TestCase):
|
||||
beta = 0.5
|
||||
input_size = (problem_size.N, problem_size.C, problem_size.H, problem_size.W)
|
||||
D_ref = alpha * torch.nn.grad.conv2d_input(
|
||||
input_size, B, A,
|
||||
input_size, B, A,
|
||||
stride=stride, padding=padding
|
||||
) + beta * C
|
||||
D = mod.run(input_size, A, B, C, stride, padding, alpha=alpha, beta=beta, )
|
||||
|
||||
assert torch.allclose(D, D_ref)
|
||||
|
||||
|
||||
assert torch.allclose(D, D_ref)
|
||||
|
||||
def test_conv2d_wgrad(self):
|
||||
torch.manual_seed(2023)
|
||||
dtype = torch.float16
|
||||
plan = cutlass.op.Conv2d(kind="wgrad", element=dtype, element_accumulator=torch.float32)
|
||||
|
||||
|
||||
op = plan.construct()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
mod = cutlass.emit.pytorch(op, name="conv2d_wgrad_mod", cc=plan.cc, sourcedir=tmpdir, jit=True)
|
||||
|
||||
|
||||
problem_size = cutlass.shape.Conv2DProblemSize(
|
||||
1, 4, 4, 16,
|
||||
8, 3, 3, 16,
|
||||
0, 0,
|
||||
3, 3,
|
||||
1, 1,
|
||||
cutlass.ConvMode.CrossCorrelation,
|
||||
ConvMode.CrossCorrelation,
|
||||
1, 1
|
||||
)
|
||||
|
||||
|
||||
A, B, C = _generate_conv2d_problem("wgrad", dtype, problem_size)
|
||||
stride = (problem_size.stride_h, problem_size.stride_w)
|
||||
padding = (problem_size.pad_h, problem_size.pad_w)
|
||||
@ -288,17 +290,17 @@ class PyTorchExtensionTest(unittest.TestCase):
|
||||
beta = 0.5
|
||||
weight_size = (problem_size.K, problem_size.C, problem_size.R, problem_size.S)
|
||||
D_ref = alpha * torch.nn.grad.conv2d_weight(
|
||||
B, weight_size, A,
|
||||
B, weight_size, A,
|
||||
stride=stride, padding=padding
|
||||
) + beta * C
|
||||
D = mod.run(weight_size, A, B, C, stride, padding, alpha=alpha, beta=beta)
|
||||
|
||||
assert torch.allclose(D, D_ref)
|
||||
|
||||
|
||||
assert torch.allclose(D, D_ref)
|
||||
|
||||
# Test serial split-K
|
||||
D_serial_split_k = mod.run(weight_size, A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="serial", split_k_slices=3)
|
||||
assert torch.allclose(D, D_serial_split_k)
|
||||
|
||||
|
||||
# Test parallel split-K
|
||||
D_parallel_split_k = mod.run(weight_size, A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="parallel", split_k_slices=7)
|
||||
assert torch.allclose(D, D_parallel_split_k)
|
||||
|
||||
@ -40,9 +40,9 @@ import unittest
|
||||
import cutlass
|
||||
from cutlass import Tensor
|
||||
import cutlass.backend.evt
|
||||
from cutlass.profiler import CUDAEventProfiler
|
||||
from cutlass.shape import GemmCoord
|
||||
from cutlass.utils.datatypes import torch_type
|
||||
from cutlass.utils.profiler import CUDAEventProfiler
|
||||
|
||||
|
||||
class EVTReferenceModule:
|
||||
|
||||
@ -43,7 +43,7 @@ import cutlass
|
||||
from cutlass.backend.utils.device import device_cc
|
||||
import torch
|
||||
|
||||
from utils import LayoutCombination, add_test_gemm
|
||||
from utils import LayoutCombination
|
||||
|
||||
cutlass.set_log_level(logging.WARNING)
|
||||
|
||||
|
||||
@ -67,58 +67,58 @@ add_test_specialized = partial(add_test_gemm, element=cutlass.DataType.f16, cc=c
|
||||
# Tests using TensorOp
|
||||
add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp)
|
||||
|
||||
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
|
||||
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
||||
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
|
||||
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
||||
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NTN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
|
||||
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NTN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
||||
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NTT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
|
||||
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NTT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
||||
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
|
||||
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
||||
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
|
||||
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
||||
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TTN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
|
||||
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TTN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
||||
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TTT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
|
||||
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TTT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
||||
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
|
||||
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 128, 32], warp_count=[1, 2, 1], stages=3)
|
||||
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
|
||||
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 64, 32], warp_count=[2, 1, 1], stages=3)
|
||||
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
|
||||
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=3)
|
||||
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass.DataType.f16,
|
||||
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
||||
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass.DataType.f16,
|
||||
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
|
||||
element_accumulator=cutlass.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
||||
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
|
||||
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
|
||||
element_accumulator=cutlass.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
||||
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
|
||||
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=5)
|
||||
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[2, 2, 2], element_output=cutlass.DataType.f16,
|
||||
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[2, 2, 2], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
|
||||
element_accumulator=cutlass.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
||||
|
||||
# Tests using SIMT
|
||||
add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt)
|
||||
|
||||
add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16,
|
||||
add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
|
||||
add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16,
|
||||
add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2)
|
||||
add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16,
|
||||
add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2)
|
||||
add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16,
|
||||
add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2)
|
||||
add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.f16,
|
||||
add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
|
||||
element_accumulator=cutlass.DataType.f16, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
|
||||
|
||||
# Stream K tests
|
||||
add_test_streamk = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK)
|
||||
add_test_streamk(cls=GemmF16Sm80StreamK, layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
|
||||
add_test_streamk(cls=GemmF16Sm80StreamK, layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
||||
add_test_streamk(cls=GemmF16Sm80StreamK, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
|
||||
add_test_streamk(cls=GemmF16Sm80StreamK, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=5)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@ -135,6 +135,10 @@ add_test_simt(layouts=LayoutCombination.NTN, element_output=cutlass.DataType.f16
|
||||
add_test_simt(layouts=LayoutCombination.TTN, element_output=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 8])
|
||||
add_test_simt(layouts=LayoutCombination.NNT, element_output=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f16, threadblock_shape=[128, 128, 8])
|
||||
|
||||
# Tests with void-C kernels
|
||||
add_test_cluster_shape(layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], stages=None,
|
||||
cluster_shape=[2, 1, 1], element_C=cutlass.DataType.void)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@ -68,31 +68,31 @@ add_test_specialized = partial(add_test_gemm, element=cutlass.DataType.f32, cc=c
|
||||
# Tests using TensorOp
|
||||
add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp)
|
||||
|
||||
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32,
|
||||
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
||||
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[4, 4, 4], element_output=cutlass.DataType.f32,
|
||||
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
||||
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32,
|
||||
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 128, 32], warp_count=[1, 2, 1], stages=3)
|
||||
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32,
|
||||
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 32], warp_count=[1, 1, 1], stages=4)
|
||||
# Tests using SIMT
|
||||
add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt)
|
||||
|
||||
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32,
|
||||
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
|
||||
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32,
|
||||
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2)
|
||||
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32,
|
||||
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2)
|
||||
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32,
|
||||
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2)
|
||||
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.f32,
|
||||
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
|
||||
|
||||
# Stream K tests
|
||||
add_test_streamk = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK)
|
||||
add_test_streamk(cls=GemmF32Sm80StreamK, layouts=LayoutCombination.TTN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32,
|
||||
add_test_streamk(cls=GemmF32Sm80StreamK, layouts=LayoutCombination.TTN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
||||
|
||||
|
||||
|
||||
@ -68,30 +68,30 @@ add_test_specialized = partial(add_test_gemm, element=cutlass.DataType.f64, cc=c
|
||||
# Tests using TensorOp
|
||||
add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp)
|
||||
|
||||
add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64,
|
||||
add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
|
||||
element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3)
|
||||
add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64,
|
||||
add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
|
||||
element_accumulator=cutlass.DataType.f64, threadblock_shape=[ 64, 64, 16], warp_count=[2, 2, 1], stages=4)
|
||||
add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64,
|
||||
add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
|
||||
element_accumulator=cutlass.DataType.f64, threadblock_shape=[ 32, 32, 16], warp_count=[2, 1, 1], stages=5)
|
||||
|
||||
# Tests using SIMT
|
||||
add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt)
|
||||
|
||||
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64,
|
||||
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
|
||||
element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
|
||||
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64,
|
||||
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
|
||||
element_accumulator=cutlass.DataType.f64, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2)
|
||||
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64,
|
||||
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
|
||||
element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2)
|
||||
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64,
|
||||
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
|
||||
element_accumulator=cutlass.DataType.f64, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2)
|
||||
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.f64,
|
||||
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
|
||||
element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
|
||||
|
||||
# Stream K tests
|
||||
add_test_streamk = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK)
|
||||
add_test_streamk(cls=GemmF64Sm80StreamK, layouts=LayoutCombination.NTT, alignments=[1, 1, 1], element_output=cutlass.DataType.f64,
|
||||
add_test_streamk(cls=GemmF64Sm80StreamK, layouts=LayoutCombination.NTT, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
|
||||
element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3)
|
||||
|
||||
|
||||
|
||||
72
test/python/cutlass/gemm/gemm_mixed_sm80.py
Normal file
72
test/python/cutlass/gemm/gemm_mixed_sm80.py
Normal file
@ -0,0 +1,72 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Low-level functionality tests for GEMM with mixed operands on SM80
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
import logging
|
||||
import unittest
|
||||
|
||||
import cutlass
|
||||
from cutlass.backend.utils.device import device_cc
|
||||
|
||||
from utils import LayoutCombination, add_test_gemm
|
||||
|
||||
|
||||
cutlass.set_log_level(logging.WARNING)
|
||||
cc = 80
|
||||
|
||||
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
|
||||
class GemmMixedSm80(unittest.TestCase):
|
||||
"""
|
||||
Wrapper class to which tests will be added dynamically in __main__
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
add_test_mixed = partial(add_test_gemm, cls=GemmMixedSm80, element=cutlass.DataType.f16, cc=cc, cluster_shape=[1, 1, 1],
|
||||
opclass=cutlass.OpcodeClass.TensorOp, threadblock_shape=[128, 128, 64],
|
||||
warp_count=[2, 2, 1], stages=3, element_accumulator=cutlass.DataType.f32)
|
||||
|
||||
# Test with upcast on A
|
||||
add_test_mixed(element_A=cutlass.DataType.s8, alignments=[16, 8, 8], layouts=LayoutCombination.TNT)
|
||||
add_test_mixed(element_A=cutlass.DataType.s8, alignments=[16, 8, 8], layouts=LayoutCombination.TNN)
|
||||
|
||||
# Test with upcast on B
|
||||
add_test_mixed(element_B=cutlass.DataType.s8, alignments=[8, 16, 8], layouts=LayoutCombination.TNT)
|
||||
add_test_mixed(element_B=cutlass.DataType.s8, alignments=[8, 16, 8], layouts=LayoutCombination.TNN)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@ -68,30 +68,30 @@ add_test_specialized = partial(add_test_gemm, element=cutlass.DataType.s8, cc=cc
|
||||
# Tests using TensorOp
|
||||
add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp)
|
||||
|
||||
add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=cutlass.DataType.s8,
|
||||
add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=cutlass.DataType.s8, element_C=cutlass.DataType.s8,
|
||||
element_accumulator=cutlass.DataType.s32, threadblock_shape=[256, 128, 64], warp_count=[4, 2, 1], stages=3)
|
||||
add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.s8,
|
||||
add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.s8, element_C=cutlass.DataType.s8,
|
||||
element_accumulator=cutlass.DataType.s32, threadblock_shape=[128, 256, 64], warp_count=[2, 4, 1], stages=3)
|
||||
add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[16, 16, 4], element_output=cutlass.DataType.s32,
|
||||
add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[16, 16, 4], element_output=cutlass.DataType.s32, element_C=cutlass.DataType.s32,
|
||||
element_accumulator=cutlass.DataType.s32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=4)
|
||||
|
||||
# Tests using SIMT
|
||||
add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt)
|
||||
|
||||
add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.s8,
|
||||
add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.s8, element_C=cutlass.DataType.s8,
|
||||
element_accumulator=cutlass.DataType.s32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
|
||||
add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.s8,
|
||||
add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.s8, element_C=cutlass.DataType.s8,
|
||||
element_accumulator=cutlass.DataType.s32, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2)
|
||||
add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.s8,
|
||||
add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.s8, element_C=cutlass.DataType.s8,
|
||||
element_accumulator=cutlass.DataType.s32, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2)
|
||||
add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.s32,
|
||||
add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.s32, element_C=cutlass.DataType.s32,
|
||||
element_accumulator=cutlass.DataType.s32, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2)
|
||||
add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.s32,
|
||||
add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.s32, element_C=cutlass.DataType.s32,
|
||||
element_accumulator=cutlass.DataType.s32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
|
||||
|
||||
# Stream K tests
|
||||
add_test_streamk = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK)
|
||||
add_test_streamk(cls=GemmS8Sm80StreamK, layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.s8,
|
||||
add_test_streamk(cls=GemmS8Sm80StreamK, layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.s8, element_C=cutlass.DataType.s8,
|
||||
element_accumulator=cutlass.DataType.s32, threadblock_shape=[128, 256, 64], warp_count=[2, 4, 1], stages=3)
|
||||
|
||||
|
||||
|
||||
@ -37,7 +37,7 @@ import subprocess
|
||||
|
||||
import torch
|
||||
|
||||
from cutlass import (
|
||||
from cutlass_library import (
|
||||
DataType,
|
||||
DataTypeSize,
|
||||
GemmUniversalMode,
|
||||
@ -49,7 +49,6 @@ from cutlass import (
|
||||
|
||||
from cutlass.backend import compiler
|
||||
from cutlass.backend.gemm_operation import GemmArguments, GemmOperationUniversal
|
||||
from cutlass.backend.memory_manager import get_allocated_size
|
||||
from cutlass.backend.reduction_operation import ReductionArguments, ReductionOperation
|
||||
from cutlass.shape import GemmCoord, MatrixCoord
|
||||
from cutlass.utils.datatypes import torch_type
|
||||
@ -65,16 +64,6 @@ class GemmUniversalLauncher:
|
||||
compiler_mode= "nvcc",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
# Create the reduction kernel, if needed
|
||||
self.reduction_operation: ReductionOperation = ReductionOperation(
|
||||
shape=MatrixCoord(4, 32 * operation.C.alignment),
|
||||
C=operation.C,
|
||||
element_accumulator=operation.tile_description.math_instruction.element_accumulator,
|
||||
element_compute=operation.epilogue_functor.element_epilogue,
|
||||
epilogue_functor=operation.epilogue_functor,
|
||||
count=operation.C.alignment,
|
||||
)
|
||||
|
||||
self.math_operation = operation.tile_description.math_instruction.math_operation
|
||||
self.verification = verification
|
||||
|
||||
@ -88,19 +77,26 @@ class GemmUniversalLauncher:
|
||||
op_list = [operation]
|
||||
if operation.arch < 90:
|
||||
# Split K via Python is currently only supported for pre-SM90 kernels
|
||||
self.reduction_operation: ReductionOperation = ReductionOperation(
|
||||
shape=MatrixCoord(4, 32 * operation.C.alignment),
|
||||
C=operation.C,
|
||||
element_accumulator=operation.tile_description.math_instruction.element_accumulator,
|
||||
element_compute=operation.epilogue_functor.element_epilogue,
|
||||
epilogue_functor=operation.epilogue_functor,
|
||||
count=operation.C.alignment,
|
||||
)
|
||||
op_list.append(self.reduction_operation)
|
||||
|
||||
compiler.add_module(op_list, bypass_cache=False)
|
||||
|
||||
self.operation = operation
|
||||
|
||||
self.dtype_A = torch_type(operation.A.element)
|
||||
self.dtype_B = torch_type(operation.B.element)
|
||||
self.dtype_A = torch_type(operation.A.element if not self.operation.switched else self.operation.B.element)
|
||||
self.dtype_B = torch_type(operation.B.element if not self.operation.switched else self.operation.A.element)
|
||||
self.dtype_C = torch_type(operation.C.element)
|
||||
self.dtype_D = torch_type(operation.C.element)
|
||||
self.dtype_D = torch_type(operation.epilogue_functor.element_output)
|
||||
|
||||
accumulator_size = DataTypeSize[operation.tile_description.math_instruction.element_accumulator]
|
||||
element_size = DataTypeSize[operation.A.element]
|
||||
element_size = min(DataTypeSize[operation.A.element], DataTypeSize[operation.B.element])
|
||||
|
||||
if element_size == 1:
|
||||
self.rand_max = 1
|
||||
@ -154,7 +150,18 @@ class GemmUniversalLauncher:
|
||||
def reference(self, problem_size, tensor_A, tensor_B, tensor_C, alpha, beta):
|
||||
# If any tensor is on CPU, place all tensors on CPU unless only
|
||||
# tensor C is on CPU
|
||||
devices = [x.device.type for x in [tensor_A, tensor_B, tensor_C]]
|
||||
# Handle mixed-input cases by casting to the larger data type and overriding
|
||||
# to whatever the data type of the larger type is
|
||||
if self.dtype_A != self.dtype_B:
|
||||
if DataTypeSize[self.operation.A.element] < DataTypeSize[self.operation.B.element]:
|
||||
tensor_A = tensor_A.to(self.dtype_B).to(tensor_B.device)
|
||||
else:
|
||||
tensor_B = tensor_B.to(self.dtype_A).to(tensor_A.device)
|
||||
|
||||
devices = [x.device.type for x in [tensor_A, tensor_B]]
|
||||
if tensor_C is not None:
|
||||
devices.append(tensor_C.device.type)
|
||||
|
||||
if "cpu" in devices and devices != ["cuda", "cuda", "cpu"]:
|
||||
device = torch.device("cpu")
|
||||
else:
|
||||
@ -162,14 +169,17 @@ class GemmUniversalLauncher:
|
||||
|
||||
tensor_A = tensor_A.to(device)
|
||||
tensor_B = tensor_B.to(device)
|
||||
tensor_C = tensor_C.to(device)
|
||||
if tensor_C is not None:
|
||||
tensor_C = tensor_C.to(device)
|
||||
|
||||
dtype = torch_type(self.compute_type)
|
||||
alpha_torch = torch.tensor([alpha], device=device).to(dtype)
|
||||
beta_torch = torch.tensor([beta], device=device).to(dtype)
|
||||
|
||||
tmp = tensor_A @ tensor_B
|
||||
tensor_D_ref = (alpha_torch * tmp) + (tensor_C * beta_torch)
|
||||
tensor_D_ref = (alpha_torch * tmp)
|
||||
if tensor_C is not None:
|
||||
tensor_D_ref += (tensor_C * beta_torch)
|
||||
return tensor_D_ref.to(self.dtype_D)
|
||||
|
||||
def run(self, mode, problem_size, batch_count=1, split_k_slices=1, alpha=1.0, beta=0.0):
|
||||
@ -199,12 +209,22 @@ class GemmUniversalLauncher:
|
||||
self.dtype_B,
|
||||
self.operation.B.layout if not self.operation.switched else transpose(self.operation.A.layout),
|
||||
)
|
||||
tensor_C, tensor_C_ref = self.uniform_init(
|
||||
if self.dtype_C is not None:
|
||||
tensor_C, tensor_C_ref = self.uniform_init(
|
||||
(true_batch_count, problem_size.m, problem_size.n),
|
||||
self.dtype_C,
|
||||
self.operation.C.layout if not self.operation.switched else transpose(self.operation.C.layout),
|
||||
)
|
||||
else:
|
||||
tensor_C = None
|
||||
tensor_C_ref = None
|
||||
|
||||
tensor_D, _ = self.uniform_init(
|
||||
(true_batch_count, problem_size.m, problem_size.n),
|
||||
self.dtype_C,
|
||||
self.dtype_D,
|
||||
self.operation.C.layout if not self.operation.switched else transpose(self.operation.C.layout),
|
||||
)
|
||||
tensor_D = torch.zeros_like(tensor_C)
|
||||
tensor_D = torch.zeros_like(tensor_D)
|
||||
|
||||
if self.compute_type in [DataType.s8, DataType.s32, DataType.u8, DataType.u32]:
|
||||
alpha = int(alpha)
|
||||
@ -248,6 +268,10 @@ class GemmUniversalLauncher:
|
||||
if self.verification:
|
||||
if mode == GemmUniversalMode.GemmSplitKParallel:
|
||||
reduction_arguments.sync()
|
||||
|
||||
# Free memory allocated by args because we are not
|
||||
# calling `arguments.sync()` in this case (which will free memory)
|
||||
arguments.free()
|
||||
else:
|
||||
arguments.sync()
|
||||
tensor_D_ref = self.reference(
|
||||
@ -274,9 +298,6 @@ class GemmUniversalLauncher:
|
||||
if mode == GemmUniversalMode.GemmSplitKParallel:
|
||||
del reduction_arguments
|
||||
|
||||
cur_size = get_allocated_size()
|
||||
assert cur_size == 0, f"{cur_size} B of memory were not released after this run"
|
||||
|
||||
return passed
|
||||
|
||||
|
||||
|
||||
@ -30,9 +30,10 @@
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
import cutlass
|
||||
from cutlass_library import SubstituteTemplate
|
||||
|
||||
from cutlass import (
|
||||
import cutlass
|
||||
from cutlass_library import (
|
||||
DataTypeNames,
|
||||
EpilogueScheduleSuffixes,
|
||||
KernelScheduleSuffixes,
|
||||
@ -42,7 +43,6 @@ from cutlass import (
|
||||
ShortLayoutTypeNames
|
||||
)
|
||||
from cutlass.backend import library
|
||||
from cutlass.backend.utils.software import SubstituteTemplate
|
||||
|
||||
from gemm_testbed import test_all_gemm
|
||||
|
||||
@ -82,6 +82,7 @@ def get_name(
|
||||
stages,
|
||||
element_a,
|
||||
element_b,
|
||||
element_c,
|
||||
arch,
|
||||
opclass,
|
||||
kernel_schedule=None,
|
||||
@ -102,6 +103,7 @@ def get_name(
|
||||
:type stages: int
|
||||
:param element_a: data type of operand A
|
||||
:param element_b: data type of operand B
|
||||
:param element_c: data type of operand C
|
||||
:param arch: compute capability of kernel being generated
|
||||
:type arch: int
|
||||
:param opclass: class of operation being performed (e.g., SIMT, Tensor Core)
|
||||
@ -122,7 +124,7 @@ def get_name(
|
||||
"arch": str(arch),
|
||||
"eA": DataTypeNames[element_a],
|
||||
"eB": DataTypeNames[element_b],
|
||||
"eC": DataTypeNames[element_output],
|
||||
"eC": DataTypeNames[element_c],
|
||||
"lA": ShortLayoutTypeNames[layouts[0]],
|
||||
"lB": ShortLayoutTypeNames[layouts[1]],
|
||||
"lC": ShortLayoutTypeNames[layouts[2]],
|
||||
@ -161,7 +163,10 @@ def add_test_gemm(
|
||||
swizzle=None,
|
||||
kernel_schedule=None,
|
||||
epilogue_schedule=None,
|
||||
compilation_modes=['nvcc', 'nvrtc']):
|
||||
compilation_modes=['nvcc', 'nvrtc'],
|
||||
element_A=None,
|
||||
element_B=None,
|
||||
element_C=None):
|
||||
"""
|
||||
Create test-running functions with the given specification and set it as a method of ``cls``.
|
||||
|
||||
@ -195,22 +200,38 @@ def add_test_gemm(
|
||||
:param epilogue_schedule: epilogue schedule to use
|
||||
:type epilogue_schedule: cutlass.EpilogueScheduleType
|
||||
:param compilation_modes: list of compilers to used in testing the kernel (options: 'nvrtc', 'nvcc')
|
||||
:type compilation_modes: list
|
||||
:type compilation_modes: list,
|
||||
:param element_A: data type of operand A. If set, overrides ``element``
|
||||
:type element_A: cutlass.DataType
|
||||
:param element_B: data type of operand B. If set, overrides ``element``
|
||||
:type element_B: cutlass.DataType
|
||||
:param element_C: data type of operand C. If set, overrides ``element``
|
||||
:type element_C: cutlass.DataType
|
||||
"""
|
||||
|
||||
if element_A is None:
|
||||
element_A = element
|
||||
if element_B is None:
|
||||
element_B = element
|
||||
if element_C is None:
|
||||
element_C = element
|
||||
if element_output is None:
|
||||
element_output = element
|
||||
if element_accumulator is None:
|
||||
element_accumulator = element
|
||||
|
||||
for compilation_mode in compilation_modes:
|
||||
def run(self):
|
||||
"""
|
||||
Dynamically-generated function that constructs a GEMM operation and verifies it against
|
||||
multiple test cases.
|
||||
"""
|
||||
element_A = element
|
||||
element_B = element
|
||||
|
||||
layout_A, layout_B, layout_C = layouts
|
||||
alignment_A, alignment_B, alignment_C = alignments
|
||||
|
||||
plan = cutlass.op.Gemm(element_A=element_A, element_B=element_B,
|
||||
element_C=element_output, element_D=element_output,
|
||||
element_C=element_C, element_D=element_output,
|
||||
layout_A=layout_A, layout_B=layout_B, layout_C=layout_C,
|
||||
element_accumulator=element_accumulator,
|
||||
kernel_cc=cc)
|
||||
@ -233,7 +254,7 @@ def add_test_gemm(
|
||||
name = get_name(
|
||||
layouts=layouts, alignments=alignments, element_output=element_output, element_accumulator=element_accumulator,
|
||||
element_epilogue=element_epilogue, cluster_shape=cluster_shape, threadblock_shape=threadblock_shape,
|
||||
stages=stages, element_a=element, element_b=element, arch=cc, opclass=opclass,
|
||||
stages=stages, element_a=element_A, element_b=element_B, element_c=element_C, arch=cc, opclass=opclass,
|
||||
kernel_schedule=kernel_schedule, epilogue_schedule=epilogue_schedule, suffix=f'_{compilation_mode}')
|
||||
|
||||
setattr(cls, name, run)
|
||||
|
||||
57
test/python/cutlass/installation.py
Normal file
57
test/python/cutlass/installation.py
Normal file
@ -0,0 +1,57 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Tests for a successful installation of the CUTLASS Python interface
|
||||
"""
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import cutlass
|
||||
import cutlass_library
|
||||
|
||||
|
||||
class InstallationTest(unittest.TestCase):
|
||||
def test_cutlass_source_paths(self):
|
||||
"""
|
||||
Tests that CUTLASS source is available as part of the cutlass and cutlass_library packages
|
||||
"""
|
||||
src_file = 'include/cutlass/cutlass.h'
|
||||
library_file = os.path.join(cutlass_library.source_path, src_file)
|
||||
cutlass_file = os.path.join(cutlass.CUTLASS_PATH, src_file)
|
||||
assert os.path.isfile(library_file), f"Unable to locate file {library_file}. Installation has not succeeded."
|
||||
assert os.path.isfile(cutlass_file), f"Unable to locate file {cutlass_file}. Installation has not succeeded."
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -50,7 +50,7 @@ class Conv2dEquivalence:
|
||||
"""
|
||||
def __init__(self, conv_kind, element_A, element_B, element_C, element_D, element_accumulator,
|
||||
alignment_A, alignment_B, alignment_C):
|
||||
|
||||
|
||||
self.element_A = element_A
|
||||
self.element_B = element_B
|
||||
self.element_C = element_C
|
||||
@ -59,21 +59,21 @@ class Conv2dEquivalence:
|
||||
self.alignment_A = alignment_A
|
||||
self.alignment_B = alignment_B
|
||||
self.alignment_C = alignment_C
|
||||
|
||||
|
||||
self.conv_kind = conv_kind
|
||||
|
||||
|
||||
self.plan = cutlass.op.Conv2d(
|
||||
kind=self.conv_kind, element_A=element_A, element_B=element_B, element_C=element_C,
|
||||
element_D=element_D, element_accumulator=element_accumulator)
|
||||
|
||||
|
||||
self.op = self.plan.construct(
|
||||
alignment_A=self.alignment_A, alignment_B=self.alignment_B,
|
||||
alignment_A=self.alignment_A, alignment_B=self.alignment_B,
|
||||
alignment_C=self.alignment_C)
|
||||
|
||||
|
||||
def _plans_equal(self, other_plan) -> bool:
|
||||
"""
|
||||
Compares whether two plans are equal
|
||||
|
||||
|
||||
:param other_plan: plan to compare against the default Conv2d
|
||||
:type other_plan: cutlass.op.Conv2d
|
||||
|
||||
@ -81,9 +81,9 @@ class Conv2dEquivalence:
|
||||
:rtype: bool
|
||||
"""
|
||||
other_op = other_plan.construct(
|
||||
alignment_A=self.alignment_A, alignment_B=self.alignment_B,
|
||||
alignment_A=self.alignment_A, alignment_B=self.alignment_B,
|
||||
alignment_C=self.alignment_C)
|
||||
|
||||
|
||||
return self.op.rt_module.emit() == other_op.rt_module.emit()
|
||||
|
||||
def generic_test(self):
|
||||
@ -91,16 +91,16 @@ class Conv2dEquivalence:
|
||||
Tests the equivalence of various constructions of the Conv2d interface when using CUTLASS data types
|
||||
and layouts for constructing the Conv2d interface
|
||||
"""
|
||||
if not datatypes.numpy_available:
|
||||
if not datatypes.is_numpy_available():
|
||||
return
|
||||
|
||||
|
||||
# Test when specifying all parameters
|
||||
plan_other = cutlass.op.Conv2d(
|
||||
kind=self.conv_kind,
|
||||
element_A=self.element_A, element_B=self.element_B, element_C=self.element_C,
|
||||
element_D=self.element_D, element_accumulator=self.element_accumulator)
|
||||
assert self._plans_equal(plan_other)
|
||||
|
||||
|
||||
# Test when specifying all parameters but A
|
||||
plan_other = cutlass.op.Conv2d(
|
||||
kind=self.conv_kind,
|
||||
@ -108,7 +108,7 @@ class Conv2dEquivalence:
|
||||
element_D=self.element_D, element_accumulator=self.element_accumulator,
|
||||
element=self.element_A)
|
||||
assert self._plans_equal(plan_other)
|
||||
|
||||
|
||||
# Test when specifying all parameters but A and B as tensors using generic element and output
|
||||
plan_other = cutlass.op.Conv2d(
|
||||
kind=self.conv_kind,
|
||||
@ -116,7 +116,7 @@ class Conv2dEquivalence:
|
||||
element_D=self.element_D, element_accumulator=self.element_accumulator,
|
||||
element=self.element_A)
|
||||
assert self._plans_equal(plan_other)
|
||||
|
||||
|
||||
# Test without explicit accumulator. Only run if the type of C and the accumulator are equal
|
||||
if self.element_C == self.element_accumulator:
|
||||
plan_other = cutlass.op.Conv2d(
|
||||
@ -125,18 +125,18 @@ class Conv2dEquivalence:
|
||||
element_D=self.element_D,
|
||||
element=self.element_A)
|
||||
assert self._plans_equal(plan_other)
|
||||
|
||||
|
||||
# Test with only the generic types. Only rune if the types of A, B, C, and D are the same
|
||||
if (self.element_A == self.element_B and self.element_A == self.element_C and self.element_A == self.element_D
|
||||
and self.element_A == self.element_accumulator):
|
||||
plan_other = cutlass.op.Conv2d(kind=self.conv_kind, element=self.element_A)
|
||||
assert self._plans_equal(plan_other)
|
||||
|
||||
|
||||
def numpy_test(self):
|
||||
"""
|
||||
Tests the equivalence of various constructions of the Conv2d interface when using numpy as a frontend
|
||||
"""
|
||||
if not datatypes.numpy_available:
|
||||
if not datatypes.is_numpy_available():
|
||||
return
|
||||
|
||||
import numpy as np
|
||||
@ -145,7 +145,7 @@ class Conv2dEquivalence:
|
||||
type_C = datatypes.numpy_type(self.element_C)
|
||||
type_D = datatypes.numpy_type(self.element_D)
|
||||
type_accum = datatypes.numpy_type(self.element_accumulator)
|
||||
|
||||
|
||||
size = (2, 2)
|
||||
A = np.zeros(size, dtype=type_A)
|
||||
B = np.zeros(size, dtype=type_B)
|
||||
@ -153,49 +153,49 @@ class Conv2dEquivalence:
|
||||
D = np.zeros(size, dtype=type_D)
|
||||
|
||||
return self.tensor_test(type_A, type_B, type_C, type_D, type_accum, A, B, C, D)
|
||||
|
||||
|
||||
def torch_test(self):
|
||||
"""
|
||||
Tests the equivalence of various constructions of the Conv2d interface when using torch as a frontend
|
||||
"""
|
||||
if not datatypes.torch_available:
|
||||
if not datatypes.is_torch_available():
|
||||
return
|
||||
|
||||
|
||||
import torch
|
||||
type_A = datatypes.torch_type(self.element_A)
|
||||
type_B = datatypes.torch_type(self.element_B)
|
||||
type_C = datatypes.torch_type(self.element_C)
|
||||
type_D = datatypes.torch_type(self.element_D)
|
||||
type_accum = datatypes.torch_type(self.element_accumulator)
|
||||
|
||||
|
||||
size = (2, 2)
|
||||
|
||||
|
||||
A = torch.empty(size, dtype=type_A)
|
||||
B = torch.empty(size, dtype=type_B)
|
||||
C = torch.empty(size, dtype=type_C)
|
||||
D = torch.empty(size, dtype=type_D)
|
||||
|
||||
|
||||
return self.tensor_test(type_A, type_B, type_C, type_D, type_accum, A, B, C, D)
|
||||
|
||||
|
||||
def tensor_test(self, type_A, type_B, type_C, type_D, type_accum, A, B, C, D):
|
||||
# Test when specifying all parameters via tensors
|
||||
plan_np = cutlass.op.Conv2d(kind=self.conv_kind, A=A, B=B, C=C, D=D, element_accumulator=type_accum)
|
||||
assert self._plans_equal(plan_np)
|
||||
|
||||
|
||||
# Test when specifying all parameters but A as tensors
|
||||
plan_np = cutlass.op.Conv2d(kind=self.conv_kind, B=B, C=C, D=D, element_accumulator=type_accum, element_A=type_A)
|
||||
assert self._plans_equal(plan_np)
|
||||
|
||||
|
||||
# Test when specifying all parameters but A and B as tensors and using generic element and output
|
||||
if type_A == type_B:
|
||||
plan_np = cutlass.op.Conv2d(kind=self.conv_kind, C=C, D=D, element_accumulator=type_accum, element=type_A)
|
||||
assert self._plans_equal(plan_np)
|
||||
|
||||
|
||||
# Test without explicit accumulator. Only run if the type of C and the accumulator.
|
||||
if type_C == type_accum:
|
||||
plan_np = cutlass.op.Conv2d(kind=self.conv_kind, A=A, B=B, C=C, D=D)
|
||||
assert self._plans_equal(plan_np)
|
||||
|
||||
|
||||
# Test with only the generic types and layouts. Only run if types and layouts of A, B, C, and D are the same.
|
||||
if (type_A == type_B and type_A == type_C and type_A == type_D and type_A == type_accum):
|
||||
plan_np = cutlass.op.Conv2d(kind=self.conv_kind, element=type_A)
|
||||
@ -223,20 +223,20 @@ type2alignment = {
|
||||
}
|
||||
|
||||
def add_test(conv_kind, element_A, element_B, element_C, element_D, element_accumulator):
|
||||
|
||||
|
||||
test_name = f"test_conv2d_{conv_kind}_{element_A}_{element_B}_{element_C}_{element_D}_{element_accumulator}"
|
||||
|
||||
|
||||
def run(self):
|
||||
conv2d_eq = Conv2dEquivalence(
|
||||
conv_kind=conv_kind,
|
||||
conv_kind=conv_kind,
|
||||
element_A=element_A, element_B=element_B,
|
||||
element_C=element_C, element_D=element_D,
|
||||
element_accumulator=element_accumulator,
|
||||
element_accumulator=element_accumulator,
|
||||
alignment_A=type2alignment[element_A], alignment_B=type2alignment[element_B],
|
||||
alignment_C=type2alignment[element_C]
|
||||
)
|
||||
conv2d_eq.test_all()
|
||||
|
||||
|
||||
setattr(ConvEquivalenceTest, test_name, run)
|
||||
|
||||
for conv_kind in ["fprop", "wgrad", "dgrad"]:
|
||||
@ -255,25 +255,25 @@ class Conv2dErrorTests(unittest.TestCase):
|
||||
"""
|
||||
Tests various error scenarios that arise with the high-level Gemm interface
|
||||
"""
|
||||
|
||||
|
||||
def test_alignment(self):
|
||||
"""
|
||||
Tests case in which the alignment specified is unsupported
|
||||
"""
|
||||
plan = cutlass.op.Conv2d(kind="fprop", element=cutlass.DataType.f16)
|
||||
|
||||
|
||||
with ExpectException(True, 'Alignment 3 is not supported for F16. The construction should fail.'):
|
||||
op = plan.construct(alignment_A=3, alignment_B=3, alignment_C=3)
|
||||
|
||||
|
||||
def test_invalid_tile_description(self):
|
||||
"""
|
||||
Tests scenarios in which an invalid tile description is provided for a given CC
|
||||
"""
|
||||
plan = cutlass.op.Conv2d(kind="fprop", element=cutlass.DataType.f16)
|
||||
|
||||
|
||||
td = plan.tile_descriptions()[0]
|
||||
td.threadblock_shape=[17, 32, 5]
|
||||
|
||||
|
||||
plan.tile_description = td
|
||||
with ExpectException(True, 'The threadblock shape is invalid. The compilation should fail.'):
|
||||
plan.compile()
|
||||
|
||||
@ -93,13 +93,16 @@ class EVTErrorTests(unittest.TestCase):
|
||||
"""
|
||||
Test when the epilogue consumes too much shared memory
|
||||
"""
|
||||
def evt_too_much_shared_memory(accum, C1, C2, C3, C4, C5):
|
||||
def evt_too_much_shared_memory(accum, C1, C2, C3, C4, C5, C6, C7, C8):
|
||||
D1 = accum + C1
|
||||
D2 = D1 + C2
|
||||
D3 = D2 + C3
|
||||
D4 = D3 + C4
|
||||
D = D4 + C5
|
||||
return D, D1, D2, D3, D4
|
||||
D5 = D4 + C5
|
||||
D6 = D5 + C6
|
||||
D7 = D6 + C7
|
||||
D = D7 + C8
|
||||
return D, D1, D2, D3, D4, D5, D6, D7
|
||||
|
||||
example_tensors = {
|
||||
"accum": self.fake_tensor(np.float16, (6, 512, 512)),
|
||||
@ -108,10 +111,16 @@ class EVTErrorTests(unittest.TestCase):
|
||||
"C3": self.fake_tensor(np.float16, (6, 512, 512)),
|
||||
"C4": self.fake_tensor(np.float16, (6, 512, 512)),
|
||||
"C5": self.fake_tensor(np.float16, (6, 512, 512)),
|
||||
"C6": self.fake_tensor(np.float16, (6, 512, 512)),
|
||||
"C7": self.fake_tensor(np.float16, (6, 512, 512)),
|
||||
"C8": self.fake_tensor(np.float16, (6, 512, 512)),
|
||||
"D1": self.fake_tensor(np.float16, (6, 512, 512)),
|
||||
"D2": self.fake_tensor(np.float16, (6, 512, 512)),
|
||||
"D3": self.fake_tensor(np.float16, (6, 512, 512)),
|
||||
"D4": self.fake_tensor(np.float16, (6, 512, 512)),
|
||||
"D5": self.fake_tensor(np.float16, (6, 512, 512)),
|
||||
"D6": self.fake_tensor(np.float16, (6, 512, 512)),
|
||||
"D7": self.fake_tensor(np.float16, (6, 512, 512)),
|
||||
"D": self.fake_tensor(np.float16, (6, 512, 512))
|
||||
}
|
||||
|
||||
|
||||
@ -85,7 +85,7 @@ class GemmEquivalence:
|
||||
Tests the equivalence of various constructions of the Gemm interface when using CUTLASS data types
|
||||
and layouts for constructing the Gemm interface
|
||||
"""
|
||||
if not datatypes.numpy_available:
|
||||
if not datatypes.is_numpy_available():
|
||||
return
|
||||
|
||||
# Test when specifying all parameters
|
||||
@ -126,7 +126,7 @@ class GemmEquivalence:
|
||||
"""
|
||||
Tests the equivalence of various constructions of the Gemm interface when using numpy as a frontend
|
||||
"""
|
||||
if not datatypes.numpy_available:
|
||||
if not datatypes.is_numpy_available():
|
||||
return
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -114,7 +114,7 @@ void run_test_integer_range_all() {
|
||||
);
|
||||
|
||||
destination.sync_host();
|
||||
|
||||
|
||||
// Verify conversion
|
||||
bool passed = true;
|
||||
for (int i = 0; i < kN; ++i) {
|
||||
@ -124,7 +124,7 @@ void run_test_integer_range_all() {
|
||||
}
|
||||
}
|
||||
EXPECT_TRUE(passed) << " FastNumericArrayConverter failed";
|
||||
|
||||
|
||||
// Print out results for the failed conversion.
|
||||
if (!passed) {
|
||||
for (int i = 0; i < kN; ++i) {
|
||||
|
||||
@ -48,11 +48,20 @@ TEST(float_e4m3_t, host_conversion) {
|
||||
for (int i = -8; i < 8; ++i) {
|
||||
float f = static_cast<float>(i);
|
||||
|
||||
cutlass::int4b_t s = static_cast<cutlass::int4b_t>(i);
|
||||
FP8 w = static_cast<FP8>(s);
|
||||
FP8 x = static_cast<FP8>(i);
|
||||
FP8 y = static_cast<FP8>(f);
|
||||
|
||||
EXPECT_TRUE(static_cast<cutlass::int4b_t>(w) == s);
|
||||
EXPECT_TRUE(static_cast<int>(x) == i);
|
||||
EXPECT_TRUE(static_cast<float>(y) == f);
|
||||
|
||||
if (i >= 0) {
|
||||
cutlass::uint4b_t u = static_cast<cutlass::uint4b_t>(i);
|
||||
FP8 z = static_cast<FP8>(u);
|
||||
EXPECT_TRUE(static_cast<unsigned>(z) == u);
|
||||
}
|
||||
}
|
||||
|
||||
// Try out default-ctor (zero initialization of primitive proxy type)
|
||||
@ -72,11 +81,20 @@ TEST(float_e5m2_t, host_conversion) {
|
||||
for (int i = -8; i < 8; ++i) {
|
||||
float f = static_cast<float>(i);
|
||||
|
||||
cutlass::int4b_t s = static_cast<cutlass::int4b_t>(i);
|
||||
FP8 w = static_cast<FP8>(s);
|
||||
FP8 x = static_cast<FP8>(i);
|
||||
FP8 y = static_cast<FP8>(f);
|
||||
|
||||
EXPECT_TRUE(static_cast<cutlass::int4b_t>(w) == s);
|
||||
EXPECT_TRUE(static_cast<int>(x) == i);
|
||||
EXPECT_TRUE(static_cast<float>(y) == f);
|
||||
|
||||
if (i >= 0) {
|
||||
cutlass::uint4b_t u = static_cast<cutlass::uint4b_t>(i);
|
||||
FP8 z = static_cast<FP8>(u);
|
||||
EXPECT_TRUE(static_cast<cutlass::uint4b_t>(z) == u);
|
||||
}
|
||||
}
|
||||
|
||||
// Try out default-ctor (zero initialization of primitive proxy type)
|
||||
|
||||
@ -60,7 +60,7 @@ __global__ void convert(
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Destination, typename Source, int Count>
|
||||
template <typename Destination, typename Source, int Count, int Range = 4>
|
||||
void run_test(const char dest_name[], const char source_name[]) {
|
||||
const int kN = Count;
|
||||
|
||||
@ -69,9 +69,11 @@ void run_test(const char dest_name[], const char source_name[]) {
|
||||
|
||||
cutlass::HostTensor<Destination, cutlass::layout::RowMajor> destination({1, kN});
|
||||
cutlass::HostTensor<Source, cutlass::layout::RowMajor> source({1, kN});
|
||||
auto source_ref = source.host_ref();
|
||||
auto destination_ref = destination.host_ref();
|
||||
|
||||
for (int i = 0; i < kN; ++i) {
|
||||
source.host_data()[i] = Source(i % 4);
|
||||
source_ref.at({0, i}) = Source(i % Range);
|
||||
}
|
||||
|
||||
source.sync_device();
|
||||
@ -84,9 +86,67 @@ void run_test(const char dest_name[], const char source_name[]) {
|
||||
destination.sync_host();
|
||||
|
||||
for (int i = 0; i < kN; ++i) {
|
||||
EXPECT_TRUE(float(destination.host_data()[i]) == float(source.host_data()[i]))
|
||||
<< "Destination type: " << dest_name
|
||||
<< ", Source type: " << source_name
|
||||
EXPECT_TRUE(float(destination_ref.at({0, i})) == float(source_ref.at({0, i})))
|
||||
<< "Destination type: " << dest_name << " "<< float(destination_ref.at({0, i}))
|
||||
<< ", Source type: " << source_name << " " << float(source_ref.at({0, i}))
|
||||
<< ", Count: " << Count;
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Destination, typename Source, typename ScaleFactor, int Count>
|
||||
__global__ void convert_with_scale_factor(
|
||||
cutlass::Array<Destination, Count> *destination,
|
||||
cutlass::Array<Source, Count> const *source,
|
||||
cutlass::Array<ScaleFactor, Count> const *scale_factor) {
|
||||
|
||||
cutlass::NumericArrayConverter<Destination, Source, Count> convert;
|
||||
|
||||
*destination = convert(*source, *scale_factor);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Destination, typename Source, typename ScaleFactor, int Count, int Range = 4>
|
||||
void run_test_with_scalefactor(const char dest_name[], const char source_name[], const char scale_factor_name[]) {
|
||||
const int kN = Count;
|
||||
|
||||
dim3 grid(1, 1);
|
||||
dim3 block(1, 1);
|
||||
|
||||
cutlass::HostTensor<Destination, cutlass::layout::RowMajor> destination({1, kN});
|
||||
cutlass::HostTensor<Source, cutlass::layout::RowMajor> source({1, kN});
|
||||
cutlass::HostTensor<ScaleFactor, cutlass::layout::RowMajor> scale_factor({1, kN});
|
||||
auto source_ref = source.host_ref();
|
||||
auto destination_ref = destination.host_ref();
|
||||
auto scale_factor_ref = scale_factor.host_ref();
|
||||
|
||||
|
||||
for (int i = 0; i < kN; ++i) {
|
||||
source_ref.at({0, i}) = Source(i % Range);
|
||||
}
|
||||
|
||||
for (int i = 0; i < kN; ++i) {
|
||||
scale_factor_ref.at({0, i}) = ScaleFactor(1 + i % 8);
|
||||
}
|
||||
|
||||
source.sync_device();
|
||||
scale_factor.sync_device();
|
||||
|
||||
convert_with_scale_factor<Destination, Source, ScaleFactor, kN><<< grid, block >>>(
|
||||
reinterpret_cast<cutlass::Array<Destination, kN> *>(destination.device_data()),
|
||||
reinterpret_cast<cutlass::Array<Source, kN> const *>(source.device_data()),
|
||||
reinterpret_cast<cutlass::Array<ScaleFactor, kN> const *>(scale_factor.device_data())
|
||||
);
|
||||
|
||||
destination.sync_host();
|
||||
|
||||
for (int i = 0; i < kN; ++i) {
|
||||
float ref = float(source_ref.at({0, i})) / float(scale_factor_ref.at({0, i}));
|
||||
EXPECT_TRUE(float(destination_ref.at({0, i})) == ref)
|
||||
<< "Destination type: " << dest_name << " "<< float(destination_ref.at({0, i}))
|
||||
<< ", Source type: " << source_name << " " << float(source_ref.at({0, i}))
|
||||
<< ", Count: " << Count;
|
||||
}
|
||||
}
|
||||
@ -98,7 +158,16 @@ void run_test(const char dest_name[], const char source_name[]) {
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(NumericConversion, f32_to_f16_rn) {
|
||||
int const kN = 1;
|
||||
constexpr int kN = 1;
|
||||
using Source = float;
|
||||
const char source_name[] = "float";
|
||||
using Destination = cutlass::half_t;
|
||||
const char dest_name[] = "half_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, f32x2_to_f16x2_rn) {
|
||||
constexpr int kN = 2;
|
||||
using Source = float;
|
||||
const char source_name[] = "float";
|
||||
using Destination = cutlass::half_t;
|
||||
@ -107,7 +176,7 @@ TEST(NumericConversion, f32_to_f16_rn) {
|
||||
}
|
||||
|
||||
TEST(NumericConversion, f32x8_to_f16x8_rn) {
|
||||
int const kN = 8;
|
||||
constexpr int kN = 8;
|
||||
using Source = float;
|
||||
const char source_name[] = "float";
|
||||
using Destination = cutlass::half_t;
|
||||
@ -394,4 +463,50 @@ TEST(NumericConversion, fe5m2_to_bf16_array) {
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// These are included as regression tests for a special case when N = 4.
|
||||
TEST(NumericConversion, int4b_t_to_fe5m2_t_array_4) {
|
||||
int const kN = 4;
|
||||
using Source = cutlass::int4b_t;
|
||||
const char source_name[] = "int4b_t";
|
||||
using Destination = cutlass::float_e5m2_t;
|
||||
const char dest_name[] = "float_e5m2_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, int_to_fe4m3_t_array_4) {
|
||||
int const kN = 4;
|
||||
using Source = int;
|
||||
const char source_name[] = "int";
|
||||
using Destination = cutlass::float_e4m3_t;
|
||||
const char dest_name[] = "float_e4m3_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, int2b_t_to_fe4m3_t_array_4) {
|
||||
int const kN = 4;
|
||||
using Source = cutlass::int2b_t;
|
||||
const char source_name[] = "int2b_t";
|
||||
using Destination = cutlass::float_e4m3_t;
|
||||
const char dest_name[] = "float_e4m3_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, fe5m2_t_to_double_array_4) {
|
||||
int const kN = 4;
|
||||
using Source = cutlass::float_e5m2_t;
|
||||
const char source_name[] = "float_e5m2_t";
|
||||
using Destination = double;
|
||||
const char dest_name[] = "double";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, int_to_fe4m3_t_array_32) {
|
||||
int const kN = 32;
|
||||
using Source = int;
|
||||
const char source_name[] = "int";
|
||||
using Destination = cutlass::float_e4m3_t;
|
||||
const char dest_name[] = "float_e4m3_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -32,6 +32,7 @@
|
||||
#include "cutlass_unit_test.h"
|
||||
|
||||
#include <cutlass/trace.h>
|
||||
|
||||
#include <cute/pointer.hpp>
|
||||
|
||||
TEST(CuTe_core, Pointer)
|
||||
@ -45,7 +46,7 @@ TEST(CuTe_core, Pointer)
|
||||
// Test T* overloads (T can be nonconst or const)
|
||||
{
|
||||
using T = float;
|
||||
using expected_type = cute::gmem_ptr<T>;
|
||||
using expected_type = cute::gmem_ptr<T*>;
|
||||
T* p = nullptr;
|
||||
|
||||
// explicit template argument
|
||||
@ -58,7 +59,7 @@ TEST(CuTe_core, Pointer)
|
||||
}
|
||||
{
|
||||
using T = float const;
|
||||
using expected_type = cute::gmem_ptr<T>;
|
||||
using expected_type = cute::gmem_ptr<T*>;
|
||||
T* p = nullptr;
|
||||
|
||||
// explicit template argument
|
||||
@ -74,7 +75,7 @@ TEST(CuTe_core, Pointer)
|
||||
// (these require an explicit template argument)
|
||||
{
|
||||
using T = float;
|
||||
using expected_type = cute::gmem_ptr<T>;
|
||||
using expected_type = cute::gmem_ptr<T*>;
|
||||
void* p = nullptr;
|
||||
|
||||
auto gmem_p0 = cute::make_gmem_ptr<T>(p);
|
||||
@ -82,7 +83,7 @@ TEST(CuTe_core, Pointer)
|
||||
}
|
||||
{
|
||||
using T = float const;
|
||||
using expected_type = cute::gmem_ptr<T>;
|
||||
using expected_type = cute::gmem_ptr<T*>;
|
||||
void const* p = nullptr;
|
||||
|
||||
auto gmem_p0 = cute::make_gmem_ptr<T>(p);
|
||||
@ -92,14 +93,14 @@ TEST(CuTe_core, Pointer)
|
||||
// Test nullptr_t overload.
|
||||
{
|
||||
using T = float;
|
||||
using expected_type = cute::gmem_ptr<T>;
|
||||
using expected_type = cute::gmem_ptr<T*>;
|
||||
|
||||
auto gmem_p0 = cute::make_gmem_ptr<T>(nullptr);
|
||||
static_assert(cute::is_same_v<decltype(gmem_p0), expected_type>);
|
||||
}
|
||||
{
|
||||
using T = float const;
|
||||
using expected_type = cute::gmem_ptr<T>;
|
||||
using expected_type = cute::gmem_ptr<T*>;
|
||||
|
||||
auto gmem_p0 = cute::make_gmem_ptr<T>(nullptr);
|
||||
static_assert(cute::is_same_v<decltype(gmem_p0), expected_type>);
|
||||
|
||||
@ -416,7 +416,6 @@ TEST(SM90_CuTe_Hopper, Tma_Load_InternalType)
|
||||
test_tma_load<half_t, uint64_t>(gmem_layout, smem_layout);
|
||||
test_tma_load< float, uint64_t>(gmem_layout, smem_layout);
|
||||
test_tma_load<double, uint64_t>(gmem_layout, smem_layout);
|
||||
|
||||
}
|
||||
|
||||
// Complex<double> is 128bit, which the TMA has no concept of
|
||||
|
||||
@ -43,7 +43,7 @@ namespace cutlass::test {
|
||||
template <class ElementType, class SmemLayout>
|
||||
struct SharedStorage
|
||||
{
|
||||
cute::array_aligned<ElementType, cute::cosize_v<SmemLayout>> smem;
|
||||
cute::ArrayEngine<ElementType, cute::cosize_v<SmemLayout>> smem;
|
||||
cute::uint64_t tma_load_mbar[1];
|
||||
};
|
||||
|
||||
@ -62,26 +62,26 @@ tma_test_device_cute(T const* g_in, T* g_out,
|
||||
extern __shared__ char shared_memory[];
|
||||
using SharedStorage = SharedStorage<T, SmemLayout>;
|
||||
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(shared_memory);
|
||||
|
||||
// Construct SMEM tensor
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.data()), smem_layout); // (CTA_TILE_M,CTA_TILE_N,...)
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.begin()), smem_layout); // (CTA_TILE_M,CTA_TILE_N,...)
|
||||
// Shared memory barriers use 64bits in SMEM for synchronization
|
||||
uint64_t* tma_load_mbar = shared_storage.tma_load_mbar;
|
||||
|
||||
// TMA requires special handling of strides to deal with coord codomain mapping
|
||||
// Represent the full tensors -- get these from TMA
|
||||
Tensor mA = tma.get_tma_tensor(shape(gmem_layout));
|
||||
Tensor mB = make_tensor(make_gmem_ptr(g_out), gmem_layout);
|
||||
Tensor mB = make_tensor(make_gmem_ptr<T>(g_out), gmem_layout);
|
||||
|
||||
constexpr int R = rank_v<CTA_Tiler>;
|
||||
Tensor gA = local_tile(mA, cta_tiler, repeat<R>(_)); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...)
|
||||
Tensor gB = local_tile(mB, cta_tiler, repeat<R>(_)); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...)
|
||||
Tensor gA = flat_divide(mA, cta_tiler); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...)
|
||||
Tensor gB = flat_divide(mB, cta_tiler); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...)
|
||||
|
||||
//
|
||||
// Prepare the TMA_LOAD
|
||||
//
|
||||
|
||||
auto cta_tma = tma.get_slice(Int<0>{}); // CTA slice
|
||||
|
||||
Tensor tAgA_x = cta_tma.partition_S(gA); // (TMA,TMA_M,TMA_N,REST_M,REST_N)
|
||||
Tensor tAsA_x = cta_tma.partition_D(sA); // (TMA,TMA_M,TMA_N)
|
||||
|
||||
@ -89,11 +89,13 @@ tma_test_device_cute(T const* g_in, T* g_out,
|
||||
if (thread0()) {
|
||||
print(tma);
|
||||
print("TILE : "); print(cta_tiler); print("\n");
|
||||
print(" mA : "); print( mA.data()); print(" o "); print( mA.layout()); print("\n");
|
||||
print(" gA : "); print( gA.data()); print(" o "); print( gA.layout()); print("\n");
|
||||
print("tAgA_x: "); print(tAgA_x.data()); print(" o "); print(tAgA_x.layout()); print("\n");
|
||||
print(" sA : "); print( sA.data()); print(" o "); print( sA.layout()); print("\n");
|
||||
print("tAsA_x: "); print(tAsA_x.data()); print(" o "); print(tAsA_x.layout()); print("\n");
|
||||
print(" mA : "); print( mA); print("\n");
|
||||
print(" mB : "); print( mB); print("\n");
|
||||
print(" gA : "); print( gA); print("\n");
|
||||
print(" gB : "); print( gB); print("\n");
|
||||
print(" sA : "); print( sA); print("\n");
|
||||
print("tAgA_x: "); print(tAgA_x); print("\n");
|
||||
print("tAsA_x: "); print(tAsA_x); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -111,9 +113,9 @@ tma_test_device_cute(T const* g_in, T* g_out,
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print("tAgA : "); print(tAgA.data()); print(" o "); print(tAgA.layout()); print("\n");
|
||||
print("tAsA : "); print(tAsA.data()); print(" o "); print(tAsA.layout()); print("\n");
|
||||
print("tBgB : "); print(tBgB.data()); print(" o "); print(tBgB.layout()); print("\n");
|
||||
print("tAgA : "); print(tAgA); print("\n");
|
||||
print("tAsA : "); print(tAsA); print("\n");
|
||||
print("tBgB : "); print(tBgB); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -121,7 +123,7 @@ tma_test_device_cute(T const* g_in, T* g_out,
|
||||
for (int stage = 0; stage < size<1>(tAgA); ++stage)
|
||||
{
|
||||
// Set the bytes transferred in this TMA transaction (may involve multiple issues)
|
||||
constexpr int kTmaTransactionBytes = size(sA) * sizeof_bits_v<T> / 8;
|
||||
constexpr int kTmaTransactionBytes = sizeof(ArrayEngine<T, size(sA)>);
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
@ -146,9 +148,15 @@ tma_test_device_cute(T const* g_in, T* g_out,
|
||||
// print_tensor(sA);
|
||||
//}
|
||||
|
||||
for (int i = threadIdx.x; i < size(sA); i += blockDim.x) {
|
||||
tBgB(i,stage) = sA(i);
|
||||
// for (int i = threadIdx.x; i < size(sA); i += blockDim.x) {
|
||||
// tBgB(i,stage) = sA(i);
|
||||
// }
|
||||
|
||||
// Subbyte elements could cause race conditions, so be even more conservative
|
||||
if (thread0()) {
|
||||
copy(sA, tBgB(_,stage));
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
@ -161,30 +169,38 @@ test_tma_load(CopyOp const& copy_op,
|
||||
CTA_Tile const& cta_tile)
|
||||
{
|
||||
using namespace cute;
|
||||
thrust::host_vector<T> h_in(cosize(gmem_layout));
|
||||
for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i % 13); }
|
||||
thrust::device_vector<T> d_in = h_in;
|
||||
thrust::device_vector<T> d_out(h_in.size(), T(-1));
|
||||
|
||||
Tensor gA = make_tensor(d_in.data().get(), gmem_layout);
|
||||
// Allocate and initialize host test data
|
||||
size_t N = ceil_div(cosize(gmem_layout) * sizeof_bits<T>::value, 8);
|
||||
thrust::host_vector<char> h_in(N);
|
||||
Tensor hA_in = make_tensor(recast_ptr<T>(h_in.data()), gmem_layout);
|
||||
for (int i = 0; i < size(hA_in); ++i) { hA_in(i) = static_cast<T>(i % 13); }
|
||||
|
||||
// Allocate and initialize device test data
|
||||
thrust::device_vector<char> d_in = h_in;
|
||||
thrust::device_vector<char> d_out(h_in.size(), char(-1));
|
||||
|
||||
// Create TMA for this device Tensor
|
||||
Tensor gA = make_tensor(make_gmem_ptr<T>(raw_pointer_cast(d_in.data())), gmem_layout);
|
||||
auto tma = make_tma_copy<TmaType>(copy_op, gA, smem_layout, cta_tile, Int<1>{});
|
||||
//print(tma);
|
||||
|
||||
// Launch
|
||||
int smem_size = int(sizeof(SharedStorage<T, decltype(smem_layout)>));
|
||||
tma_test_device_cute<<<1, 128, smem_size>>>(
|
||||
thrust::raw_pointer_cast(d_in.data()),
|
||||
thrust::raw_pointer_cast(d_out.data()),
|
||||
reinterpret_cast<T const*>(raw_pointer_cast(d_in.data())),
|
||||
reinterpret_cast<T*> (raw_pointer_cast(d_out.data())),
|
||||
tma, cta_tile,
|
||||
gmem_layout,
|
||||
smem_layout);
|
||||
|
||||
thrust::host_vector<T> h_out = d_out;
|
||||
// Copy results back to host
|
||||
thrust::host_vector<char> h_out = d_out;
|
||||
Tensor hA_out = make_tensor(recast_ptr<T>(h_out.data()), gmem_layout);
|
||||
|
||||
// Validate the results, and tolerate the first 3 errors:
|
||||
Tensor hA_in = make_tensor(h_in.data(), gmem_layout);
|
||||
Tensor hA_out = make_tensor(h_out.data(), gmem_layout);
|
||||
// Validate the results. Print only the first 3 errors.
|
||||
int count = 3;
|
||||
for (int i = 0; i < cute::size(gmem_layout) && count > 0; ++i) {
|
||||
for (int i = 0; i < size(hA_out) && count > 0; ++i) {
|
||||
EXPECT_EQ(hA_in(i), hA_out(i));
|
||||
if (hA_in(i) != hA_out(i)) {
|
||||
--count;
|
||||
|
||||
@ -43,7 +43,7 @@ namespace cutlass::test {
|
||||
template <class ElementType, class SmemLayout>
|
||||
struct SharedStorage
|
||||
{
|
||||
cute::array_aligned<ElementType, cute::cosize_v<SmemLayout>> smem;
|
||||
cute::ArrayEngine<ElementType, cute::cosize_v<SmemLayout>> smem;
|
||||
};
|
||||
|
||||
#if CUDA_12_0_SM90_FEATURES_SUPPORTED
|
||||
@ -61,24 +61,24 @@ tma_test_device_cute(T const* g_in, T* g_out,
|
||||
extern __shared__ char shared_memory[];
|
||||
using SharedStorage = SharedStorage<T, SmemLayout>;
|
||||
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(shared_memory);
|
||||
|
||||
// Construct SMEM tensor
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem.data()), smem_layout); // (CTA_TILE_M,CTA_TILE_N,...)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem.begin()), smem_layout); // (CTA_TILE_M,CTA_TILE_N,...)
|
||||
|
||||
// TMA requires special handling of strides to deal with coord codomain mapping
|
||||
// Represent the full tensors -- get these from TMA
|
||||
Tensor mA = make_tensor(make_gmem_ptr(g_in), gmem_layout);
|
||||
Tensor mA = make_tensor(make_gmem_ptr<T>(g_in), gmem_layout);
|
||||
Tensor mB = tma.get_tma_tensor(shape(gmem_layout));
|
||||
|
||||
constexpr int R = rank_v<CTA_Tiler>;
|
||||
Tensor gA = local_tile(mA, cta_tiler, repeat<R>(_)); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...)
|
||||
Tensor gB = local_tile(mB, cta_tiler, repeat<R>(_)); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...)
|
||||
Tensor gA = flat_divide(mA, cta_tiler); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...)
|
||||
Tensor gB = flat_divide(mB, cta_tiler); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...)
|
||||
|
||||
//
|
||||
// Prepare the TMA_STORE
|
||||
//
|
||||
|
||||
auto cta_tma = tma.get_slice(Int<0>{}); // CTA slice
|
||||
|
||||
Tensor tBsB_x = cta_tma.partition_S(sB); // (TMA,TMA_M,TMA_N)
|
||||
Tensor tBgB_x = cta_tma.partition_D(gB); // (TMA,TMA_M,TMA_N,REST_M,REST_N)
|
||||
|
||||
@ -121,11 +121,17 @@ tma_test_device_cute(T const* g_in, T* g_out,
|
||||
// Read in trivially gmem -> smem
|
||||
//
|
||||
|
||||
for (int i = threadIdx.x; i < size(sB); i += blockDim.x) {
|
||||
sB(i) = tAgA(i,stage);
|
||||
// for (int i = threadIdx.x; i < size(sB); i += blockDim.x) {
|
||||
// sB(i) = tAgA(i,stage);
|
||||
// }
|
||||
|
||||
// Subbyte elements could cause race conditions, so be even more conservative
|
||||
if (thread0()) {
|
||||
copy(tAgA(_,stage), sB);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
cute::cp_async_wait<0>();
|
||||
|
||||
//
|
||||
// Perform the TMA_STORE
|
||||
@ -148,30 +154,38 @@ test_tma_store(CopyOp const& copy_op,
|
||||
CTA_Tile const& cta_tile)
|
||||
{
|
||||
using namespace cute;
|
||||
thrust::host_vector<T> h_in(cosize(gmem_layout));
|
||||
for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i % 13); }
|
||||
thrust::device_vector<T> d_in = h_in;
|
||||
thrust::device_vector<T> d_out(h_in.size(), T(-1));
|
||||
|
||||
Tensor gA = make_tensor(d_out.data().get(), gmem_layout);
|
||||
// Allocate and initialize host test data
|
||||
size_t N = ceil_div(cosize(gmem_layout) * sizeof_bits<T>::value, 8);
|
||||
thrust::host_vector<char> h_in(N);
|
||||
Tensor hA_in = make_tensor(recast_ptr<T>(h_in.data()), gmem_layout);
|
||||
for (int i = 0; i < size(hA_in); ++i) { hA_in(i) = static_cast<T>(i % 13); }
|
||||
|
||||
// Allocate and initialize device test data
|
||||
thrust::device_vector<char> d_in = h_in;
|
||||
thrust::device_vector<char> d_out(h_in.size(), char(-1));
|
||||
|
||||
// Create TMA for this device Tensor
|
||||
Tensor gA = make_tensor(make_gmem_ptr<T>(raw_pointer_cast(d_out.data())), gmem_layout);
|
||||
auto tma = make_tma_copy<TmaType>(copy_op, gA, smem_layout, cta_tile, Int<1>{});
|
||||
//print(tma);
|
||||
|
||||
// Launch
|
||||
int smem_size = int(sizeof(SharedStorage<T, decltype(smem_layout)>));
|
||||
tma_test_device_cute<<<1, 128, smem_size>>>(
|
||||
thrust::raw_pointer_cast(d_in.data()),
|
||||
thrust::raw_pointer_cast(d_out.data()),
|
||||
reinterpret_cast<T const*>(raw_pointer_cast(d_in.data())),
|
||||
reinterpret_cast<T*> (raw_pointer_cast(d_out.data())),
|
||||
tma, cta_tile,
|
||||
gmem_layout,
|
||||
smem_layout);
|
||||
|
||||
thrust::host_vector<T> h_out = d_out;
|
||||
// Copy results back to host
|
||||
thrust::host_vector<char> h_out = d_out;
|
||||
Tensor hA_out = make_tensor(recast_ptr<T>(h_out.data()), gmem_layout);
|
||||
|
||||
// Validate the results, and tolerate the first 3 errors:
|
||||
Tensor hA_in = make_tensor(h_in.data(), gmem_layout);
|
||||
Tensor hA_out = make_tensor(h_out.data(), gmem_layout);
|
||||
// Validate the results. Print only the first 3 errors.
|
||||
int count = 3;
|
||||
for (int i = 0; i < cute::size(gmem_layout) && count > 0; ++i) {
|
||||
for (int i = 0; i < size(hA_out) && count > 0; ++i) {
|
||||
EXPECT_EQ(hA_in(i), hA_out(i));
|
||||
if (hA_in(i) != hA_out(i)) {
|
||||
--count;
|
||||
|
||||
@ -242,6 +242,24 @@ cutlass_test_unit_add_executable(
|
||||
sm80_gemm_tf32_tf32_f32_tensor_op_f32.cu
|
||||
)
|
||||
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_mixed_input_tensorop_sm80
|
||||
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 4
|
||||
|
||||
# Upcast on Operand A
|
||||
gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu
|
||||
gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu
|
||||
gemm_universal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32_sm80.cu
|
||||
|
||||
# Upcast on Operand B
|
||||
gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f16_sm80.cu
|
||||
gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f16_sm80.cu
|
||||
gemm_universal_bf16t_s8n_bf16t_mixed_input_tensor_op_f32_sm80.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_tensorop_sm90
|
||||
|
||||
@ -272,10 +290,22 @@ cutlass_test_unit_add_executable(
|
||||
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 4
|
||||
sm90_gemm_f16_f16_f16_alignx_tensor_op.cu
|
||||
sm90_gemm_f16_f16_f16_alignx_tensor_op_f32.cu
|
||||
sm90_gemm_f16_f16_f16_alignx_tensor_op_f32_warpspecialized.cu
|
||||
sm90_gemm_f16_f16_f16_alignx_tensor_op_f32_warpspecialized_cooperative.cu
|
||||
sm90_gemm_f16_f16_f16_alignx_tensor_op_f32_warpspecialized_pingpong.cu
|
||||
sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32.cu
|
||||
sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized.cu
|
||||
sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_cooperative.cu
|
||||
sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_pingpong.cu
|
||||
sm90_gemm_s8_s8_s8_alignx_tensor_op_s32.cu
|
||||
sm90_gemm_s8_s8_s8_alignx_tensor_op_s32_warpspecialized.cu
|
||||
sm90_gemm_s8_s8_s8_alignx_tensor_op_s32_warpspecialized_cooperative.cu
|
||||
sm90_gemm_s8_s8_s8_alignx_tensor_op_s32_warpspecialized_pingpong.cu
|
||||
sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32.cu
|
||||
sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32_warpspecialized.cu
|
||||
sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32_warpspecialized_cooperative.cu
|
||||
sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32_warpspecialized_pingpong.cu
|
||||
)
|
||||
|
||||
# Fused epilogue tests
|
||||
@ -298,7 +328,6 @@ cutlass_test_unit_add_executable(
|
||||
sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_dag.cu
|
||||
sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_dag.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_tensorop_cluster_multicast_sm90
|
||||
|
||||
@ -311,7 +340,6 @@ cutlass_test_unit_add_executable(
|
||||
sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu
|
||||
sm90_gemm_f8_f8_f32_tensor_op_f32_cluster_warpspecialized_cooperative.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_tensorop_gmma_rs_warpspecialized_sm90
|
||||
|
||||
@ -319,6 +347,8 @@ cutlass_test_unit_add_executable(
|
||||
BATCH_SIZE 4
|
||||
|
||||
sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu
|
||||
sm90_gemm_f8_f8_f32_tensor_op_f32_rs_cluster_warpspecialized_cooperative.cu
|
||||
sm90_gemm_f16_f16_f32_tensor_op_f32_rs_cluster_warpspecialized_cooperative.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
@ -341,23 +371,6 @@ cutlass_test_unit_add_executable(
|
||||
sm80_gemm_f16_f16_f32_tensor_op_f32.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_mixed_input_tensorop_sm80
|
||||
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 4
|
||||
|
||||
# Upcast on Operand A
|
||||
gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu
|
||||
gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu
|
||||
gemm_universal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32_sm80.cu
|
||||
|
||||
# Upcast on Operand B
|
||||
gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f16_sm80.cu
|
||||
gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f16_sm80.cu
|
||||
gemm_universal_bf16t_s8n_bf16t_mixed_input_tensor_op_f32_sm80.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_tensorop_f64
|
||||
|
||||
|
||||
@ -49,16 +49,18 @@
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
|
||||
#include "testbed_utils.h"
|
||||
|
||||
#include "cutlass/kernel_hardware_info.hpp"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/platform/platform.h"
|
||||
#include "cutlass/epilogue/fusion/operations.hpp"
|
||||
|
||||
#include "cute/int_tuple.hpp"
|
||||
#include "cute/layout.hpp"
|
||||
|
||||
namespace test {
|
||||
namespace gemm {
|
||||
@ -68,9 +70,9 @@ namespace device {
|
||||
|
||||
namespace detail{
|
||||
|
||||
// Helper classes that take default data type when
|
||||
// Helper classes that take default data type when
|
||||
// the Gemm::EpilogueOutputOp does not have ElementCompute
|
||||
// and ElementScalar.
|
||||
// and ElementScalar.
|
||||
// (e.g. when Sm90TreeVisitor is used as FusionCallbacks)
|
||||
template <typename Gemm, typename Default, typename = void>
|
||||
struct ElementComputeType {
|
||||
@ -138,6 +140,34 @@ private:
|
||||
int iterations_ = 20;
|
||||
};
|
||||
|
||||
// The maxium swizzle size to use
|
||||
//
|
||||
// This class, like Splits above makes it harder to confuse
|
||||
// the order of arguments of the various run(...) functions in this file.
|
||||
class MaxSwizzleSize {
|
||||
public:
|
||||
MaxSwizzleSize() = default;
|
||||
|
||||
template<class IntegralNotBool,
|
||||
__CUTE_REQUIRES((std::is_integral_v<IntegralNotBool> &&
|
||||
!std::is_same_v<IntegralNotBool, bool>)) >
|
||||
explicit MaxSwizzleSize(IntegralNotBool max_swizzle_size) : max_swizzle_size_(max_swizzle_size) {}
|
||||
explicit operator int() const { return max_swizzle_size_; }
|
||||
private:
|
||||
int max_swizzle_size_ = 1;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
auto make_iterator(T* ptr) {
|
||||
using namespace cute;
|
||||
if constexpr (is_subbyte_v<T>) {
|
||||
return subbyte_iterator<T>(ptr);
|
||||
}
|
||||
else {
|
||||
return ptr;
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename Gemm,
|
||||
template <class T> class ActivationFunctor_ = cutlass::epilogue::thread::Identity
|
||||
@ -161,6 +191,8 @@ struct TestbedImpl {
|
||||
using ElementScalar = typename ElementScalarType<Gemm, ElementCompute>::Type;
|
||||
using ActivationFunctor = ActivationFunctor_<ElementCompute>;
|
||||
|
||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions;
|
||||
|
||||
static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
||||
static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
||||
|
||||
@ -190,6 +222,7 @@ struct TestbedImpl {
|
||||
using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t<StrideB>;
|
||||
using LayoutTagC = cutlass::detail::StrideToLayoutTagA_t<StrideC>;
|
||||
using LayoutTagD = cutlass::detail::StrideToLayoutTagA_t<StrideD>;
|
||||
using LayoutTagVector = cutlass::layout::PackedVectorLayout;
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_a;
|
||||
@ -323,10 +356,10 @@ struct TestbedImpl {
|
||||
// 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode
|
||||
auto a_coord = cutlass::make_Coord(M * L, K);
|
||||
auto c_coord = cutlass::make_Coord(M * L, N);
|
||||
// Cutlass has Row/Col major refers to MxK times KxN matrix product,
|
||||
// Cutlass has Row/Col major refers to MxK times KxN matrix product,
|
||||
// so the HostTensorB should be treated as KxN in "coord"'s view
|
||||
auto b_coord = cutlass::make_Coord(K, N * L);
|
||||
|
||||
|
||||
|
||||
tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagA>::layout_factory(a_coord, stride_factor_A));
|
||||
tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagB>::layout_factory(b_coord, stride_factor_B));
|
||||
@ -387,7 +420,7 @@ struct TestbedImpl {
|
||||
std::ofstream file(fname.str());
|
||||
file
|
||||
<< "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << L
|
||||
<< ", alpha: " << float(alpha) << ", beta: " << float(beta) << "\n\n";
|
||||
<< ", alpha: " << alpha << ", beta: " << beta << "\n\n";
|
||||
|
||||
file
|
||||
<< "A =\n" << tensor_A.host_view()
|
||||
@ -404,7 +437,7 @@ struct TestbedImpl {
|
||||
bool verify(
|
||||
ProblemShapeType problem_size,
|
||||
ElementScalar alpha,
|
||||
ElementScalar beta)
|
||||
ElementScalar beta)
|
||||
{
|
||||
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
|
||||
auto M = cute::size<0>(problem_shape_MNKL);
|
||||
@ -412,13 +445,13 @@ struct TestbedImpl {
|
||||
auto K = cute::size<2>(problem_shape_MNKL);
|
||||
auto L = cute::size<3>(problem_shape_MNKL);
|
||||
|
||||
auto A = cute::make_tensor(tensor_A.host_data(),
|
||||
auto A = cute::make_tensor(detail::make_iterator(tensor_A.host_data()),
|
||||
cute::make_layout(cute::make_shape(M, K, L), stride_a));
|
||||
auto B = cute::make_tensor(tensor_B.host_data(),
|
||||
auto B = cute::make_tensor(detail::make_iterator(tensor_B.host_data()),
|
||||
cute::make_layout(cute::make_shape(N, K, L), stride_b));
|
||||
auto C = cute::make_tensor(tensor_C.host_data(),
|
||||
auto C = cute::make_tensor(detail::make_iterator(tensor_C.host_data()),
|
||||
cute::make_layout(cute::make_shape(M, N, L), stride_c));
|
||||
auto D = cute::make_tensor(reference_D.host_data(),
|
||||
auto D = cute::make_tensor(detail::make_iterator(reference_D.host_data()),
|
||||
cute::make_layout(cute::make_shape(M, N, L), stride_d));
|
||||
auto Bias = cute::make_tensor(static_cast<ElementCompute*>(nullptr),
|
||||
cute::make_layout(cute::make_shape(M, cute::_1{})));
|
||||
@ -451,7 +484,6 @@ struct TestbedImpl {
|
||||
};
|
||||
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
return compare_reference(problem_shape_MNKL, alpha, beta);
|
||||
}
|
||||
|
||||
@ -529,8 +561,10 @@ struct TestbedImpl {
|
||||
ElementScalar alpha = ElementScalar(1),
|
||||
ElementScalar beta = ElementScalar(0),
|
||||
bool profiling = false,
|
||||
detail::Iterations iterations = Iterations{},
|
||||
detail::Splits splits = Splits{})
|
||||
detail::Iterations iterations = detail::Iterations{},
|
||||
RasterOrderOptions raster_order = RasterOrderOptions::Heuristic,
|
||||
detail::MaxSwizzleSize max_swizzle = detail::MaxSwizzleSize{},
|
||||
detail::Splits splits = detail::Splits{})
|
||||
{
|
||||
// Fail test if insufficient CUDA device
|
||||
if (!sufficient()) {
|
||||
@ -557,7 +591,10 @@ struct TestbedImpl {
|
||||
|
||||
typename Gemm::GemmKernel::TileScheduler::Arguments scheduler_args;
|
||||
if constexpr (std::is_same_v<typename Gemm::GemmKernel::TileSchedulerTag, cutlass::gemm::StreamKScheduler>) {
|
||||
scheduler_args = { static_cast<int>(splits) };
|
||||
scheduler_args = { static_cast<int>(splits), static_cast<int>(max_swizzle), raster_order };
|
||||
}
|
||||
else {
|
||||
scheduler_args = { static_cast<int>(max_swizzle), raster_order };
|
||||
}
|
||||
|
||||
// DefaultEpilogue
|
||||
@ -613,7 +650,7 @@ struct TestbedImpl {
|
||||
//
|
||||
bool passed = this->verify(problem_size, alpha, beta);
|
||||
if (!passed) {
|
||||
std::cout << "Error : Failed : with alpha: " << float(alpha) << ", beta: " << float(beta)
|
||||
std::cout << "Error : Failed : with alpha: " << alpha << ", beta: " << beta
|
||||
<< "\n";
|
||||
}
|
||||
|
||||
@ -648,6 +685,8 @@ struct Testbed3x {
|
||||
using LayoutTagC = typename TestBedImpl::LayoutTagC;
|
||||
using LayoutTagD = typename TestBedImpl::LayoutTagD;
|
||||
|
||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions;
|
||||
|
||||
// Detail Implementation
|
||||
TestBedImpl impl_;
|
||||
|
||||
@ -661,7 +700,7 @@ struct Testbed3x {
|
||||
uint64_t seed_ = TestBedImpl::kDefaultSeed)
|
||||
: impl_(init_A_, init_B_, init_C_, seed_) {}
|
||||
|
||||
Testbed3x(
|
||||
Testbed3x(
|
||||
typename LayoutTagA::Stride stride_factor_A_,
|
||||
typename LayoutTagB::Stride stride_factor_B_,
|
||||
typename LayoutTagC::Stride stride_factor_C_,
|
||||
@ -684,12 +723,14 @@ struct Testbed3x {
|
||||
typename TestBedImpl::ProblemShapeType problem_size,
|
||||
ElementScalar alpha = ElementScalar(1),
|
||||
ElementScalar beta = ElementScalar(0),
|
||||
RasterOrderOptions raster_order = RasterOrderOptions::Heuristic,
|
||||
detail::MaxSwizzleSize max_swizzle = detail::MaxSwizzleSize{},
|
||||
detail::Splits splits = detail::Splits{},
|
||||
bool profiling = false,
|
||||
detail::Iterations iterations = detail::Iterations{})
|
||||
{
|
||||
return impl_.run(
|
||||
problem_size, alpha, beta, profiling, iterations, splits
|
||||
problem_size, alpha, beta, profiling, iterations, raster_order, max_swizzle, splits
|
||||
);
|
||||
}
|
||||
};
|
||||
@ -722,13 +763,15 @@ struct Testbed3xFusionOperation {
|
||||
using StrideD = typename Kernel::StrideD;
|
||||
using ProblemShapeType = typename Kernel::ProblemShape;
|
||||
using ElementAccumulator = typename Kernel::ElementAccumulator;
|
||||
|
||||
|
||||
//
|
||||
// FusionOperation derived types/queries
|
||||
//
|
||||
using FusionOp = typename Gemm::EpilogueOutputOp;
|
||||
static_assert(cute::is_base_of_v<cutlass::epilogue::fusion::FusionOperation, FusionOp>);
|
||||
|
||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions;
|
||||
|
||||
// fusion types are potentially void if the fusion is not supported
|
||||
// helper so we don't try to construct HostTensor with void type
|
||||
template <typename T, typename U = uint8_t>
|
||||
@ -744,11 +787,17 @@ struct Testbed3xFusionOperation {
|
||||
cutlass::epilogue::thread::Identity<ElementCompute>>;
|
||||
|
||||
static constexpr bool IsBiasEnabled = FusionOp::IsPerRowBiasSupported;
|
||||
static constexpr bool IsDeBiasEnabled = FusionOp::IsDePerRowBiasSupported;
|
||||
static constexpr bool IsPerRowScaleEnabled = FusionOp::IsPerRowScaleSupported;
|
||||
static constexpr bool IsScaleFactorEnabled = FusionOp::IsScaleFactorSupported;
|
||||
static constexpr bool IsAuxEnabled = FusionOp::IsAuxOutSupported;
|
||||
static constexpr bool IsAbsMaxEnabled = FusionOp::IsAbsMaxSupported;
|
||||
|
||||
static constexpr bool IsAuxInEnabled = FusionOp::IsAuxInSupported;
|
||||
static constexpr bool IsAuxOutEnabled = FusionOp::IsAuxOutSupported;
|
||||
static constexpr bool IsAbsMaxEnabledD = FusionOp::IsAbsMaxSupported &&
|
||||
(cute::is_same_v<ElementD, cutlass::float_e4m3_t> ||
|
||||
cute::is_same_v<ElementD, cutlass::float_e5m2_t>);
|
||||
static constexpr bool IsAbsMaxEnabledAux = IsAuxOutEnabled && FusionOp::IsAbsMaxSupported &&
|
||||
(cute::is_same_v<ElementAux, cutlass::float_e4m3_t> ||
|
||||
cute::is_same_v<ElementAux, cutlass::float_e5m2_t>);
|
||||
// Legacy support for deprecated bias-elementwise collective, will be removed next release
|
||||
using EpiloguePolicy = typename Epilogue::DispatchPolicy;
|
||||
static constexpr bool IsLegacy =
|
||||
@ -773,6 +822,7 @@ struct Testbed3xFusionOperation {
|
||||
cutlass::HostTensor<ElementAux , LayoutTagAux > tensor_Aux;
|
||||
cutlass::gemm::TagToStrideC_t< LayoutTagAux > stride_Aux;
|
||||
// References
|
||||
cutlass::HostTensor<ElementBias, LayoutTagVector> reference_dbias;
|
||||
cutlass::HostTensor<ElementAux , LayoutTagAux > reference_Aux;
|
||||
cutlass::HostTensor<ElementAmax, LayoutTagScalar> reference_abs_max_Aux;
|
||||
cutlass::HostTensor<ElementAmax, LayoutTagScalar> reference_abs_max_D;
|
||||
@ -791,12 +841,6 @@ struct Testbed3xFusionOperation {
|
||||
// Random distribution with which to initialize the bias vector
|
||||
cutlass::Distribution::Kind init_bias = cutlass::Distribution::Uniform;
|
||||
|
||||
// Factors used for calculating relative equality. These default
|
||||
// values are borrowed from those used by default in the CUTLASS
|
||||
// profiler for performing relative equality checks.
|
||||
float epsilon = 0.05f;
|
||||
float nonzero_floor = 1.0f / 256.0f;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
@ -853,7 +897,7 @@ struct Testbed3xFusionOperation {
|
||||
else {
|
||||
beta.resize(col_vector_coord);
|
||||
EXPECT_TRUE(impl_.initialize_tensor(beta.host_view(), init_scale, impl_.seed + 2024));
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
alpha.resize(scalar_coord, use_device_scalars);
|
||||
@ -885,13 +929,34 @@ struct Testbed3xFusionOperation {
|
||||
bias.sync_device();
|
||||
}
|
||||
|
||||
if constexpr (IsAbsMaxEnabled) {
|
||||
abs_max_D.resize(scalar_coord);
|
||||
abs_max_D.sync_device();
|
||||
reference_abs_max_D.resize(scalar_coord);
|
||||
if constexpr (IsDeBiasEnabled) {
|
||||
bias.resize(col_vector_coord);
|
||||
reference_dbias.resize(col_vector_coord);
|
||||
cutlass::reference::host::TensorFill(bias.host_view(), ElementBias(0));
|
||||
cutlass::reference::host::TensorFill(reference_dbias.host_view(), ElementBias(0));
|
||||
bias.sync_device();
|
||||
}
|
||||
|
||||
if constexpr (IsAuxEnabled) {
|
||||
if constexpr (IsAbsMaxEnabledD) {
|
||||
abs_max_D.resize(scalar_coord);
|
||||
// ensure in-place device reductions perform their own initialization
|
||||
cutlass::reference::host::TensorFill(abs_max_D.host_view(),
|
||||
CUTLASS_STL_NAMESPACE::numeric_limits<ElementAmax>::max());
|
||||
abs_max_D.sync_device();
|
||||
reference_abs_max_D.resize(scalar_coord);
|
||||
cutlass::reference::host::TensorFill(reference_abs_max_D.host_view(), ElementAmax(0));
|
||||
}
|
||||
|
||||
if constexpr (IsAuxInEnabled) {
|
||||
auto aux_coord = cutlass::make_Coord(M * L, N);
|
||||
auto aux_layout = cutlass::layout::Affine2Layout_Factory<LayoutTagD>::layout_factory(aux_coord, typename LayoutTagAux::Stride{});
|
||||
tensor_Aux.resize(aux_coord, aux_layout);
|
||||
EXPECT_TRUE(impl_.initialize_tensor(tensor_Aux.host_view(), impl_.init_C, impl_.seed + 2023));
|
||||
tensor_Aux.sync_device();
|
||||
stride_Aux = cutlass::make_cute_packed_stride(cutlass::gemm::TagToStrideC_t<LayoutTagAux>{}, cute::make_shape(M, N, L));
|
||||
}
|
||||
|
||||
if constexpr (IsAuxOutEnabled) {
|
||||
auto aux_coord = cutlass::make_Coord(M * L, N);
|
||||
auto aux_layout = cutlass::layout::Affine2Layout_Factory<LayoutTagD>::layout_factory(aux_coord, typename LayoutTagAux::Stride{});
|
||||
tensor_Aux.resize(aux_coord, aux_layout);
|
||||
@ -905,10 +970,14 @@ struct Testbed3xFusionOperation {
|
||||
scale_Aux.sync_device();
|
||||
}
|
||||
|
||||
if constexpr (IsAbsMaxEnabled) {
|
||||
if constexpr (IsAbsMaxEnabledAux) {
|
||||
abs_max_Aux.resize(scalar_coord);
|
||||
// ensure in-place device reductions perform their own initialization
|
||||
cutlass::reference::host::TensorFill(abs_max_Aux.host_view(),
|
||||
CUTLASS_STL_NAMESPACE::numeric_limits<ElementAmax>::max());
|
||||
abs_max_Aux.sync_device();
|
||||
reference_abs_max_Aux.resize(scalar_coord);
|
||||
cutlass::reference::host::TensorFill(reference_abs_max_Aux.host_view(), ElementAmax(0));
|
||||
}
|
||||
}
|
||||
|
||||
@ -922,9 +991,16 @@ struct Testbed3xFusionOperation {
|
||||
cutlass::TensorView<Element, Layout> const& lhs,
|
||||
cutlass::TensorView<Element, Layout> const& rhs) const {
|
||||
|
||||
// Factors used for calculating relative equality. CUTLASS's relative-equality
|
||||
// checks in include/cutlass/relatively_equal.h are inspired by
|
||||
// https://floating-point-gui.de/errors/comparison/. This reference suggests using
|
||||
// the minimum normal value of a given type as the nonzero_floor.
|
||||
Element epsilon(0.1f);
|
||||
Element nonzero_floor(std::numeric_limits<Element>::min());
|
||||
|
||||
if (check_relative_equality) {
|
||||
return cutlass::reference::host::TensorRelativelyEquals(
|
||||
lhs, rhs, Element(epsilon), Element(nonzero_floor));
|
||||
lhs, rhs, epsilon, nonzero_floor);
|
||||
}
|
||||
else {
|
||||
return cutlass::reference::host::TensorEquals(lhs, rhs);
|
||||
@ -933,6 +1009,7 @@ struct Testbed3xFusionOperation {
|
||||
|
||||
/// Compares computed reference with device reference and outputs to a file if incorrect
|
||||
bool compare_reference(cute::Shape<int,int,int,int> problem_shape_MNKL) {
|
||||
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
auto coord_0 = cutlass::make_Coord(0);
|
||||
|
||||
@ -947,17 +1024,24 @@ struct Testbed3xFusionOperation {
|
||||
}
|
||||
bool passed = equality_check(impl_.reference_D.host_view(), impl_.tensor_D.host_view());
|
||||
|
||||
if constexpr (IsAbsMaxEnabled) {
|
||||
if constexpr (IsAbsMaxEnabledD) {
|
||||
abs_max_D.sync_host();
|
||||
passed &= equality_check(reference_abs_max_D.host_view(), abs_max_D.host_view());
|
||||
}
|
||||
|
||||
if constexpr (IsAuxEnabled) {
|
||||
if constexpr (IsDeBiasEnabled) {
|
||||
bias.sync_host();
|
||||
EXPECT_GT(cutlass::reference::host::TensorNorm(bias.host_view()), 0);
|
||||
EXPECT_GT(cutlass::reference::host::TensorNorm(reference_dbias.host_view()), 0);
|
||||
passed &= equality_check(reference_dbias.host_view(), bias.host_view());
|
||||
}
|
||||
|
||||
if constexpr (IsAuxOutEnabled) {
|
||||
tensor_Aux.sync_host();
|
||||
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Aux.host_view()), 0);
|
||||
EXPECT_GT(cutlass::reference::host::TensorNorm(reference_Aux.host_view()), 0);
|
||||
passed &= equality_check(reference_Aux.host_view(), tensor_Aux.host_view());
|
||||
if constexpr (IsAbsMaxEnabled) {
|
||||
if constexpr (IsAbsMaxEnabledAux) {
|
||||
abs_max_Aux.sync_host();
|
||||
passed &= equality_check(reference_abs_max_Aux.host_view(), abs_max_Aux.host_view());
|
||||
}
|
||||
@ -990,7 +1074,7 @@ struct Testbed3xFusionOperation {
|
||||
}
|
||||
file << "\n\n";
|
||||
|
||||
if constexpr (IsAbsMaxEnabled) {
|
||||
if constexpr (IsAbsMaxEnabledD) {
|
||||
file << "scale_d: " << float(scale_D.at(coord_0));
|
||||
file << "\nReference abs_max_D :";
|
||||
file << " " << float(reference_abs_max_D.at(coord_0));
|
||||
@ -998,15 +1082,16 @@ struct Testbed3xFusionOperation {
|
||||
file << "\nComputed abs_max_D :";
|
||||
file << " " << float(abs_max_D.at(coord_0));
|
||||
file << "\n\n";
|
||||
if constexpr (IsAuxEnabled) {
|
||||
file << "scale_aux: " << float(scale_Aux.at(coord_0));
|
||||
file << "\nReference abs_max_Aux :";
|
||||
file << " " << float(reference_abs_max_Aux.at(coord_0));
|
||||
}
|
||||
|
||||
file << "\nComputed abs_max_Aux :";
|
||||
file << " " << float(abs_max_Aux.at(coord_0));
|
||||
file << "\n\n";
|
||||
}
|
||||
if constexpr (IsAbsMaxEnabledAux) {
|
||||
file << "scale_aux: " << float(scale_Aux.at(coord_0));
|
||||
file << "\nReference abs_max_Aux :";
|
||||
file << " " << float(reference_abs_max_Aux.at(coord_0));
|
||||
|
||||
file << "\nComputed abs_max_Aux :";
|
||||
file << " " << float(abs_max_Aux.at(coord_0));
|
||||
file << "\n\n";
|
||||
}
|
||||
|
||||
file
|
||||
@ -1018,7 +1103,16 @@ struct Testbed3xFusionOperation {
|
||||
file << "\n\nBias = \n" << bias.host_view();
|
||||
}
|
||||
|
||||
if constexpr (IsAuxEnabled) {
|
||||
if constexpr (IsAuxInEnabled) {
|
||||
file << "\n\nAux Input = \n" << tensor_Aux.host_view();
|
||||
}
|
||||
|
||||
if constexpr (IsDeBiasEnabled) {
|
||||
file << "\n\nReference dBias = \n" << reference_dbias.host_view();
|
||||
file << "\n\nComputed dBias = \n" << bias.host_view();
|
||||
}
|
||||
|
||||
if constexpr (IsAuxOutEnabled) {
|
||||
file
|
||||
<< "\n\nReference Aux =\n" << reference_Aux.host_view()
|
||||
<< "\n\nComputed Aux =\n" << tensor_Aux.host_view();
|
||||
@ -1041,21 +1135,21 @@ struct Testbed3xFusionOperation {
|
||||
auto L = cute::get<3>(problem_shape_MNKL);
|
||||
auto coord_0 = cutlass::make_Coord(0);
|
||||
|
||||
auto A = cute::make_tensor(impl_.tensor_A.host_data(),
|
||||
auto A = cute::make_tensor(detail::make_iterator(impl_.tensor_A.host_data()),
|
||||
cute::make_layout(cute::make_shape(M, K, L), impl_.stride_a));
|
||||
auto B = cute::make_tensor(impl_.tensor_B.host_data(),
|
||||
auto B = cute::make_tensor(detail::make_iterator(impl_.tensor_B.host_data()),
|
||||
cute::make_layout(cute::make_shape(N, K, L), impl_.stride_b));
|
||||
auto C = cute::make_tensor(impl_.tensor_C.host_data(),
|
||||
auto C = cute::make_tensor(detail::make_iterator(impl_.tensor_C.host_data()),
|
||||
cute::make_layout(cute::make_shape(M, N, L), impl_.stride_c));
|
||||
auto D = cute::make_tensor(impl_.reference_D.host_data(),
|
||||
auto D = cute::make_tensor(detail::make_iterator(impl_.reference_D.host_data()),
|
||||
cute::make_layout(cute::make_shape(M, N, L), impl_.stride_d));
|
||||
auto Bias = cute::make_tensor(bias.host_data(),
|
||||
auto Bias = cute::make_tensor(detail::make_iterator(IsDeBiasEnabled ? reference_dbias.host_data() : bias.host_data()),
|
||||
cute::make_layout(cute::make_shape(M, cute::_1{})));
|
||||
auto Aux = cute::make_tensor(reference_Aux.host_data(),
|
||||
auto Aux = cute::make_tensor(detail::make_iterator(IsAuxInEnabled ? tensor_Aux.host_data() : reference_Aux.host_data()),
|
||||
cute::make_layout(cute::make_shape(M, N, L), stride_Aux));
|
||||
auto Valpha = cute::make_tensor(alpha.host_data(),
|
||||
auto Valpha = cute::make_tensor(detail::make_iterator(alpha.host_data()),
|
||||
cute::make_layout(cute::make_shape(M, cute::_1{})));
|
||||
auto Vbeta = cute::make_tensor(beta.host_data(),
|
||||
auto Vbeta = cute::make_tensor(detail::make_iterator(beta.host_data()),
|
||||
cute::make_layout(cute::make_shape(M, cute::_1{})));
|
||||
|
||||
cutlass::reference::host::GettMainloopParams<ElementAccumulator, decltype(A), decltype(B)> mainloop_params{A, B};
|
||||
@ -1086,20 +1180,24 @@ struct Testbed3xFusionOperation {
|
||||
epilogue_params.scale_d = scale_D.at(coord_0);
|
||||
}
|
||||
|
||||
if constexpr (IsBiasEnabled) {
|
||||
if constexpr (IsBiasEnabled or IsDeBiasEnabled) {
|
||||
epilogue_params.Bias = Bias;
|
||||
}
|
||||
|
||||
if constexpr (IsAbsMaxEnabled) {
|
||||
if constexpr (IsAbsMaxEnabledD) {
|
||||
epilogue_params.abs_max_D = reference_abs_max_D.host_data();
|
||||
}
|
||||
|
||||
if constexpr (IsAuxEnabled) {
|
||||
if constexpr (IsAuxInEnabled) {
|
||||
epilogue_params.Aux = Aux;
|
||||
}
|
||||
|
||||
if constexpr (IsAuxOutEnabled) {
|
||||
epilogue_params.Aux = Aux;
|
||||
if constexpr (IsScaleFactorEnabled) {
|
||||
epilogue_params.scale_aux = scale_Aux.at(coord_0);
|
||||
}
|
||||
if constexpr (IsAbsMaxEnabled) {
|
||||
if constexpr (IsAbsMaxEnabledAux) {
|
||||
epilogue_params.abs_max_Aux = reference_abs_max_Aux.host_data();
|
||||
}
|
||||
}
|
||||
@ -1121,6 +1219,8 @@ struct Testbed3xFusionOperation {
|
||||
ProblemShapeType problem_size,
|
||||
ElementScalar alpha_ = ElementScalar(1),
|
||||
ElementScalar beta_ = ElementScalar(0),
|
||||
RasterOrderOptions raster_order = RasterOrderOptions::Heuristic,
|
||||
detail::MaxSwizzleSize max_swizzle = detail::MaxSwizzleSize{},
|
||||
detail::Splits splits = detail::Splits{},
|
||||
bool profiling = false,
|
||||
detail::Iterations iterations = detail::Iterations{})
|
||||
@ -1136,6 +1236,8 @@ struct Testbed3xFusionOperation {
|
||||
|
||||
typename Gemm::Arguments arguments;
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
cudaDeviceProp prop;
|
||||
|
||||
hw_info.device_id = 0;
|
||||
if (not profiling) {
|
||||
impl_.sm_count = min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id));
|
||||
@ -1146,6 +1248,8 @@ struct Testbed3xFusionOperation {
|
||||
hw_info.sm_count = impl_.sm_count;
|
||||
}
|
||||
|
||||
cudaGetDeviceProperties(&prop, hw_info.device_id);
|
||||
|
||||
/// Initializes data structures
|
||||
/// A/B/C/D Tensor
|
||||
initialize(problem_size, alpha_, beta_);
|
||||
@ -1172,7 +1276,7 @@ struct Testbed3xFusionOperation {
|
||||
hw_info,
|
||||
scheduler_args
|
||||
};
|
||||
|
||||
|
||||
auto coord_0 = cutlass::make_Coord(0);
|
||||
if constexpr (IsLegacy) {
|
||||
arguments.epilogue.thread = {
|
||||
@ -1186,7 +1290,7 @@ struct Testbed3xFusionOperation {
|
||||
}
|
||||
else {
|
||||
auto &fusion_args = arguments.epilogue.thread;
|
||||
|
||||
|
||||
fusion_args.alpha = alpha.at(coord_0);
|
||||
fusion_args.beta = beta.at(coord_0);
|
||||
fusion_args.alpha_ptr = alpha.device_data();
|
||||
@ -1207,6 +1311,10 @@ struct Testbed3xFusionOperation {
|
||||
fusion_args.bias_ptr = bias.device_data();
|
||||
}
|
||||
|
||||
if constexpr (IsDeBiasEnabled) {
|
||||
fusion_args.dbias_ptr = bias.device_data();
|
||||
}
|
||||
|
||||
// example of how to set kernel activation arguments
|
||||
if constexpr (cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::ScaledGELU_taylor<ElementCompute>>) {
|
||||
// see ActivationFunctor::Arguments in activation.h for definition
|
||||
@ -1214,18 +1322,23 @@ struct Testbed3xFusionOperation {
|
||||
fusion_args.activation.scale = ElementCompute(1);
|
||||
}
|
||||
|
||||
if constexpr (IsAbsMaxEnabled) {
|
||||
if constexpr (IsAbsMaxEnabledD) {
|
||||
fusion_args.amax_D_ptr = abs_max_D.device_data();
|
||||
}
|
||||
|
||||
if constexpr (IsAuxEnabled) {
|
||||
if constexpr (IsAuxInEnabled) {
|
||||
fusion_args.aux_ptr = tensor_Aux.device_data();
|
||||
fusion_args.dAux = stride_Aux;
|
||||
}
|
||||
|
||||
if constexpr (IsAuxOutEnabled) {
|
||||
fusion_args.aux_ptr = tensor_Aux.device_data();
|
||||
fusion_args.dAux = stride_Aux;
|
||||
if constexpr (IsScaleFactorEnabled) {
|
||||
fusion_args.scale_aux = scale_Aux.at(coord_0);
|
||||
fusion_args.scale_aux_ptr = scale_Aux.device_data();
|
||||
}
|
||||
if constexpr (IsAbsMaxEnabled) {
|
||||
if constexpr (IsAbsMaxEnabledAux) {
|
||||
fusion_args.amax_aux_ptr = abs_max_Aux.device_data();
|
||||
}
|
||||
}
|
||||
@ -1277,6 +1390,7 @@ struct Testbed3xFusionOperation {
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
@ -1311,29 +1425,39 @@ bool TestAll(double alpha = 1.0, double beta = 0.0, Testbed testbed = {}) {
|
||||
problem_splits.push_back(Stages + 1);
|
||||
}
|
||||
|
||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions;
|
||||
std::vector<RasterOrderOptions> raster_orders = {RasterOrderOptions::AlongM, RasterOrderOptions::AlongN};
|
||||
std::vector<int> max_swizzle_sizes = {1, 4};
|
||||
|
||||
bool passed = true;
|
||||
|
||||
for (int m : problem_size_m) {
|
||||
for (int n : problem_size_n) {
|
||||
for (int k : problem_size_k) {
|
||||
for (int splits : problem_splits) {
|
||||
ProblemShapeType problem_size;
|
||||
if constexpr (cute::rank(ProblemShapeType{}) == 4) {
|
||||
problem_size = ProblemShapeType{m, n, k, /* l */ 1};
|
||||
}
|
||||
else {
|
||||
problem_size = ProblemShapeType{m, n, k};
|
||||
}
|
||||
for (auto raster_order : raster_orders) {
|
||||
for (int max_swizzle_size : max_swizzle_sizes) {
|
||||
for (int splits : problem_splits) {
|
||||
ProblemShapeType problem_size;
|
||||
if constexpr (cute::rank(ProblemShapeType{}) == 4) {
|
||||
problem_size = ProblemShapeType{m, n, k, /* l */ 1};
|
||||
}
|
||||
else {
|
||||
problem_size = ProblemShapeType{m, n, k};
|
||||
}
|
||||
|
||||
passed = testbed.run(
|
||||
problem_size,
|
||||
cutlass::from_real<ElementScalar>(alpha),
|
||||
cutlass::from_real<ElementScalar>(beta),
|
||||
detail::Splits(splits)
|
||||
);
|
||||
passed = testbed.run(
|
||||
problem_size,
|
||||
cutlass::from_real<ElementScalar>(alpha),
|
||||
cutlass::from_real<ElementScalar>(beta),
|
||||
raster_order,
|
||||
detail::MaxSwizzleSize(max_swizzle_size),
|
||||
detail::Splits(splits)
|
||||
);
|
||||
|
||||
if (!passed) {
|
||||
return false;
|
||||
if (!passed) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -275,4 +275,4 @@ TEST(SM80_Device_GemmUniversal_bf16t_s8n_bf16t_mixed_input_tensor_op_f32, 16x128
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -94,4 +94,4 @@ TEST(SM80_Device_GemmUniversal_f16t_s8n_f16t_mixed_input_tensor_op_f16, 128x128x
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -94,4 +94,4 @@ TEST(SM80_Device_GemmUniversal_f16t_u8t_f16t_mixed_input_tensor_op_f16, 128x128x
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -381,4 +381,4 @@ TEST(SM80_Device_GemmUniversal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32, 128x16
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -94,4 +94,4 @@ TEST(SM80_Device_GemmUniversal_s8t_f16n_f16t_mixed_input_tensor_op_f16, 128x128x
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -94,4 +94,4 @@ TEST(SM80_Device_GemmUniversal_u8t_f16t_f16t_mixed_input_tensor_op_f16, 128x128x
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -29,7 +29,7 @@
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Host reference and operations for Sm90 EVT unit test
|
||||
\brief Host reference and operations for Sm90 EVT unit test
|
||||
*/
|
||||
#pragma once
|
||||
#include "gemm_testbed_3x_evt.hpp"
|
||||
@ -53,10 +53,10 @@ public:
|
||||
using ScalarAlpha = HostScalarBroadcast<Gemm, 1>;
|
||||
using AccFetchNode = HostAccumulator<Gemm>;
|
||||
using AuxLoadNode = HostAuxLoad<Gemm, false, ElementAux, LayoutAux>;
|
||||
using TernaryCompute0 = HEVT<HostCompute<Gemm, cutlass::multiply_add>, ScalarAlpha, AccFetchNode, AuxLoadNode>;
|
||||
using TernaryCompute0 = HEVT<HostCompute<Gemm, cutlass::homogeneous_multiply_add>, ScalarAlpha, AccFetchNode, AuxLoadNode>;
|
||||
using ScalarBeta = HostScalarBroadcast<Gemm, 1>;
|
||||
using CLoadNode = HostAuxLoad<Gemm, true>;
|
||||
using TernaryCompute1 = HEVT<HostCompute<Gemm, cutlass::multiply_add>, ScalarBeta, CLoadNode, TernaryCompute0>;
|
||||
using TernaryCompute1 = HEVT<HostCompute<Gemm, cutlass::homogeneous_multiply_add>, ScalarBeta, CLoadNode, TernaryCompute0>;
|
||||
using EVTModule = HEVT<HostAuxStore<Gemm, true>, TernaryCompute1>;
|
||||
};
|
||||
|
||||
@ -67,10 +67,10 @@ public:
|
||||
using ScalarAlpha = HostScalarBroadcast<Gemm, 1>;
|
||||
using AccFetchNode = HostAccumulator<Gemm>;
|
||||
using RowBroadcastNode = HostRowBroadcast<Gemm, ElementBias>;
|
||||
using TernaryCompute0 = HEVT<HostCompute<Gemm, cutlass::multiply_add>, ScalarAlpha, AccFetchNode, RowBroadcastNode>;
|
||||
using TernaryCompute0 = HEVT<HostCompute<Gemm, cutlass::homogeneous_multiply_add>, ScalarAlpha, AccFetchNode, RowBroadcastNode>;
|
||||
using ScalarBeta = HostScalarBroadcast<Gemm, 1>;
|
||||
using CLoadNode = HostAuxLoad<Gemm, true>;
|
||||
using TernaryCompute1 = HEVT<HostCompute<Gemm, cutlass::multiply_add>, ScalarBeta, CLoadNode, TernaryCompute0>;
|
||||
using TernaryCompute1 = HEVT<HostCompute<Gemm, cutlass::homogeneous_multiply_add>, ScalarBeta, CLoadNode, TernaryCompute0>;
|
||||
using EVTModule = HEVT<HostAuxStore<Gemm, true>, TernaryCompute1>;
|
||||
};
|
||||
|
||||
@ -95,13 +95,13 @@ public:
|
||||
ScalarAlpha,
|
||||
AccFetchNode,
|
||||
AuxLoadNode,
|
||||
HostCompute<Gemm, cutlass::multiply_add>,
|
||||
HostCompute<Gemm, cutlass::homogeneous_multiply_add>,
|
||||
HostCompute<Gemm, cutlass::epilogue::thread::ReLu>,
|
||||
HostCompute<Gemm, cutlass::plus>
|
||||
>;
|
||||
using ScalarBeta = HostScalarBroadcast<Gemm, 1>;
|
||||
using CLoadNode = HostAuxLoad<Gemm, true>;
|
||||
using TernaryCompute1 = HEVT<HostCompute<Gemm, cutlass::multiply_add>, ScalarBeta, CLoadNode, DAGNode>;
|
||||
using TernaryCompute1 = HEVT<HostCompute<Gemm, cutlass::homogeneous_multiply_add>, ScalarBeta, CLoadNode, DAGNode>;
|
||||
using EVTModule = HEVT<HostAuxStore<Gemm, true>, TernaryCompute1>;
|
||||
};
|
||||
|
||||
@ -114,7 +114,7 @@ public:
|
||||
using EVTNode = HEVT<
|
||||
HostAuxStore<Gemm, false, cutlass::half_t, cutlass::layout::RowMajor>,
|
||||
HEVT<
|
||||
HostCompute<Gemm, cutlass::multiply_add>,
|
||||
HostCompute<Gemm, cutlass::homogeneous_multiply_add>,
|
||||
HostScalarBroadcast<Gemm, 2>,
|
||||
HostAccumulator<Gemm>,
|
||||
HostAuxLoad<Gemm, true>
|
||||
@ -133,7 +133,7 @@ public:
|
||||
EVTNode,
|
||||
HostColBroadcast<Gemm, cutlass::half_t>,
|
||||
HostCompute<Gemm, cutlass::plus>,
|
||||
HostCompute<Gemm, cutlass::maximum>
|
||||
HostCompute<Gemm, cutlass::maximum_with_default_nan_propagation>
|
||||
>
|
||||
>;
|
||||
};
|
||||
@ -147,13 +147,13 @@ public:
|
||||
using BinaryCompute0 = HEVT<HostCompute<Gemm, cutlass::multiplies>, ScalarAlpha, AccFetchNode>;
|
||||
using ScalarBeta = HostScalarBroadcast<Gemm, 1>;
|
||||
using CLoadNode = HostAuxLoad<Gemm, true>;
|
||||
using TernaryCompute1 = HEVT<HostCompute<Gemm, cutlass::multiply_add>, ScalarBeta, CLoadNode, BinaryCompute0>;
|
||||
using TernaryCompute1 = HEVT<HostCompute<Gemm, cutlass::homogeneous_multiply_add>, ScalarBeta, CLoadNode, BinaryCompute0>;
|
||||
using ReduceNode = HEVT<ReduceOp<Gemm, cutlass::plus, float>, TernaryCompute1>;
|
||||
using EVTModule = HEVT<HostAuxStore<Gemm, true>, ReduceNode>;
|
||||
};
|
||||
|
||||
// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias
|
||||
// if D is fp8
|
||||
// if D is fp8
|
||||
// D = scale_d * activation(Z)
|
||||
// else
|
||||
// D = activation(Z)
|
||||
@ -167,11 +167,11 @@ public:
|
||||
HEVT<
|
||||
HostCompute<Gemm, ActivationFn>, // activation(Z)
|
||||
HEVT<
|
||||
HostCompute<Gemm, cutlass::multiply_add>,
|
||||
HostCompute<Gemm, cutlass::homogeneous_multiply_add>,
|
||||
HostScalarBroadcast<Gemm, 1, 2>, // scale_c * beta
|
||||
HostAuxLoad<Gemm, true>, // C
|
||||
HEVT<
|
||||
HostCompute<Gemm, cutlass::multiply_add>,
|
||||
HostCompute<Gemm, cutlass::homogeneous_multiply_add>,
|
||||
HostScalarBroadcast<Gemm, 1, 3>, // scale_a * scale_b * alpha
|
||||
HostAccumulator<Gemm>,
|
||||
HostColBroadcast<Gemm, ElementD>,
|
||||
@ -184,12 +184,12 @@ public:
|
||||
};
|
||||
|
||||
// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias
|
||||
// if D is fp8
|
||||
// if D is fp8
|
||||
// amax_d = max(abs(elements in activation(Z)))
|
||||
// D = scale_d * activation(Z)
|
||||
// else
|
||||
// D = activation(Z)
|
||||
// if Aux is fp8
|
||||
// if Aux is fp8
|
||||
// amax_aux = max(abs(elements in Z))
|
||||
// Aux = scale_aux * Z
|
||||
// else
|
||||
@ -204,11 +204,11 @@ public:
|
||||
HST<Gemm,
|
||||
// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias
|
||||
HEVT<
|
||||
HostCompute<Gemm, cutlass::multiply_add>,
|
||||
HostCompute<Gemm, cutlass::homogeneous_multiply_add>,
|
||||
HostScalarBroadcast<Gemm, 1, 2>, // scale_c * beta
|
||||
HostAuxLoad<Gemm, true>, // C
|
||||
HEVT<
|
||||
HostCompute<Gemm, cutlass::multiply_add>,
|
||||
HostCompute<Gemm, cutlass::homogeneous_multiply_add>,
|
||||
HostScalarBroadcast<Gemm, 1, 3>, // scale_a * scale_b * alpha
|
||||
HostAccumulator<Gemm>,
|
||||
HostColBroadcast<Gemm, ElementD>,
|
||||
@ -218,7 +218,7 @@ public:
|
||||
HEVT<
|
||||
HostCompute<Gemm, cutlass::epilogue::fusion::detail::ScaleOutOp<ElementD>::Op>,
|
||||
HEVT<
|
||||
HostScalarReduce<Gemm, amax, float>,
|
||||
HostScalarReduce<Gemm, amax, float>,
|
||||
HEVT<
|
||||
HostCompute<Gemm, ActivationFn>, //activation(Z) * scaled_d
|
||||
HostAccumulator<Gemm>, // Z
|
||||
@ -247,6 +247,13 @@ public:
|
||||
namespace cutlass::epilogue {
|
||||
namespace fusion {
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename T>
|
||||
struct maximum_with_default_nan_propagation : maximum<T> {};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
/// D = alpha * acc + beta * C + AuxLoad
|
||||
template<
|
||||
@ -258,16 +265,16 @@ template<
|
||||
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
using Sm90LinCombAuxLoad =
|
||||
Sm90EVT<Sm90Compute<multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + (alpha * acc + bias)
|
||||
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + (alpha * acc + bias)
|
||||
Sm90ScalarBroadcast<ElementScalar>, // beta
|
||||
Sm90SrcFetch, // C
|
||||
Sm90EVT<Sm90Compute<multiply_add, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc + bias
|
||||
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc + bias
|
||||
Sm90ScalarBroadcast<ElementScalar>, // alpha
|
||||
Sm90AccFetch, // acc
|
||||
Sm90AuxLoad<
|
||||
AuxLoadDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile,
|
||||
typename AuxLoadDescriptor::Element,
|
||||
typename AuxLoadDescriptor::Stride, typename AuxLoadDescriptor::SmemLayoutAtom,
|
||||
AuxLoadDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile,
|
||||
typename AuxLoadDescriptor::Element,
|
||||
typename AuxLoadDescriptor::Stride, typename AuxLoadDescriptor::SmemLayoutAtom,
|
||||
typename AuxLoadDescriptor::CopyOpS2R // aux load
|
||||
>
|
||||
>
|
||||
@ -286,7 +293,7 @@ template<
|
||||
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
using Sm90LinCombEVTDAG =
|
||||
Sm90EVT<Sm90Compute<multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + (alpha * acc + aux)
|
||||
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + (alpha * acc + aux)
|
||||
Sm90ScalarBroadcast<ElementScalar>, // beta
|
||||
Sm90SrcFetch, // C
|
||||
Sm90TopologicalVisitor<
|
||||
@ -302,13 +309,13 @@ using Sm90LinCombEVTDAG =
|
||||
Sm90ScalarBroadcast<ElementScalar>, // alpha
|
||||
Sm90AccFetch, // acc
|
||||
Sm90AuxLoad<
|
||||
AuxLoadDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile,
|
||||
typename AuxLoadDescriptor::Element, typename AuxLoadDescriptor::Stride,
|
||||
AuxLoadDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile,
|
||||
typename AuxLoadDescriptor::Element, typename AuxLoadDescriptor::Stride,
|
||||
typename AuxLoadDescriptor::SmemLayoutAtom, typename AuxLoadDescriptor::CopyOpS2R>,
|
||||
Sm90Compute<multiply_add, ElementCompute, ElementCompute, RoundStyle>,
|
||||
Sm90Compute<homogeneous_multiply_add, ElementCompute, ElementCompute, RoundStyle>,
|
||||
Sm90Compute<cutlass::epilogue::thread::ReLu, ElementCompute, ElementCompute, RoundStyle>,
|
||||
Sm90Compute<plus, ElementCompute, ElementCompute, RoundStyle>
|
||||
>
|
||||
>
|
||||
>;
|
||||
|
||||
|
||||
@ -336,10 +343,10 @@ using Sm90LinCombDAGEVT =
|
||||
>,
|
||||
Sm90EVT<
|
||||
Sm90AuxStore<
|
||||
AuxStoreDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile,
|
||||
AuxStoreDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile,
|
||||
typename AuxStoreDescriptor::Element, RoundStyle, typename AuxStoreDescriptor::Stride,
|
||||
typename AuxStoreDescriptor::SmemLayoutAtom, typename AuxStoreDescriptor::CopyOpR2S>,
|
||||
Sm90EVT<Sm90Compute<multiply_add, ElementCompute, ElementCompute, RoundStyle>,
|
||||
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementCompute, ElementCompute, RoundStyle>,
|
||||
Sm90ScalarBroadcast<ElementScalar>,
|
||||
Sm90AccFetch,
|
||||
Sm90SrcFetch
|
||||
@ -347,7 +354,7 @@ using Sm90LinCombDAGEVT =
|
||||
>,
|
||||
Sm90ColBroadcast<0, typename EpilogueDescriptor::TileShape, ElementBias>,
|
||||
Sm90Compute<plus, ElementCompute, ElementCompute, RoundStyle>,
|
||||
Sm90Compute<maximum, ElementOutput, ElementCompute, RoundStyle>
|
||||
Sm90Compute<detail::maximum_with_default_nan_propagation, ElementOutput, ElementCompute, RoundStyle>
|
||||
>;
|
||||
|
||||
|
||||
@ -362,18 +369,18 @@ template<
|
||||
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
using Sm90LinCombPerColumnBias =
|
||||
Sm90EVT<Sm90Compute<multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + (alpha * acc + bias)
|
||||
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + (alpha * acc + bias)
|
||||
Sm90ScalarBroadcast<ElementScalar>, // beta
|
||||
Sm90SrcFetch, // C
|
||||
Sm90EVT<Sm90Compute<multiply_add, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc + bias
|
||||
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc + bias
|
||||
Sm90ScalarBroadcast<ElementScalar>, // alpha
|
||||
Sm90AccFetch, // acc
|
||||
Sm90RowBroadcast<
|
||||
ceil_div(
|
||||
EpilogueDescriptor::StagesC,
|
||||
EpilogueDescriptor::StagesC,
|
||||
size(shape_div(take<0, 2>(typename EpilogueDescriptor::TileShape{}), typename EpilogueDescriptor::EpilogueTile{}))
|
||||
) + 1,
|
||||
typename EpilogueDescriptor::TileShape,
|
||||
) + 1,
|
||||
typename EpilogueDescriptor::TileShape,
|
||||
ElementBias
|
||||
>
|
||||
>
|
||||
@ -385,7 +392,7 @@ using Sm90LinCombPerColumnBias =
|
||||
template<
|
||||
template <class> class RegReduceFn,
|
||||
template <class> class GmemReduceFn,
|
||||
class ElementReduce,
|
||||
class ElementReduce,
|
||||
class CtaTileShapeMNK,
|
||||
class ElementOutput,
|
||||
class ElementCompute,
|
||||
@ -394,7 +401,7 @@ template<
|
||||
>
|
||||
using Sm90LinCombPerColumnReduce =
|
||||
Sm90EVT<Sm90RowReduction<RegReduceFn, GmemReduceFn, 0, CtaTileShapeMNK, ElementReduce, ElementCompute, RoundStyle>, // per column reduce
|
||||
Sm90EVT<Sm90Compute<multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + alpha * acc
|
||||
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + alpha * acc
|
||||
Sm90ScalarBroadcast<ElementScalar>, // beta
|
||||
Sm90SrcFetch, // C
|
||||
Sm90EVT<Sm90Compute<multiplies, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc
|
||||
@ -410,7 +417,7 @@ using Sm90LinCombPerColumnReduce =
|
||||
template<
|
||||
template <class> class RegReduceFn,
|
||||
template <class> class GmemReduceFn,
|
||||
class ElementReduce,
|
||||
class ElementReduce,
|
||||
class CtaTileShapeMNK,
|
||||
class ElementOutput,
|
||||
class ElementCompute,
|
||||
@ -419,7 +426,7 @@ template<
|
||||
>
|
||||
using Sm90LinCombPerRowReduce =
|
||||
Sm90EVT<Sm90ColReduction<RegReduceFn, GmemReduceFn, 0, CtaTileShapeMNK, ElementReduce, ElementCompute, RoundStyle>, // per column reduce
|
||||
Sm90EVT<Sm90Compute<multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + alpha * acc
|
||||
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + alpha * acc
|
||||
Sm90ScalarBroadcast<ElementScalar>, // beta
|
||||
Sm90SrcFetch, // C
|
||||
Sm90EVT<Sm90Compute<multiplies, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc
|
||||
@ -435,7 +442,7 @@ using Sm90LinCombPerRowReduce =
|
||||
template<
|
||||
template <class> class RegReduceFn,
|
||||
template <class> class GmemReduceFn,
|
||||
class ElementReduce,
|
||||
class ElementReduce,
|
||||
class ElementOutput,
|
||||
class ElementCompute,
|
||||
class ElementScalar = ElementCompute,
|
||||
@ -443,7 +450,7 @@ template<
|
||||
>
|
||||
using Sm90LinCombScalarReduce =
|
||||
Sm90EVT<Sm90ScalarReduction<RegReduceFn, GmemReduceFn, ElementReduce, ElementCompute, RoundStyle>, // per column reduce
|
||||
Sm90EVT<Sm90Compute<multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + alpha * acc
|
||||
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + alpha * acc
|
||||
Sm90ScalarBroadcast<ElementScalar>, // beta
|
||||
Sm90SrcFetch, // C
|
||||
Sm90EVT<Sm90Compute<multiplies, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc
|
||||
|
||||
@ -58,44 +58,7 @@ using namespace cute;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_bf16t_bf16t_bf16n_align8_tensor_op_gmma_f32, 64x128x64) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::RowMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::bfloat16_t, LayoutA, 8,
|
||||
cutlass::bfloat16_t, LayoutB, 8,
|
||||
float,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::bfloat16_t, LayoutC, 8,
|
||||
cutlass::bfloat16_t, LayoutC, 8,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align4_tensor_op_gmma_f32, 64x128x64) {
|
||||
TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align4_tensor_op_gmma_f32, 128x128x64) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
@ -105,14 +68,14 @@ TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align4_tensor_op_gmma_f32, 64x128x64) {
|
||||
cutlass::bfloat16_t, LayoutA, 4,
|
||||
cutlass::bfloat16_t, LayoutB, 4,
|
||||
float,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::bfloat16_t, LayoutC, 4,
|
||||
@ -132,9 +95,9 @@ TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align4_tensor_op_gmma_f32, 64x128x64) {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_bf16n_bf16t_bf16n_align2_tensor_op_gmma_f32, 64x128x64) {
|
||||
using LayoutA = cutlass::layout::ColumnMajor;
|
||||
using LayoutB = cutlass::layout::RowMajor;
|
||||
TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align2_tensor_op_gmma_f32, 128x128x64) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
@ -142,14 +105,14 @@ TEST(SM90_Device_Gemm_bf16n_bf16t_bf16n_align2_tensor_op_gmma_f32, 64x128x64) {
|
||||
cutlass::bfloat16_t, LayoutA, 2,
|
||||
cutlass::bfloat16_t, LayoutB, 2,
|
||||
float,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::bfloat16_t, LayoutC, 2,
|
||||
@ -169,41 +132,4 @@ TEST(SM90_Device_Gemm_bf16n_bf16t_bf16n_align2_tensor_op_gmma_f32, 64x128x64) {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_bf16n_bf16n_bf16n_align8_tensor_op_gmma_f32, 64x128x64) {
|
||||
using LayoutA = cutlass::layout::ColumnMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::bfloat16_t, LayoutA, 8,
|
||||
cutlass::bfloat16_t, LayoutB, 8,
|
||||
float,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::bfloat16_t, LayoutC, 8,
|
||||
cutlass::bfloat16_t, LayoutC, 8,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
@ -0,0 +1,172 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
|
||||
#include "gemm_testbed_3x.hpp"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
using namespace cute;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align8_tensor_op_gmma_f32_warpspecialized, 64x128x64) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::bfloat16_t, LayoutA, 8,
|
||||
cutlass::bfloat16_t, LayoutB, 8,
|
||||
float,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::bfloat16_t, LayoutC, 8,
|
||||
cutlass::bfloat16_t, LayoutC, 8,
|
||||
cutlass::epilogue::NoSmemWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align4_tensor_op_gmma_f32_warpspecialized, 64x128x64) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::bfloat16_t, LayoutA, 4,
|
||||
cutlass::bfloat16_t, LayoutB, 4,
|
||||
float,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::bfloat16_t, LayoutC, 4,
|
||||
cutlass::bfloat16_t, LayoutC, 4,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align2_tensor_op_gmma_f32_warpspecialized, 64x128x64) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::bfloat16_t, LayoutA, 2,
|
||||
cutlass::bfloat16_t, LayoutB, 2,
|
||||
float,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::bfloat16_t, LayoutC, 2,
|
||||
cutlass::bfloat16_t, LayoutC, 2,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
@ -0,0 +1,172 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
|
||||
#include "gemm_testbed_3x.hpp"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
using namespace cute;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align8_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::bfloat16_t, LayoutA, 8,
|
||||
cutlass::bfloat16_t, LayoutB, 8,
|
||||
float,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::bfloat16_t, LayoutC, 8,
|
||||
cutlass::bfloat16_t, LayoutC, 8,
|
||||
cutlass::epilogue::NoSmemWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align4_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::bfloat16_t, LayoutA, 4,
|
||||
cutlass::bfloat16_t, LayoutB, 4,
|
||||
float,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::bfloat16_t, LayoutC, 4,
|
||||
cutlass::bfloat16_t, LayoutC, 4,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align2_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::bfloat16_t, LayoutA, 2,
|
||||
cutlass::bfloat16_t, LayoutB, 2,
|
||||
float,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::bfloat16_t, LayoutC, 2,
|
||||
cutlass::bfloat16_t, LayoutC, 2,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
@ -0,0 +1,172 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
|
||||
#include "gemm_testbed_3x.hpp"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
using namespace cute;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align8_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::bfloat16_t, LayoutA, 8,
|
||||
cutlass::bfloat16_t, LayoutB, 8,
|
||||
float,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::bfloat16_t, LayoutC, 8,
|
||||
cutlass::bfloat16_t, LayoutC, 8,
|
||||
cutlass::epilogue::NoSmemWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align4_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::bfloat16_t, LayoutA, 4,
|
||||
cutlass::bfloat16_t, LayoutB, 4,
|
||||
float,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::bfloat16_t, LayoutC, 4,
|
||||
cutlass::bfloat16_t, LayoutC, 4,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align2_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::bfloat16_t, LayoutA, 2,
|
||||
cutlass::bfloat16_t, LayoutB, 2,
|
||||
float,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::bfloat16_t, LayoutC, 2,
|
||||
cutlass::bfloat16_t, LayoutC, 2,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
@ -0,0 +1,365 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
|
||||
#include "gemm_testbed_3x.hpp"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
using namespace cute;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////////////// TT //////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16t_f16n_align4_tensor_op_gmma_f32, 128x128x64) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::RowMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 4,
|
||||
cutlass::half_t, LayoutB, 4,
|
||||
float,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::half_t, LayoutC, 4,
|
||||
cutlass::half_t, LayoutC, 4,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16t_f16n_align2_tensor_op_gmma_f32, 128x128x64) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::RowMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 2,
|
||||
cutlass::half_t, LayoutB, 2,
|
||||
float,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::half_t, LayoutC, 2,
|
||||
cutlass::half_t, LayoutC, 2,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////////////// TN //////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f16n_align4_tensor_op_gmma_f32, 128x128x64) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 4,
|
||||
cutlass::half_t, LayoutB, 4,
|
||||
float,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::half_t, LayoutC, 4,
|
||||
cutlass::half_t, LayoutC, 4,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f16n_align2_tensor_op_gmma_f32, 128x128x64) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 2,
|
||||
cutlass::half_t, LayoutB, 2,
|
||||
float,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::half_t, LayoutC, 2,
|
||||
cutlass::half_t, LayoutC, 2,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////////////// NT //////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16n_f16t_f16n_align4_tensor_op_gmma_f32, 128x128x64) {
|
||||
using LayoutA = cutlass::layout::ColumnMajor;
|
||||
using LayoutB = cutlass::layout::RowMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 4,
|
||||
cutlass::half_t, LayoutB, 4,
|
||||
float,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::half_t, LayoutC, 4,
|
||||
cutlass::half_t, LayoutC, 4,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16n_f16t_f16n_align2_tensor_op_gmma_f32, 128x128x64) {
|
||||
using LayoutA = cutlass::layout::ColumnMajor;
|
||||
using LayoutB = cutlass::layout::RowMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 2,
|
||||
cutlass::half_t, LayoutB, 2,
|
||||
float,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::half_t, LayoutC, 2,
|
||||
cutlass::half_t, LayoutC, 2,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////////////// NN //////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16n_f16n_f16n_align4_tensor_op_gmma_f32, 128x128x64) {
|
||||
using LayoutA = cutlass::layout::ColumnMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 4,
|
||||
cutlass::half_t, LayoutB, 4,
|
||||
float,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::half_t, LayoutC, 4,
|
||||
cutlass::half_t, LayoutC, 4,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16n_f16n_f16n_align2_tensor_op_gmma_f32, 128x128x64) {
|
||||
using LayoutA = cutlass::layout::ColumnMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 2,
|
||||
cutlass::half_t, LayoutB, 2,
|
||||
float,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::half_t, LayoutC, 2,
|
||||
cutlass::half_t, LayoutC, 2,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
@ -60,7 +60,7 @@ using namespace cute;
|
||||
///////////////////////////////////// TT //////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16t_f16n_align8_tensor_op_gmma_f32, 64x128x64) {
|
||||
TEST(SM90_Device_Gemm_f16t_f16t_f16n_align8_tensor_op_gmma_f32_warpspecialized, 64x128x64) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::RowMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
@ -72,7 +72,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f16n_align8_tensor_op_gmma_f32, 64x128x64) {
|
||||
float,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelMultistage
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
@ -95,7 +95,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f16n_align8_tensor_op_gmma_f32, 64x128x64) {
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16t_f16n_align4_tensor_op_gmma_f32, 64x128x64) {
|
||||
TEST(SM90_Device_Gemm_f16t_f16t_f16n_align4_tensor_op_gmma_f32_warpspecialized, 64x128x64) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::RowMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
@ -107,7 +107,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f16n_align4_tensor_op_gmma_f32, 64x128x64) {
|
||||
float,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
@ -131,7 +131,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f16n_align4_tensor_op_gmma_f32, 64x128x64) {
|
||||
}
|
||||
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16t_f16n_align2_tensor_op_gmma_f32, 64x128x64) {
|
||||
TEST(SM90_Device_Gemm_f16t_f16t_f16n_align2_tensor_op_gmma_f32_warpspecialized, 64x128x64) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::RowMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
@ -143,7 +143,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f16n_align2_tensor_op_gmma_f32, 64x128x64) {
|
||||
float,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
@ -170,7 +170,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f16n_align2_tensor_op_gmma_f32, 64x128x64) {
|
||||
///////////////////////////////////// TN //////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f16n_align8_tensor_op_gmma_f32, 64x128x64) {
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f16n_align8_tensor_op_gmma_f32_warpspecialized, 64x128x64) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
@ -182,7 +182,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_align8_tensor_op_gmma_f32, 64x128x64) {
|
||||
float,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelMultistage
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
@ -207,7 +207,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_align8_tensor_op_gmma_f32, 64x128x64) {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f16n_align4_tensor_op_gmma_f32, 64x128x64) {
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f16n_align4_tensor_op_gmma_f32_warpspecialized, 64x128x64) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
@ -219,7 +219,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_align4_tensor_op_gmma_f32, 64x128x64) {
|
||||
float,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
@ -244,7 +244,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_align4_tensor_op_gmma_f32, 64x128x64) {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f16n_align2_tensor_op_gmma_f32, 64x128x64) {
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f16n_align2_tensor_op_gmma_f32_warpspecialized, 64x128x64) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
@ -256,7 +256,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_align2_tensor_op_gmma_f32, 64x128x64) {
|
||||
float,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
@ -283,7 +283,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_align2_tensor_op_gmma_f32, 64x128x64) {
|
||||
///////////////////////////////////// NT //////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16n_f16t_f16n_align8_tensor_op_gmma_f32, 64x128x64) {
|
||||
TEST(SM90_Device_Gemm_f16n_f16t_f16n_align8_tensor_op_gmma_f32_warpspecialized, 64x128x64) {
|
||||
using LayoutA = cutlass::layout::ColumnMajor;
|
||||
using LayoutB = cutlass::layout::RowMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
@ -295,7 +295,7 @@ TEST(SM90_Device_Gemm_f16n_f16t_f16n_align8_tensor_op_gmma_f32, 64x128x64) {
|
||||
float,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelMultistage
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
@ -320,7 +320,7 @@ TEST(SM90_Device_Gemm_f16n_f16t_f16n_align8_tensor_op_gmma_f32, 64x128x64) {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16n_f16t_f16n_align4_tensor_op_gmma_f32, 64x128x64) {
|
||||
TEST(SM90_Device_Gemm_f16n_f16t_f16n_align4_tensor_op_gmma_f32_warpspecialized, 64x128x64) {
|
||||
using LayoutA = cutlass::layout::ColumnMajor;
|
||||
using LayoutB = cutlass::layout::RowMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
@ -332,7 +332,7 @@ TEST(SM90_Device_Gemm_f16n_f16t_f16n_align4_tensor_op_gmma_f32, 64x128x64) {
|
||||
float,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
@ -357,7 +357,7 @@ TEST(SM90_Device_Gemm_f16n_f16t_f16n_align4_tensor_op_gmma_f32, 64x128x64) {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16n_f16t_f16n_align2_tensor_op_gmma_f32, 64x128x64) {
|
||||
TEST(SM90_Device_Gemm_f16n_f16t_f16n_align2_tensor_op_gmma_f32_warpspecialized, 64x128x64) {
|
||||
using LayoutA = cutlass::layout::ColumnMajor;
|
||||
using LayoutB = cutlass::layout::RowMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
@ -369,7 +369,7 @@ TEST(SM90_Device_Gemm_f16n_f16t_f16n_align2_tensor_op_gmma_f32, 64x128x64) {
|
||||
float,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
@ -396,7 +396,7 @@ TEST(SM90_Device_Gemm_f16n_f16t_f16n_align2_tensor_op_gmma_f32, 64x128x64) {
|
||||
///////////////////////////////////// NN //////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16n_f16n_f16n_align8_tensor_op_gmma_f32, 64x128x64) {
|
||||
TEST(SM90_Device_Gemm_f16n_f16n_f16n_align8_tensor_op_gmma_f32_warpspecialized, 64x128x64) {
|
||||
using LayoutA = cutlass::layout::ColumnMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
@ -408,7 +408,7 @@ TEST(SM90_Device_Gemm_f16n_f16n_f16n_align8_tensor_op_gmma_f32, 64x128x64) {
|
||||
float,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelMultistage
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
@ -433,7 +433,7 @@ TEST(SM90_Device_Gemm_f16n_f16n_f16n_align8_tensor_op_gmma_f32, 64x128x64) {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16n_f16n_f16n_align4_tensor_op_gmma_f32, 64x128x64) {
|
||||
TEST(SM90_Device_Gemm_f16n_f16n_f16n_align4_tensor_op_gmma_f32_warpspecialized, 64x128x64) {
|
||||
using LayoutA = cutlass::layout::ColumnMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
@ -445,7 +445,7 @@ TEST(SM90_Device_Gemm_f16n_f16n_f16n_align4_tensor_op_gmma_f32, 64x128x64) {
|
||||
float,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
@ -470,7 +470,7 @@ TEST(SM90_Device_Gemm_f16n_f16n_f16n_align4_tensor_op_gmma_f32, 64x128x64) {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16n_f16n_f16n_align2_tensor_op_gmma_f32, 64x128x64) {
|
||||
TEST(SM90_Device_Gemm_f16n_f16n_f16n_align2_tensor_op_gmma_f32_warpspecialized, 64x128x64) {
|
||||
using LayoutA = cutlass::layout::ColumnMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
@ -482,7 +482,7 @@ TEST(SM90_Device_Gemm_f16n_f16n_f16n_align2_tensor_op_gmma_f32, 64x128x64) {
|
||||
float,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
@ -0,0 +1,510 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
|
||||
#include "gemm_testbed_3x.hpp"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
using namespace cute;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////////////// TT //////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16t_f16n_align8_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::RowMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 8,
|
||||
cutlass::half_t, LayoutB, 8,
|
||||
float,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::half_t, LayoutC, 8,
|
||||
cutlass::half_t, LayoutC, 8,
|
||||
cutlass::epilogue::NoSmemWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16t_f16n_align4_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::RowMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 4,
|
||||
cutlass::half_t, LayoutB, 4,
|
||||
float,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::half_t, LayoutC, 4,
|
||||
cutlass::half_t, LayoutC, 4,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16t_f16n_align2_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::RowMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 2,
|
||||
cutlass::half_t, LayoutB, 2,
|
||||
float,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::half_t, LayoutC, 2,
|
||||
cutlass::half_t, LayoutC, 2,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////////////// TN //////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f16n_align8_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 8,
|
||||
cutlass::half_t, LayoutB, 8,
|
||||
float,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::half_t, LayoutC, 8,
|
||||
cutlass::half_t, LayoutC, 8,
|
||||
cutlass::epilogue::NoSmemWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f16n_align4_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 4,
|
||||
cutlass::half_t, LayoutB, 4,
|
||||
float,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::half_t, LayoutC, 4,
|
||||
cutlass::half_t, LayoutC, 4,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f16n_align2_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 2,
|
||||
cutlass::half_t, LayoutB, 2,
|
||||
float,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::half_t, LayoutC, 2,
|
||||
cutlass::half_t, LayoutC, 2,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////////////// NT //////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16n_f16t_f16n_align8_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) {
|
||||
using LayoutA = cutlass::layout::ColumnMajor;
|
||||
using LayoutB = cutlass::layout::RowMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 8,
|
||||
cutlass::half_t, LayoutB, 8,
|
||||
float,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::half_t, LayoutC, 8,
|
||||
cutlass::half_t, LayoutC, 8,
|
||||
cutlass::epilogue::NoSmemWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16n_f16t_f16n_align4_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) {
|
||||
using LayoutA = cutlass::layout::ColumnMajor;
|
||||
using LayoutB = cutlass::layout::RowMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 4,
|
||||
cutlass::half_t, LayoutB, 4,
|
||||
float,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::half_t, LayoutC, 4,
|
||||
cutlass::half_t, LayoutC, 4,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16n_f16t_f16n_align2_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) {
|
||||
using LayoutA = cutlass::layout::ColumnMajor;
|
||||
using LayoutB = cutlass::layout::RowMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 2,
|
||||
cutlass::half_t, LayoutB, 2,
|
||||
float,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::half_t, LayoutC, 2,
|
||||
cutlass::half_t, LayoutC, 2,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////////////// NN //////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16n_f16n_f16n_align8_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) {
|
||||
using LayoutA = cutlass::layout::ColumnMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 8,
|
||||
cutlass::half_t, LayoutB, 8,
|
||||
float,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::half_t, LayoutC, 8,
|
||||
cutlass::half_t, LayoutC, 8,
|
||||
cutlass::epilogue::NoSmemWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16n_f16n_f16n_align4_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) {
|
||||
using LayoutA = cutlass::layout::ColumnMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 4,
|
||||
cutlass::half_t, LayoutB, 4,
|
||||
float,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::half_t, LayoutC, 4,
|
||||
cutlass::half_t, LayoutC, 4,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16n_f16n_f16n_align2_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) {
|
||||
using LayoutA = cutlass::layout::ColumnMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 2,
|
||||
cutlass::half_t, LayoutB, 2,
|
||||
float,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::half_t, LayoutC, 2,
|
||||
cutlass::half_t, LayoutC, 2,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
@ -0,0 +1,510 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
|
||||
#include "gemm_testbed_3x.hpp"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
using namespace cute;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////////////// TT //////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16t_f16n_align8_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::RowMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 8,
|
||||
cutlass::half_t, LayoutB, 8,
|
||||
float,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::half_t, LayoutC, 8,
|
||||
cutlass::half_t, LayoutC, 8,
|
||||
cutlass::epilogue::NoSmemWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16t_f16n_align4_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::RowMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 4,
|
||||
cutlass::half_t, LayoutB, 4,
|
||||
float,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::half_t, LayoutC, 4,
|
||||
cutlass::half_t, LayoutC, 4,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16t_f16n_align2_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::RowMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 2,
|
||||
cutlass::half_t, LayoutB, 2,
|
||||
float,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::half_t, LayoutC, 2,
|
||||
cutlass::half_t, LayoutC, 2,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////////////// TN //////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f16n_align8_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 8,
|
||||
cutlass::half_t, LayoutB, 8,
|
||||
float,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::half_t, LayoutC, 8,
|
||||
cutlass::half_t, LayoutC, 8,
|
||||
cutlass::epilogue::NoSmemWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f16n_align4_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 4,
|
||||
cutlass::half_t, LayoutB, 4,
|
||||
float,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::half_t, LayoutC, 4,
|
||||
cutlass::half_t, LayoutC, 4,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f16n_align2_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 2,
|
||||
cutlass::half_t, LayoutB, 2,
|
||||
float,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::half_t, LayoutC, 2,
|
||||
cutlass::half_t, LayoutC, 2,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////////////// NT //////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16n_f16t_f16n_align8_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) {
|
||||
using LayoutA = cutlass::layout::ColumnMajor;
|
||||
using LayoutB = cutlass::layout::RowMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 8,
|
||||
cutlass::half_t, LayoutB, 8,
|
||||
float,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::half_t, LayoutC, 8,
|
||||
cutlass::half_t, LayoutC, 8,
|
||||
cutlass::epilogue::NoSmemWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16n_f16t_f16n_align4_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) {
|
||||
using LayoutA = cutlass::layout::ColumnMajor;
|
||||
using LayoutB = cutlass::layout::RowMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 4,
|
||||
cutlass::half_t, LayoutB, 4,
|
||||
float,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::half_t, LayoutC, 4,
|
||||
cutlass::half_t, LayoutC, 4,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16n_f16t_f16n_align2_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) {
|
||||
using LayoutA = cutlass::layout::ColumnMajor;
|
||||
using LayoutB = cutlass::layout::RowMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 2,
|
||||
cutlass::half_t, LayoutB, 2,
|
||||
float,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::half_t, LayoutC, 2,
|
||||
cutlass::half_t, LayoutC, 2,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////////////// NN //////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16n_f16n_f16n_align8_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) {
|
||||
using LayoutA = cutlass::layout::ColumnMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 8,
|
||||
cutlass::half_t, LayoutB, 8,
|
||||
float,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::half_t, LayoutC, 8,
|
||||
cutlass::half_t, LayoutC, 8,
|
||||
cutlass::epilogue::NoSmemWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16n_f16n_f16n_align4_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) {
|
||||
using LayoutA = cutlass::layout::ColumnMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 4,
|
||||
cutlass::half_t, LayoutB, 4,
|
||||
float,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::half_t, LayoutC, 4,
|
||||
cutlass::half_t, LayoutC, 4,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_f16n_f16n_f16n_align2_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) {
|
||||
using LayoutA = cutlass::layout::ColumnMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 2,
|
||||
cutlass::half_t, LayoutB, 2,
|
||||
float,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_64,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::half_t, LayoutC, 2,
|
||||
cutlass::half_t, LayoutC, 2,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
@ -469,7 +469,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
|
||||
EXPECT_TRUE(passed);
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasS8_ReLU_VoidC) {
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasS8_ReLU_VoidC_U1Aux) {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
@ -478,8 +478,9 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
|
||||
using ClusterShape_MNK = Shape<_2,_2,_1>;
|
||||
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
// ReLU with uint1b_t aux will compute dReLU/dZ as the aux output, i.e. Aux(i) = (Z(i) >= 0) ? 1 : 0
|
||||
using FusionOperation = cutlass::epilogue::fusion::LinCombPerRowBiasEltActAux<
|
||||
LayoutC, cutlass::epilogue::thread::ReLu, cutlass::half_t, float, cutlass::half_t, int8_t>;
|
||||
LayoutC, cutlass::epilogue::thread::ReLU, cutlass::half_t, float, cutlass::uint1b_t, int8_t>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
@ -514,4 +515,94 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
|
||||
EXPECT_TRUE(passed);
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_dReLU_dBias_VoidC) {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using TileShape_MNK = Shape<_256,_128,_64>;
|
||||
using ClusterShape_MNK = Shape<_2,_2,_1>;
|
||||
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using FusionOperation = cutlass::epilogue::fusion::LinCombDeEltActDePerRowBias<
|
||||
LayoutC, cutlass::epilogue::thread::dReLU, cutlass::half_t, float, cutlass::uint1b_t, float>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
void, LayoutC, 8,
|
||||
cutlass::half_t, LayoutC, 8,
|
||||
EpilogueSchedule,
|
||||
FusionOperation
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 8,
|
||||
cutlass::half_t, LayoutB, 8,
|
||||
float,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
bool passed = test::gemm::device::TestAllBiasElementwise<Gemm>();
|
||||
EXPECT_TRUE(passed);
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_dGELU_VoidC) {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using TileShape_MNK = Shape<_256,_128,_64>;
|
||||
using ClusterShape_MNK = Shape<_2,_2,_1>;
|
||||
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using FusionOperation = cutlass::epilogue::fusion::LinCombDeEltAct<
|
||||
LayoutC, cutlass::epilogue::thread::dGELU, cutlass::half_t, float, cutlass::half_t>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
void, LayoutC, 8,
|
||||
cutlass::half_t, LayoutC, 8,
|
||||
EpilogueSchedule,
|
||||
FusionOperation
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 8,
|
||||
cutlass::half_t, LayoutB, 8,
|
||||
float,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
bool passed = test::gemm::device::TestAllBiasElementwise<Gemm>(1.0, 0.0, /*check_relative_equality=*/true);
|
||||
EXPECT_TRUE(passed);
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
@ -460,4 +460,49 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128
|
||||
EXPECT_TRUE(passed);
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_pingpong_epilogue, 128x128x64_2x2x1_dReLU_dBias_VoidC) {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using TileShape_MNK = Shape<_128,_128,_64>;
|
||||
using ClusterShape_MNK = Shape<_2,_2,_1>;
|
||||
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
|
||||
using FusionOperation = cutlass::epilogue::fusion::LinCombDeEltActDePerRowBias<
|
||||
LayoutC, cutlass::epilogue::thread::dReLU, cutlass::half_t, float, cutlass::uint1b_t, float>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
void, LayoutC, 8,
|
||||
cutlass::half_t, LayoutC, 8,
|
||||
EpilogueSchedule,
|
||||
FusionOperation
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 8,
|
||||
cutlass::half_t, LayoutB, 8,
|
||||
float,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpong
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
bool passed = test::gemm::device::TestAllBiasElementwise<Gemm>();
|
||||
EXPECT_TRUE(passed);
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
@ -0,0 +1,209 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
|
||||
#include "gemm_testbed_3x.hpp"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
using namespace cute;
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative, 128x192x64_1x1x1) {
|
||||
using ElementA = cutlass::half_t;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using ElementB = cutlass::half_t;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using ElementC = float;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using TileShape_MNK = Shape<_128,_192,_64>;
|
||||
using ClusterShape_MNK = Shape<_1,_1,_1>;
|
||||
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct<
|
||||
cutlass::epilogue::thread::Identity, ElementC, ElementAccumulator, ElementAccumulator>;
|
||||
|
||||
using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, 16 / sizeof(ElementC),
|
||||
ElementC, LayoutC, 16 / sizeof(ElementC),
|
||||
EpilogueSchedule,
|
||||
FusionOperation
|
||||
>::CollectiveOp;
|
||||
|
||||
using KernelScheduleType = cutlass::gemm::KernelTmaWarpSpecializedCooperative;
|
||||
using AtomLayoutMNK = Layout<Shape<_2,_1,_1>>;
|
||||
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::rs_op_selector<
|
||||
ElementA, ElementB, ElementAccumulator, TileShape_MNK, GMMA::Major::K, GMMA::Major::K>(), AtomLayoutMNK{}));
|
||||
using GmemTiledCopyA = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
|
||||
using GmemTiledCopyB = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
|
||||
static constexpr cute::GMMA::Major GmmaMajorA = cutlass::gemm::collective::detail::gmma_rs_tag_to_major_A<LayoutA>();
|
||||
static constexpr cute::GMMA::Major GmmaMajorB = cutlass::gemm::collective::detail::gmma_rs_tag_to_major_B<LayoutB>();
|
||||
using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::rs_smem_selector<GmmaMajorA, ElementA,
|
||||
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), false>());
|
||||
using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::rs_smem_selector<GmmaMajorB, ElementB,
|
||||
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), false>());
|
||||
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename EpilogueOp::SharedStorage)>;
|
||||
|
||||
static constexpr int PipelineStages = cutlass::gemm::collective::detail::compute_stage_count_or_override<
|
||||
cutlass::gemm::collective::detail::sm90_smem_capacity_bytes,
|
||||
ElementA, ElementB, TileShape_MNK>(StageCountType{});
|
||||
using DispatchPolicy = cutlass::gemm::MainloopSm90TmaGmmaRmemAWarpSpecialized<
|
||||
PipelineStages, ClusterShape_MNK, KernelScheduleType>;
|
||||
|
||||
using CollectiveOp = cutlass::gemm::collective::CollectiveMma<
|
||||
DispatchPolicy,
|
||||
TileShape_MNK,
|
||||
ElementA,
|
||||
cutlass::gemm::TagToStrideA_t<LayoutA>,
|
||||
ElementB,
|
||||
cutlass::gemm::TagToStrideB_t<LayoutB>,
|
||||
TiledMma,
|
||||
GmemTiledCopyA,
|
||||
SmemLayoutAtomA,
|
||||
cute::Copy_Atom<cute::SM75_U32x4_LDSM_N,ElementA>,
|
||||
cute::identity,
|
||||
GmemTiledCopyB,
|
||||
SmemLayoutAtomB,
|
||||
void,
|
||||
cute::identity
|
||||
>;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
EpilogueOp
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative, 128x192x64_2x1x1) {
|
||||
using ElementA = cutlass::half_t;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using ElementB = cutlass::half_t;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using ElementC = float;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using TileShape_MNK = Shape<_128,_192,_64>;
|
||||
using ClusterShape_MNK = Shape<_2,_1,_1>;
|
||||
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct<
|
||||
cutlass::epilogue::thread::Identity, ElementC, ElementAccumulator, ElementAccumulator>;
|
||||
|
||||
using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, 16 / sizeof(ElementC),
|
||||
ElementC, LayoutC, 16 / sizeof(ElementC),
|
||||
EpilogueSchedule,
|
||||
FusionOperation
|
||||
>::CollectiveOp;
|
||||
|
||||
using KernelScheduleType = cutlass::gemm::KernelTmaWarpSpecializedCooperative;
|
||||
using AtomLayoutMNK = Layout<Shape<_2,_1,_1>>;
|
||||
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::rs_op_selector<
|
||||
ElementA, ElementB, ElementAccumulator, TileShape_MNK, GMMA::Major::K, GMMA::Major::K>(), AtomLayoutMNK{}));
|
||||
using GmemTiledCopyA = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
|
||||
using GmemTiledCopyB = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
|
||||
static constexpr cute::GMMA::Major GmmaMajorA = cutlass::gemm::collective::detail::gmma_rs_tag_to_major_A<LayoutA>();
|
||||
static constexpr cute::GMMA::Major GmmaMajorB = cutlass::gemm::collective::detail::gmma_rs_tag_to_major_B<LayoutB>();
|
||||
using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::rs_smem_selector<GmmaMajorA, ElementA,
|
||||
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), false>());
|
||||
using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::rs_smem_selector<GmmaMajorB, ElementB,
|
||||
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), false>());
|
||||
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename EpilogueOp::SharedStorage)>;
|
||||
|
||||
static constexpr int PipelineStages = cutlass::gemm::collective::detail::compute_stage_count_or_override<
|
||||
cutlass::gemm::collective::detail::sm90_smem_capacity_bytes,
|
||||
ElementA, ElementB, TileShape_MNK>(StageCountType{});
|
||||
using DispatchPolicy = cutlass::gemm::MainloopSm90TmaGmmaRmemAWarpSpecialized<
|
||||
PipelineStages, ClusterShape_MNK, KernelScheduleType>;
|
||||
|
||||
using CollectiveOp = cutlass::gemm::collective::CollectiveMma<
|
||||
DispatchPolicy,
|
||||
TileShape_MNK,
|
||||
ElementA,
|
||||
cutlass::gemm::TagToStrideA_t<LayoutA>,
|
||||
ElementB,
|
||||
cutlass::gemm::TagToStrideB_t<LayoutB>,
|
||||
TiledMma,
|
||||
GmemTiledCopyA,
|
||||
SmemLayoutAtomA,
|
||||
cute::Copy_Atom<cute::SM75_U32x4_LDSM_N,ElementA>,
|
||||
cute::identity,
|
||||
GmemTiledCopyB,
|
||||
SmemLayoutAtomB,
|
||||
void,
|
||||
cute::identity
|
||||
>;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
EpilogueOp
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise<Gemm>());
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
@ -0,0 +1,209 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
|
||||
#include "gemm_testbed_3x.hpp"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
using namespace cute;
|
||||
|
||||
TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative, 128x128x128_1x1x1) {
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using ElementC = float;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using TileShape_MNK = Shape<_128,_128,_128>;
|
||||
using ClusterShape_MNK = Shape<_1,_1,_1>;
|
||||
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct<
|
||||
cutlass::epilogue::thread::Identity, ElementC, ElementAccumulator, ElementAccumulator>;
|
||||
|
||||
using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, 16 / sizeof(ElementC),
|
||||
ElementC, LayoutC, 16 / sizeof(ElementC),
|
||||
EpilogueSchedule,
|
||||
FusionOperation
|
||||
>::CollectiveOp;
|
||||
|
||||
using KernelScheduleType = cutlass::gemm::KernelTmaWarpSpecializedCooperative;
|
||||
using AtomLayoutMNK = Layout<Shape<_2,_1,_1>>;
|
||||
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::rs_op_selector<
|
||||
ElementA, ElementB, ElementAccumulator, TileShape_MNK, GMMA::Major::K, GMMA::Major::K>(), AtomLayoutMNK{}));
|
||||
using GmemTiledCopyA = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
|
||||
using GmemTiledCopyB = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
|
||||
static constexpr cute::GMMA::Major GmmaMajorA = cutlass::gemm::collective::detail::gmma_rs_tag_to_major_A<LayoutA>();
|
||||
static constexpr cute::GMMA::Major GmmaMajorB = cutlass::gemm::collective::detail::gmma_rs_tag_to_major_B<LayoutB>();
|
||||
using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::rs_smem_selector<GmmaMajorA, ElementA,
|
||||
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), false>());
|
||||
using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::rs_smem_selector<GmmaMajorB, ElementB,
|
||||
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), false>());
|
||||
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename EpilogueOp::SharedStorage)>;
|
||||
|
||||
static constexpr int PipelineStages = cutlass::gemm::collective::detail::compute_stage_count_or_override<
|
||||
cutlass::gemm::collective::detail::sm90_smem_capacity_bytes,
|
||||
ElementA, ElementB, TileShape_MNK>(StageCountType{});
|
||||
using DispatchPolicy = cutlass::gemm::MainloopSm90TmaGmmaRmemAWarpSpecialized<
|
||||
PipelineStages, ClusterShape_MNK, KernelScheduleType>;
|
||||
|
||||
using CollectiveOp = cutlass::gemm::collective::CollectiveMma<
|
||||
DispatchPolicy,
|
||||
TileShape_MNK,
|
||||
ElementA,
|
||||
cutlass::gemm::TagToStrideA_t<LayoutA>,
|
||||
ElementB,
|
||||
cutlass::gemm::TagToStrideB_t<LayoutB>,
|
||||
TiledMma,
|
||||
GmemTiledCopyA,
|
||||
SmemLayoutAtomA,
|
||||
cute::Copy_Atom<cute::SM75_U32x4_LDSM_N,ElementA>,
|
||||
cute::identity,
|
||||
GmemTiledCopyB,
|
||||
SmemLayoutAtomB,
|
||||
void,
|
||||
cute::identity
|
||||
>;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
EpilogueOp
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative, 128x128x128_2x1x1) {
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using ElementC = float;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using TileShape_MNK = Shape<_128,_128,_128>;
|
||||
using ClusterShape_MNK = Shape<_2,_1,_1>;
|
||||
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct<
|
||||
cutlass::epilogue::thread::Identity, ElementC, ElementAccumulator, ElementAccumulator>;
|
||||
|
||||
using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, 16 / sizeof(ElementC),
|
||||
ElementC, LayoutC, 16 / sizeof(ElementC),
|
||||
EpilogueSchedule,
|
||||
FusionOperation
|
||||
>::CollectiveOp;
|
||||
|
||||
using KernelScheduleType = cutlass::gemm::KernelTmaWarpSpecializedCooperative;
|
||||
using AtomLayoutMNK = Layout<Shape<_2,_1,_1>>;
|
||||
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::rs_op_selector<
|
||||
ElementA, ElementB, ElementAccumulator, TileShape_MNK, GMMA::Major::K, GMMA::Major::K>(), AtomLayoutMNK{}));
|
||||
using GmemTiledCopyA = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
|
||||
using GmemTiledCopyB = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
|
||||
static constexpr cute::GMMA::Major GmmaMajorA = cutlass::gemm::collective::detail::gmma_rs_tag_to_major_A<LayoutA>();
|
||||
static constexpr cute::GMMA::Major GmmaMajorB = cutlass::gemm::collective::detail::gmma_rs_tag_to_major_B<LayoutB>();
|
||||
using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::rs_smem_selector<GmmaMajorA, ElementA,
|
||||
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), false>());
|
||||
using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::rs_smem_selector<GmmaMajorB, ElementB,
|
||||
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), false>());
|
||||
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename EpilogueOp::SharedStorage)>;
|
||||
|
||||
static constexpr int PipelineStages = cutlass::gemm::collective::detail::compute_stage_count_or_override<
|
||||
cutlass::gemm::collective::detail::sm90_smem_capacity_bytes,
|
||||
ElementA, ElementB, TileShape_MNK>(StageCountType{});
|
||||
using DispatchPolicy = cutlass::gemm::MainloopSm90TmaGmmaRmemAWarpSpecialized<
|
||||
PipelineStages, ClusterShape_MNK, KernelScheduleType>;
|
||||
|
||||
using CollectiveOp = cutlass::gemm::collective::CollectiveMma<
|
||||
DispatchPolicy,
|
||||
TileShape_MNK,
|
||||
ElementA,
|
||||
cutlass::gemm::TagToStrideA_t<LayoutA>,
|
||||
ElementB,
|
||||
cutlass::gemm::TagToStrideB_t<LayoutB>,
|
||||
TiledMma,
|
||||
GmemTiledCopyA,
|
||||
SmemLayoutAtomA,
|
||||
cute::Copy_Atom<cute::SM75_U32x4_LDSM_N,ElementA>,
|
||||
cute::identity,
|
||||
GmemTiledCopyB,
|
||||
SmemLayoutAtomB,
|
||||
void,
|
||||
cute::identity
|
||||
>;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
EpilogueOp
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise<Gemm>());
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
@ -58,7 +58,7 @@ using namespace cute;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_s8t_s8n_s8n_align8_tensor_op_gmma_s32, 64x128x128) {
|
||||
TEST(SM90_Device_Gemm_s8t_s8n_s8n_align8_tensor_op_gmma_s32, 128x128x128) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
@ -68,14 +68,14 @@ TEST(SM90_Device_Gemm_s8t_s8n_s8n_align8_tensor_op_gmma_s32, 64x128x128) {
|
||||
int8_t, LayoutA, 8,
|
||||
int8_t, LayoutB, 8,
|
||||
int32_t,
|
||||
Shape<_64,_128,_128>, Shape<_1,_1,_1>,
|
||||
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_64,_128,_128>, Shape<_1,_1,_1>,
|
||||
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
int32_t, int32_t,
|
||||
int8_t, LayoutC, 8,
|
||||
@ -93,42 +93,7 @@ TEST(SM90_Device_Gemm_s8t_s8n_s8n_align8_tensor_op_gmma_s32, 64x128x128) {
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_s8t_s8n_s8n_align16_tensor_op_gmma_s32, 128x128x128) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
int8_t, LayoutA, 16,
|
||||
int8_t, LayoutB, 16,
|
||||
int32_t,
|
||||
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelMultistage
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
int32_t, int32_t,
|
||||
int8_t, LayoutC, 8,
|
||||
int8_t, LayoutC, 8,
|
||||
cutlass::epilogue::NoSmemWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_s8t_s8n_s8n_align4_tensor_op_gmma_s32, 128x64x128) {
|
||||
TEST(SM90_Device_Gemm_s8t_s8n_s8n_align4_tensor_op_gmma_s32, 128x128x128) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
@ -138,14 +103,14 @@ TEST(SM90_Device_Gemm_s8t_s8n_s8n_align4_tensor_op_gmma_s32, 128x64x128) {
|
||||
int8_t, LayoutA, 4,
|
||||
int8_t, LayoutB, 4,
|
||||
int32_t,
|
||||
Shape<_128,_64,_128>, Shape<_1,_1,_1>,
|
||||
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_64,_128>, Shape<_1,_1,_1>,
|
||||
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
int32_t, int32_t,
|
||||
int8_t, LayoutC, 4,
|
||||
|
||||
@ -0,0 +1,168 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
|
||||
#include "gemm_testbed_3x.hpp"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
using namespace cute;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_s8t_s8n_s8n_align16_tensor_op_gmma_s32_warpspecialized, 128x128x128) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
int8_t, LayoutA, 16,
|
||||
int8_t, LayoutB, 16,
|
||||
int32_t,
|
||||
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
int32_t, int32_t,
|
||||
int8_t, LayoutC, 16,
|
||||
int8_t, LayoutC, 16,
|
||||
cutlass::epilogue::NoSmemWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_s8t_s8n_s8n_align8_tensor_op_gmma_s32_warpspecialized, 128x128x128) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
int8_t, LayoutA, 8,
|
||||
int8_t, LayoutB, 8,
|
||||
int32_t,
|
||||
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
int32_t, int32_t,
|
||||
int8_t, LayoutC, 8,
|
||||
int8_t, LayoutC, 8,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_s8t_s8n_s8n_align4_tensor_op_gmma_s32_warpspecialized, 128x128x128) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
int8_t, LayoutA, 4,
|
||||
int8_t, LayoutB, 4,
|
||||
int32_t,
|
||||
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
int32_t, int32_t,
|
||||
int8_t, LayoutC, 4,
|
||||
int8_t, LayoutC, 4,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
@ -0,0 +1,168 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
|
||||
#include "gemm_testbed_3x.hpp"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
using namespace cute;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_s8t_s8n_s8n_align16_tensor_op_gmma_s32_warpspecialized_cooperative, 128x128x128) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
int8_t, LayoutA, 16,
|
||||
int8_t, LayoutB, 16,
|
||||
int32_t,
|
||||
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
int32_t, int32_t,
|
||||
int8_t, LayoutC, 16,
|
||||
int8_t, LayoutC, 16,
|
||||
cutlass::epilogue::NoSmemWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_s8t_s8n_s8n_align8_tensor_op_gmma_s32_warpspecialized_cooperative, 128x128x128) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
int8_t, LayoutA, 8,
|
||||
int8_t, LayoutB, 8,
|
||||
int32_t,
|
||||
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
int32_t, int32_t,
|
||||
int8_t, LayoutC, 8,
|
||||
int8_t, LayoutC, 8,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_s8t_s8n_s8n_align4_tensor_op_gmma_s32_warpspecialized_cooperative, 128x128x128) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
int8_t, LayoutA, 4,
|
||||
int8_t, LayoutB, 4,
|
||||
int32_t,
|
||||
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
int32_t, int32_t,
|
||||
int8_t, LayoutC, 4,
|
||||
int8_t, LayoutC, 4,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
@ -0,0 +1,168 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
|
||||
#include "gemm_testbed_3x.hpp"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
using namespace cute;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_s8t_s8n_s8n_align16_tensor_op_gmma_s32_warpspecialized_pingpong, 128x128x128) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
int8_t, LayoutA, 16,
|
||||
int8_t, LayoutB, 16,
|
||||
int32_t,
|
||||
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
int32_t, int32_t,
|
||||
int8_t, LayoutC, 16,
|
||||
int8_t, LayoutC, 16,
|
||||
cutlass::epilogue::NoSmemWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_s8t_s8n_s8n_align8_tensor_op_gmma_s32_warpspecialized_pingpong, 128x128x128) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
int8_t, LayoutA, 8,
|
||||
int8_t, LayoutB, 8,
|
||||
int32_t,
|
||||
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
int32_t, int32_t,
|
||||
int8_t, LayoutC, 8,
|
||||
int8_t, LayoutC, 8,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_s8t_s8n_s8n_align4_tensor_op_gmma_s32_warpspecialized_pingpong, 128x128x128) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
int8_t, LayoutA, 4,
|
||||
int8_t, LayoutB, 4,
|
||||
int32_t,
|
||||
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
int32_t, int32_t,
|
||||
int8_t, LayoutC, 4,
|
||||
int8_t, LayoutC, 4,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
@ -62,10 +62,10 @@ run_scheduler(int* visit_counters, typename Scheduler::Params params, TileShape
|
||||
Scheduler scheduler{params};
|
||||
auto work_tile_info = scheduler.get_current_work();
|
||||
|
||||
while (work_tile_info.is_valid_tile) {
|
||||
while (work_tile_info.is_valid()) {
|
||||
// Increment counters to indicate coverage
|
||||
auto tile_idx = Scheduler::output_tile_index(params, work_tile_info);
|
||||
auto offset = tile_idx * params.k_tiles_per_output_tile_ + work_tile_info.K_idx;
|
||||
auto offset = tile_idx * params.divmod_tiles_per_output_tile_.divisor + work_tile_info.K_idx;
|
||||
for (auto i = 0; i < work_tile_info.k_tile_count; ++i) {
|
||||
// Use atomicAdd because the visit counters are shared by multiple thread blocks.
|
||||
// While having more than one block increment the same counter indicates failure,
|
||||
@ -108,7 +108,7 @@ test_scheduler(
|
||||
|
||||
// Allocate counters indicating the number of times each k iteration of each output tile has been visited
|
||||
auto [blk_m, blk_n, blk_l] = Scheduler::get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape, cluster_shape);
|
||||
auto total_counters = blk_m * blk_n * blk_l * params.k_tiles_per_output_tile_;
|
||||
auto total_counters = blk_m * blk_n * blk_l * params.divmod_tiles_per_output_tile_.divisor;
|
||||
cutlass::DeviceAllocation<int> visit_counters(total_counters);
|
||||
|
||||
// Initialize counters to zero
|
||||
@ -181,8 +181,6 @@ test_scheduler(
|
||||
|
||||
for (size_t i = 0; i < host_visit_counts.size(); ++i) {
|
||||
if (host_visit_counts[i] != 1) {
|
||||
// for (int count : host_visit_counts) {
|
||||
// if (count != 1) {
|
||||
std::cout << "Failed with problem size "
|
||||
<< size<0>(problem_shape_mnkl) << "x"
|
||||
<< size<1>(problem_shape_mnkl) << "x"
|
||||
@ -191,11 +189,12 @@ test_scheduler(
|
||||
<< " and grid size " << grid.x << "x"
|
||||
<< grid.y << "x" << grid.z
|
||||
<< " splits=" << params.splits_
|
||||
<< " k_iter=" << params.k_tiles_per_output_tile_
|
||||
<< " k_iter=" << params.divmod_tiles_per_output_tile_.divisor
|
||||
<< " big_units=" << params.big_units_
|
||||
<< " sk_tiles=" << params.sk_tiles_
|
||||
<< " sk_units=" << params.sk_units_
|
||||
<< " k_tiles_per_sk_unit=" << params.k_tiles_per_sk_unit_ << std::endl;
|
||||
<< " k_tiles_per_sk_unit=" << params.k_tiles_per_sk_unit_
|
||||
<< " units_per_problem=" << params.units_per_problem_ << std::endl;
|
||||
std::cout << "Error at idx: " << i << ". Got count " << host_visit_counts[i] << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -57,42 +57,7 @@ using namespace cute;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align4_tensor_op_gmma_f32, 64x128x32) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
tfloat32_t, LayoutA, 4,
|
||||
tfloat32_t, LayoutB, 4,
|
||||
float,
|
||||
Shape<_64,_128,_32>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelMultistage
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_64,_128,_32>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
float, LayoutC, 4,
|
||||
float, LayoutC, 4,
|
||||
cutlass::epilogue::NoSmemWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align2_tensor_op_gmma_f32, 64x64x32) {
|
||||
TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align2_tensor_op_gmma_f32, 128x64x32) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
@ -102,7 +67,7 @@ TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align2_tensor_op_gmma_f32, 64x64x32) {
|
||||
cutlass::tfloat32_t, LayoutA, 2,
|
||||
cutlass::tfloat32_t, LayoutB, 2,
|
||||
float,
|
||||
Shape<_64,_64,_32>, Shape<_1,_1,_1>,
|
||||
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
@ -0,0 +1,167 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
|
||||
#include "gemm_testbed_3x.hpp"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
using namespace cute;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align4_tensor_op_gmma_f32_warpspecialized, 128x64x32) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
tfloat32_t, LayoutA, 4,
|
||||
tfloat32_t, LayoutB, 4,
|
||||
float,
|
||||
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
float, LayoutC, 4,
|
||||
float, LayoutC, 4,
|
||||
cutlass::epilogue::NoSmemWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align2_tensor_op_gmma_f32_warpspecialized, 128x64x32) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::tfloat32_t, LayoutA, 2,
|
||||
cutlass::tfloat32_t, LayoutB, 2,
|
||||
float,
|
||||
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
float, LayoutC, 2,
|
||||
float, LayoutC, 2,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align1_tensor_op_gmma_f32_warpspecialized, 128x64x32) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::tfloat32_t, LayoutA, 1,
|
||||
cutlass::tfloat32_t, LayoutB, 1,
|
||||
float,
|
||||
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
float, LayoutC, 1,
|
||||
float, LayoutC, 1,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
@ -0,0 +1,167 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
|
||||
#include "gemm_testbed_3x.hpp"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
using namespace cute;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align4_tensor_op_gmma_f32_warpspecialized_cooperative, 128x64x32) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
tfloat32_t, LayoutA, 4,
|
||||
tfloat32_t, LayoutB, 4,
|
||||
float,
|
||||
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
float, LayoutC, 4,
|
||||
float, LayoutC, 4,
|
||||
cutlass::epilogue::NoSmemWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align2_tensor_op_gmma_f32_warpspecialized_cooperative, 128x64x32) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::tfloat32_t, LayoutA, 2,
|
||||
cutlass::tfloat32_t, LayoutB, 2,
|
||||
float,
|
||||
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
float, LayoutC, 2,
|
||||
float, LayoutC, 2,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align1_tensor_op_gmma_f32_warpspecialized_cooperative, 128x64x32) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::tfloat32_t, LayoutA, 1,
|
||||
cutlass::tfloat32_t, LayoutB, 1,
|
||||
float,
|
||||
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
float, LayoutC, 1,
|
||||
float, LayoutC, 1,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
@ -0,0 +1,167 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
|
||||
#include "gemm_testbed_3x.hpp"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
using namespace cute;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align4_tensor_op_gmma_f32_warpspecialized_pingpong, 128x64x32) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
tfloat32_t, LayoutA, 4,
|
||||
tfloat32_t, LayoutB, 4,
|
||||
float,
|
||||
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
float, LayoutC, 4,
|
||||
float, LayoutC, 4,
|
||||
cutlass::epilogue::NoSmemWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align2_tensor_op_gmma_f32_warpspecialized_pingpong, 128x64x32) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::tfloat32_t, LayoutA, 2,
|
||||
cutlass::tfloat32_t, LayoutB, 2,
|
||||
float,
|
||||
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
float, LayoutC, 2,
|
||||
float, LayoutC, 2,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align1_tensor_op_gmma_f32_warpspecialized_pingpong, 128x64x32) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::tfloat32_t, LayoutA, 1,
|
||||
cutlass::tfloat32_t, LayoutB, 1,
|
||||
float,
|
||||
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_64,_32>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
float, LayoutC, 1,
|
||||
float, LayoutC, 1,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
@ -99,7 +99,7 @@ TEST(SM90_Device_Gemm_tf32n_tf32n_f32n_tensor_op_gmma_f32, 64x128x32) {
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::tfloat32_t, LayoutA, 1,
|
||||
cutlass::tfloat32_t, LayoutA, 4,
|
||||
cutlass::tfloat32_t, LayoutB, 4,
|
||||
float,
|
||||
Shape<_64,_128,_32>, Shape<_1,_1,_1>,
|
||||
@ -136,8 +136,8 @@ TEST(SM90_Device_Gemm_tf32n_tf32t_f32n_tensor_op_gmma_f32, 64x128x32) {
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::tfloat32_t, LayoutA, 1,
|
||||
cutlass::tfloat32_t, LayoutB, 1,
|
||||
cutlass::tfloat32_t, LayoutA, 4,
|
||||
cutlass::tfloat32_t, LayoutB, 4,
|
||||
float,
|
||||
Shape<_64,_128,_32>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
@ -149,8 +149,8 @@ TEST(SM90_Device_Gemm_tf32n_tf32t_f32n_tensor_op_gmma_f32, 64x128x32) {
|
||||
Shape<_64,_128,_32>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
float, LayoutC, 1,
|
||||
float, LayoutC, 1,
|
||||
float, LayoutC, 4,
|
||||
float, LayoutC, 4,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
@ -174,7 +174,7 @@ TEST(SM90_Device_Gemm_tf32t_tf32t_f32n_tensor_op_gmma_f32, 64x128x32) {
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::tfloat32_t, LayoutA, 4,
|
||||
cutlass::tfloat32_t, LayoutB, 1,
|
||||
cutlass::tfloat32_t, LayoutB, 4,
|
||||
float,
|
||||
Shape<_64,_128,_32>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
@ -188,7 +188,7 @@ TEST(SM90_Device_Gemm_tf32t_tf32t_f32n_tensor_op_gmma_f32, 64x128x32) {
|
||||
float, float,
|
||||
float, LayoutC, 4,
|
||||
float, LayoutC, 4,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
cutlass::gemm::EpilogueTransposed
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
|
||||
@ -337,6 +337,8 @@ TEST(SM80_Device_Trmm_tf32t_tf32n_f32t_ls_u_un_tensor_op_f32_align1_align4, 128x
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// This test fails on Ada when running with 11.8
|
||||
#if ((__CUDACC_VER_MAJOR__ != 11) || (__CUDACC_VER_MINOR__ != 8) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 890)))
|
||||
TEST(SM80_Device_Trmm_tf32t_tf32n_f32t_ls_u_nu_tensor_op_f32_align1_align4, 256x128x16_128x64x16) {
|
||||
|
||||
using ElementOutput = float;
|
||||
@ -374,6 +376,7 @@ TEST(SM80_Device_Trmm_tf32t_tf32n_f32t_ls_u_nu_tensor_op_f32_align1_align4, 256x
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal<Trmm>());
|
||||
}
|
||||
#endif
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -55,6 +55,7 @@
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// F32 <= F16 * I8 + F32 (Upcast on Operand B)
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i8, 128x128x64_64x64x64_16x8x16) {
|
||||
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
@ -98,6 +99,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i8, 64x64x64_64x64x64_16
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// F32 <= I8 * F16 + F32 (Upcast on Operand A)
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_f16, 128x128x64_64x64x64_16x8x16) {
|
||||
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
@ -118,7 +120,6 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_f16, 128x128x64_64x64x64_
|
||||
.run();
|
||||
}
|
||||
|
||||
|
||||
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_f16, 64x64x64_64x64x64_16x8x16) {
|
||||
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
@ -142,6 +143,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_f16, 64x64x64_64x64x64_16
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// F32 <= F16 * U8 + F32 (Upcast on Operand B)
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_u8, 64x64x64_64x64x64_16x8x16) {
|
||||
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
@ -185,6 +187,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_u8, 128x128x64_64x64x64_
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// F32 <= U8 * F16 + F32 (Upcast on Operand A)
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_f16, 64x64x64_64x64x64_16x8x16) {
|
||||
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
@ -225,10 +228,10 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_f16, 128x128x64_64x64x64_
|
||||
.run();
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// F32 <= B16 * U8 + F32 (Upcast on Operand B)
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_bf16_u8, 64x64x64_64x64x64_16x8x16) {
|
||||
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
@ -252,6 +255,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_bf16_u8, 64x64x64_64x64x64_1
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// F32 <= U8 * BF16 + F32 (Upcast on Operand A)
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_bf16, 64x64x64_64x64x64_16x8x16) {
|
||||
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
@ -273,8 +277,9 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_bf16, 64x64x64_64x64x64_1
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// F32 <= B16 * I8 + F32 (Upcast on Operand B)
|
||||
/// F32 <= I8 * BF16 + F32 (Upcast on Operand A)
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_bf16_i8, 64x64x64_64x64x64_16x8x16) {
|
||||
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
@ -296,8 +301,9 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_bf16_i8, 64x64x64_64x64x64_1
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// F32 <= I8 * BF16 + F32 (Upcast on Operand A)
|
||||
/// F32 <= B16 * I8 + F32 (Upcast on Operand B)
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_bf16, 64x64x64_64x64x64_16x8x16) {
|
||||
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
@ -318,4 +324,4 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_bf16, 64x64x64_64x64x64_1
|
||||
.run();
|
||||
}
|
||||
|
||||
#endif // if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
#endif // if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
@ -44,7 +44,7 @@ void rmsnorm_host(cutlass::MatrixCoord tensor_size,
|
||||
cutlass::TensorRef<ElementType, Layout> output,
|
||||
cutlass::TensorRef<ElementType, Layout> input,
|
||||
cutlass::TensorRef<ElementType, Layout> weight,
|
||||
float epsilon) {
|
||||
float epsilon) {
|
||||
const int M = tensor_size.row();
|
||||
const int N = tensor_size.column();
|
||||
|
||||
@ -94,7 +94,7 @@ void run_test(int M, int N) {
|
||||
|
||||
rmsnorm_host({M, N}, output_ref.host_ref(), input.host_ref(), weight.host_ref(), (float)1e-5);
|
||||
cutlass::rmsnorm({M, N}, output.device_ref(),
|
||||
input.device_ref(), weight.device_ref(), NULL, (float)1e-5);
|
||||
input.device_ref(), weight.device_ref(), NULL, (float)1e-5L);
|
||||
|
||||
output.sync_host();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user