CUTLASS 3.3.0 (#1167)

* Release 3.3.0

Adds support for mixed precision GEMMs On Hopper and Ampere
Adds support for < 16B aligned GEMMs on Hopper
Enhancements to EVT
Enhancements to Python interface
Enhancements to Sub-byte type handling in CuTe
Several other bug-fixes and performance improvements.

* minor doc update
This commit is contained in:
Pradeep Ramani
2023-11-02 08:09:05 -07:00
committed by GitHub
parent 922fb5108b
commit c008b4aea8
263 changed files with 16214 additions and 5008 deletions

View File

@ -36,8 +36,9 @@ Utilities for defining Conv2D problem sizes for testing.
This file was ported from the C++ version in test/unit/conv/device/conv2d_problems.h
"""
from cutlass_library import ConvMode
import cutlass
from cutlass import ConvMode
from cutlass.shape import Conv2DProblemSize

View File

@ -34,10 +34,11 @@
Utility functions for Conv2d tests.
"""
from cutlass_library import SubstituteTemplate
import torch
import cutlass
from cutlass import (
from cutlass_library import (
ConvKind,
ConvMode,
DataType,
@ -50,7 +51,6 @@ from cutlass import (
ShortLayoutTypeNames,
SplitKMode,
)
from cutlass.backend.utils.software import SubstituteTemplate
from cutlass.shape import Conv2DProblemSize
from cutlass.utils.datatypes import numpy_type, torch_type
@ -301,17 +301,19 @@ class Conv2dLauncherFrontend:
tensor_B = self.uniform_init(size=tensor_B_size, dtype=self.dtype_B)
tensor_C = self.uniform_init(size=tensor_C_size, dtype=self.dtype_C)
tensor_D = torch.zeros_like(tensor_C).to(memory_format=torch.channels_last)
self.operation.run(tensor_A, tensor_B, tensor_C, tensor_D,
args = self.operation.run(tensor_A, tensor_B, tensor_C, tensor_D,
stride=(ps.stride_h, ps.stride_w),
padding=(ps.pad_h, ps.pad_w),
dilation=(ps.dilation_h, ps.dilation_w),
alpha=alpha, beta=beta,
split_k=(split_k_mode, split_k_slices))
args.sync()
tensor_D_ref = self.reference(ps, tensor_A, tensor_B, tensor_C, alpha, beta, self.activation)
torch.cuda.synchronize()
passed = torch.equal(tensor_D, tensor_D_ref)
passed = torch.allclose(tensor_D, tensor_D_ref, atol=2e-06)
return passed
@ -378,7 +380,8 @@ def add_test(
conv2d_launcher = Conv2dLauncherFrontend(plan, 80, backend="torch")
for ps in problem_sizes:
if not validate_problem_size(ps, conv_kind, split_k_slices): continue
if not validate_problem_size(ps, conv_kind, split_k_slices):
continue
self.assertTrue(conv2d_launcher.run(ps, split_k_mode, split_k_slices, 1.0, 2.0))

View File

@ -38,9 +38,11 @@ import random
import tempfile
import unittest
from cutlass_library import ConvMode
import cutlass
if cutlass.utils.datatypes.torch_available:
if cutlass.utils.datatypes.is_torch_available():
import torch
@ -88,7 +90,7 @@ def _generate_problems(dtype, num):
def _generate_conv2d_problem(conv_kind, dtype, ps):
"""
Utility function to generate conv2d inputs
:param conv_kind: kind of convolution
:type conv_kind: str
:param dtype: data type of tensors
@ -114,7 +116,7 @@ def _generate_conv2d_problem(conv_kind, dtype, ps):
return [torch.ceil(torch.empty(size, dtype=dtype, device='cuda').uniform_(-4.5, 3.5)).to(memory_format=torch.channels_last) for size in sizes]
@unittest.skipIf(not cutlass.utils.datatypes.torch_available, 'PyTorch must be available to run PyTorch extension tests')
@unittest.skipIf(not cutlass.utils.datatypes.is_torch_available(), 'PyTorch must be available to run PyTorch extension tests')
class PyTorchExtensionTest(unittest.TestCase):
def test_gemm(self):
@ -183,18 +185,18 @@ class PyTorchExtensionTest(unittest.TestCase):
Ds_ref = [(a @ b) * alpha + (beta * c) for a, b, c in zip(As, Bs, Cs)]
Ds = mod.run(As, Bs, Cs, alpha, beta)
check_all(Ds, Ds_ref)
def test_conv2d_fprop(self):
torch.manual_seed(2023)
dtype = torch.float16
plan = cutlass.op.Conv2d(kind="fprop", element=dtype, element_accumulator=torch.float32)
plan.activation = "relu"
op = plan.construct()
with tempfile.TemporaryDirectory() as tmpdir:
mod = cutlass.emit.pytorch(op, name="conv2d_mod", cc=plan.cc, sourcedir=tmpdir, jit=True)
problem_size = cutlass.shape.Conv2DProblemSize(
1, 4, 4, 16,
8, 3, 3, 16,
@ -202,50 +204,50 @@ class PyTorchExtensionTest(unittest.TestCase):
3, 3,
1, 1
)
A, B, C = _generate_conv2d_problem("fprop", dtype, problem_size)
stride = (problem_size.stride_h, problem_size.stride_w)
padding = (problem_size.pad_h, problem_size.pad_w)
alpha = 1.0
beta = 0.5
D_ref = alpha * torch.ops.aten.conv2d(
A, B, stride=stride, padding=padding
) + beta * C
D_ref = torch.nn.functional.relu(D_ref)
D = mod.run(A, B, C, stride, padding, alpha=alpha, beta=beta)
assert torch.allclose(D, D_ref)
assert torch.allclose(D, D_ref)
# Test serial split-K
D_serial_split_k = mod.run(A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="serial", split_k_slices=3)
assert torch.allclose(D, D_serial_split_k)
# Test parallel split-K
D_parallel_split_k = mod.run(A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="parallel", split_k_slices=7)
assert torch.allclose(D, D_parallel_split_k)
def test_conv2d_dgrad(self):
torch.manual_seed(2023)
dtype = torch.float16
plan = cutlass.op.Conv2d(kind="dgrad", element=dtype, element_accumulator=torch.float32)
op = plan.construct()
with tempfile.TemporaryDirectory() as tmpdir:
mod = cutlass.emit.pytorch(op, name="conv2d_dgrad_mod", cc=plan.cc, sourcedir=tmpdir, jit=True)
problem_size = cutlass.shape.Conv2DProblemSize(
1, 4, 4, 16,
8, 3, 3, 16,
0, 0,
3, 3,
1, 1,
cutlass.ConvMode.CrossCorrelation,
ConvMode.CrossCorrelation,
1, 1
)
A, B, C = _generate_conv2d_problem("dgrad", dtype, problem_size)
stride = (problem_size.stride_h, problem_size.stride_w)
padding = (problem_size.pad_h, problem_size.pad_w)
@ -254,32 +256,32 @@ class PyTorchExtensionTest(unittest.TestCase):
beta = 0.5
input_size = (problem_size.N, problem_size.C, problem_size.H, problem_size.W)
D_ref = alpha * torch.nn.grad.conv2d_input(
input_size, B, A,
input_size, B, A,
stride=stride, padding=padding
) + beta * C
D = mod.run(input_size, A, B, C, stride, padding, alpha=alpha, beta=beta, )
assert torch.allclose(D, D_ref)
assert torch.allclose(D, D_ref)
def test_conv2d_wgrad(self):
torch.manual_seed(2023)
dtype = torch.float16
plan = cutlass.op.Conv2d(kind="wgrad", element=dtype, element_accumulator=torch.float32)
op = plan.construct()
with tempfile.TemporaryDirectory() as tmpdir:
mod = cutlass.emit.pytorch(op, name="conv2d_wgrad_mod", cc=plan.cc, sourcedir=tmpdir, jit=True)
problem_size = cutlass.shape.Conv2DProblemSize(
1, 4, 4, 16,
8, 3, 3, 16,
0, 0,
3, 3,
1, 1,
cutlass.ConvMode.CrossCorrelation,
ConvMode.CrossCorrelation,
1, 1
)
A, B, C = _generate_conv2d_problem("wgrad", dtype, problem_size)
stride = (problem_size.stride_h, problem_size.stride_w)
padding = (problem_size.pad_h, problem_size.pad_w)
@ -288,17 +290,17 @@ class PyTorchExtensionTest(unittest.TestCase):
beta = 0.5
weight_size = (problem_size.K, problem_size.C, problem_size.R, problem_size.S)
D_ref = alpha * torch.nn.grad.conv2d_weight(
B, weight_size, A,
B, weight_size, A,
stride=stride, padding=padding
) + beta * C
D = mod.run(weight_size, A, B, C, stride, padding, alpha=alpha, beta=beta)
assert torch.allclose(D, D_ref)
assert torch.allclose(D, D_ref)
# Test serial split-K
D_serial_split_k = mod.run(weight_size, A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="serial", split_k_slices=3)
assert torch.allclose(D, D_serial_split_k)
# Test parallel split-K
D_parallel_split_k = mod.run(weight_size, A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="parallel", split_k_slices=7)
assert torch.allclose(D, D_parallel_split_k)

View File

@ -40,9 +40,9 @@ import unittest
import cutlass
from cutlass import Tensor
import cutlass.backend.evt
from cutlass.profiler import CUDAEventProfiler
from cutlass.shape import GemmCoord
from cutlass.utils.datatypes import torch_type
from cutlass.utils.profiler import CUDAEventProfiler
class EVTReferenceModule:

View File

@ -43,7 +43,7 @@ import cutlass
from cutlass.backend.utils.device import device_cc
import torch
from utils import LayoutCombination, add_test_gemm
from utils import LayoutCombination
cutlass.set_log_level(logging.WARNING)

View File

@ -67,58 +67,58 @@ add_test_specialized = partial(add_test_gemm, element=cutlass.DataType.f16, cc=c
# Tests using TensorOp
add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp)
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NTN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NTN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NTT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NTT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TTN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TTN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TTT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TTT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 128, 32], warp_count=[1, 2, 1], stages=3)
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 64, 32], warp_count=[2, 1, 1], stages=3)
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=3)
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass.DataType.f16,
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass.DataType.f16,
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=5)
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[2, 2, 2], element_output=cutlass.DataType.f16,
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[2, 2, 2], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
# Tests using SIMT
add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt)
add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16,
add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16,
add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2)
add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16,
add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2)
add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16,
add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2)
add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.f16,
add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f16, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
# Stream K tests
add_test_streamk = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK)
add_test_streamk(cls=GemmF16Sm80StreamK, layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
add_test_streamk(cls=GemmF16Sm80StreamK, layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
add_test_streamk(cls=GemmF16Sm80StreamK, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
add_test_streamk(cls=GemmF16Sm80StreamK, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=5)
if __name__ == '__main__':

View File

@ -135,6 +135,10 @@ add_test_simt(layouts=LayoutCombination.NTN, element_output=cutlass.DataType.f16
add_test_simt(layouts=LayoutCombination.TTN, element_output=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 8])
add_test_simt(layouts=LayoutCombination.NNT, element_output=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f16, threadblock_shape=[128, 128, 8])
# Tests with void-C kernels
add_test_cluster_shape(layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], stages=None,
cluster_shape=[2, 1, 1], element_C=cutlass.DataType.void)
if __name__ == '__main__':
unittest.main()

View File

@ -68,31 +68,31 @@ add_test_specialized = partial(add_test_gemm, element=cutlass.DataType.f32, cc=c
# Tests using TensorOp
add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp)
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32,
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[4, 4, 4], element_output=cutlass.DataType.f32,
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32,
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 128, 32], warp_count=[1, 2, 1], stages=3)
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32,
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 32], warp_count=[1, 1, 1], stages=4)
# Tests using SIMT
add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt)
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32,
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32,
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2)
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32,
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2)
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32,
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2)
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.f32,
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
# Stream K tests
add_test_streamk = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK)
add_test_streamk(cls=GemmF32Sm80StreamK, layouts=LayoutCombination.TTN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32,
add_test_streamk(cls=GemmF32Sm80StreamK, layouts=LayoutCombination.TTN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)

View File

@ -68,30 +68,30 @@ add_test_specialized = partial(add_test_gemm, element=cutlass.DataType.f64, cc=c
# Tests using TensorOp
add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp)
add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64,
add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3)
add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64,
add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
element_accumulator=cutlass.DataType.f64, threadblock_shape=[ 64, 64, 16], warp_count=[2, 2, 1], stages=4)
add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64,
add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
element_accumulator=cutlass.DataType.f64, threadblock_shape=[ 32, 32, 16], warp_count=[2, 1, 1], stages=5)
# Tests using SIMT
add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt)
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64,
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64,
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
element_accumulator=cutlass.DataType.f64, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2)
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64,
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2)
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64,
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
element_accumulator=cutlass.DataType.f64, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2)
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.f64,
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
# Stream K tests
add_test_streamk = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK)
add_test_streamk(cls=GemmF64Sm80StreamK, layouts=LayoutCombination.NTT, alignments=[1, 1, 1], element_output=cutlass.DataType.f64,
add_test_streamk(cls=GemmF64Sm80StreamK, layouts=LayoutCombination.NTT, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3)

View File

@ -0,0 +1,72 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Low-level functionality tests for GEMM with mixed operands on SM80
"""
from functools import partial
import logging
import unittest
import cutlass
from cutlass.backend.utils.device import device_cc
from utils import LayoutCombination, add_test_gemm
cutlass.set_log_level(logging.WARNING)
cc = 80
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
class GemmMixedSm80(unittest.TestCase):
"""
Wrapper class to which tests will be added dynamically in __main__
"""
pass
add_test_mixed = partial(add_test_gemm, cls=GemmMixedSm80, element=cutlass.DataType.f16, cc=cc, cluster_shape=[1, 1, 1],
opclass=cutlass.OpcodeClass.TensorOp, threadblock_shape=[128, 128, 64],
warp_count=[2, 2, 1], stages=3, element_accumulator=cutlass.DataType.f32)
# Test with upcast on A
add_test_mixed(element_A=cutlass.DataType.s8, alignments=[16, 8, 8], layouts=LayoutCombination.TNT)
add_test_mixed(element_A=cutlass.DataType.s8, alignments=[16, 8, 8], layouts=LayoutCombination.TNN)
# Test with upcast on B
add_test_mixed(element_B=cutlass.DataType.s8, alignments=[8, 16, 8], layouts=LayoutCombination.TNT)
add_test_mixed(element_B=cutlass.DataType.s8, alignments=[8, 16, 8], layouts=LayoutCombination.TNN)
if __name__ == '__main__':
unittest.main()

View File

@ -68,30 +68,30 @@ add_test_specialized = partial(add_test_gemm, element=cutlass.DataType.s8, cc=cc
# Tests using TensorOp
add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp)
add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=cutlass.DataType.s8,
add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=cutlass.DataType.s8, element_C=cutlass.DataType.s8,
element_accumulator=cutlass.DataType.s32, threadblock_shape=[256, 128, 64], warp_count=[4, 2, 1], stages=3)
add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.s8,
add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.s8, element_C=cutlass.DataType.s8,
element_accumulator=cutlass.DataType.s32, threadblock_shape=[128, 256, 64], warp_count=[2, 4, 1], stages=3)
add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[16, 16, 4], element_output=cutlass.DataType.s32,
add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[16, 16, 4], element_output=cutlass.DataType.s32, element_C=cutlass.DataType.s32,
element_accumulator=cutlass.DataType.s32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=4)
# Tests using SIMT
add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt)
add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.s8,
add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.s8, element_C=cutlass.DataType.s8,
element_accumulator=cutlass.DataType.s32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.s8,
add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.s8, element_C=cutlass.DataType.s8,
element_accumulator=cutlass.DataType.s32, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2)
add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.s8,
add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.s8, element_C=cutlass.DataType.s8,
element_accumulator=cutlass.DataType.s32, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2)
add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.s32,
add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.s32, element_C=cutlass.DataType.s32,
element_accumulator=cutlass.DataType.s32, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2)
add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.s32,
add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.s32, element_C=cutlass.DataType.s32,
element_accumulator=cutlass.DataType.s32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
# Stream K tests
add_test_streamk = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK)
add_test_streamk(cls=GemmS8Sm80StreamK, layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.s8,
add_test_streamk(cls=GemmS8Sm80StreamK, layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.s8, element_C=cutlass.DataType.s8,
element_accumulator=cutlass.DataType.s32, threadblock_shape=[128, 256, 64], warp_count=[2, 4, 1], stages=3)

View File

@ -37,7 +37,7 @@ import subprocess
import torch
from cutlass import (
from cutlass_library import (
DataType,
DataTypeSize,
GemmUniversalMode,
@ -49,7 +49,6 @@ from cutlass import (
from cutlass.backend import compiler
from cutlass.backend.gemm_operation import GemmArguments, GemmOperationUniversal
from cutlass.backend.memory_manager import get_allocated_size
from cutlass.backend.reduction_operation import ReductionArguments, ReductionOperation
from cutlass.shape import GemmCoord, MatrixCoord
from cutlass.utils.datatypes import torch_type
@ -65,16 +64,6 @@ class GemmUniversalLauncher:
compiler_mode= "nvcc",
**kwargs,
) -> None:
# Create the reduction kernel, if needed
self.reduction_operation: ReductionOperation = ReductionOperation(
shape=MatrixCoord(4, 32 * operation.C.alignment),
C=operation.C,
element_accumulator=operation.tile_description.math_instruction.element_accumulator,
element_compute=operation.epilogue_functor.element_epilogue,
epilogue_functor=operation.epilogue_functor,
count=operation.C.alignment,
)
self.math_operation = operation.tile_description.math_instruction.math_operation
self.verification = verification
@ -88,19 +77,26 @@ class GemmUniversalLauncher:
op_list = [operation]
if operation.arch < 90:
# Split K via Python is currently only supported for pre-SM90 kernels
self.reduction_operation: ReductionOperation = ReductionOperation(
shape=MatrixCoord(4, 32 * operation.C.alignment),
C=operation.C,
element_accumulator=operation.tile_description.math_instruction.element_accumulator,
element_compute=operation.epilogue_functor.element_epilogue,
epilogue_functor=operation.epilogue_functor,
count=operation.C.alignment,
)
op_list.append(self.reduction_operation)
compiler.add_module(op_list, bypass_cache=False)
self.operation = operation
self.dtype_A = torch_type(operation.A.element)
self.dtype_B = torch_type(operation.B.element)
self.dtype_A = torch_type(operation.A.element if not self.operation.switched else self.operation.B.element)
self.dtype_B = torch_type(operation.B.element if not self.operation.switched else self.operation.A.element)
self.dtype_C = torch_type(operation.C.element)
self.dtype_D = torch_type(operation.C.element)
self.dtype_D = torch_type(operation.epilogue_functor.element_output)
accumulator_size = DataTypeSize[operation.tile_description.math_instruction.element_accumulator]
element_size = DataTypeSize[operation.A.element]
element_size = min(DataTypeSize[operation.A.element], DataTypeSize[operation.B.element])
if element_size == 1:
self.rand_max = 1
@ -154,7 +150,18 @@ class GemmUniversalLauncher:
def reference(self, problem_size, tensor_A, tensor_B, tensor_C, alpha, beta):
# If any tensor is on CPU, place all tensors on CPU unless only
# tensor C is on CPU
devices = [x.device.type for x in [tensor_A, tensor_B, tensor_C]]
# Handle mixed-input cases by casting to the larger data type and overriding
# to whatever the data type of the larger type is
if self.dtype_A != self.dtype_B:
if DataTypeSize[self.operation.A.element] < DataTypeSize[self.operation.B.element]:
tensor_A = tensor_A.to(self.dtype_B).to(tensor_B.device)
else:
tensor_B = tensor_B.to(self.dtype_A).to(tensor_A.device)
devices = [x.device.type for x in [tensor_A, tensor_B]]
if tensor_C is not None:
devices.append(tensor_C.device.type)
if "cpu" in devices and devices != ["cuda", "cuda", "cpu"]:
device = torch.device("cpu")
else:
@ -162,14 +169,17 @@ class GemmUniversalLauncher:
tensor_A = tensor_A.to(device)
tensor_B = tensor_B.to(device)
tensor_C = tensor_C.to(device)
if tensor_C is not None:
tensor_C = tensor_C.to(device)
dtype = torch_type(self.compute_type)
alpha_torch = torch.tensor([alpha], device=device).to(dtype)
beta_torch = torch.tensor([beta], device=device).to(dtype)
tmp = tensor_A @ tensor_B
tensor_D_ref = (alpha_torch * tmp) + (tensor_C * beta_torch)
tensor_D_ref = (alpha_torch * tmp)
if tensor_C is not None:
tensor_D_ref += (tensor_C * beta_torch)
return tensor_D_ref.to(self.dtype_D)
def run(self, mode, problem_size, batch_count=1, split_k_slices=1, alpha=1.0, beta=0.0):
@ -199,12 +209,22 @@ class GemmUniversalLauncher:
self.dtype_B,
self.operation.B.layout if not self.operation.switched else transpose(self.operation.A.layout),
)
tensor_C, tensor_C_ref = self.uniform_init(
if self.dtype_C is not None:
tensor_C, tensor_C_ref = self.uniform_init(
(true_batch_count, problem_size.m, problem_size.n),
self.dtype_C,
self.operation.C.layout if not self.operation.switched else transpose(self.operation.C.layout),
)
else:
tensor_C = None
tensor_C_ref = None
tensor_D, _ = self.uniform_init(
(true_batch_count, problem_size.m, problem_size.n),
self.dtype_C,
self.dtype_D,
self.operation.C.layout if not self.operation.switched else transpose(self.operation.C.layout),
)
tensor_D = torch.zeros_like(tensor_C)
tensor_D = torch.zeros_like(tensor_D)
if self.compute_type in [DataType.s8, DataType.s32, DataType.u8, DataType.u32]:
alpha = int(alpha)
@ -248,6 +268,10 @@ class GemmUniversalLauncher:
if self.verification:
if mode == GemmUniversalMode.GemmSplitKParallel:
reduction_arguments.sync()
# Free memory allocated by args because we are not
# calling `arguments.sync()` in this case (which will free memory)
arguments.free()
else:
arguments.sync()
tensor_D_ref = self.reference(
@ -274,9 +298,6 @@ class GemmUniversalLauncher:
if mode == GemmUniversalMode.GemmSplitKParallel:
del reduction_arguments
cur_size = get_allocated_size()
assert cur_size == 0, f"{cur_size} B of memory were not released after this run"
return passed

View File

@ -30,9 +30,10 @@
#
#################################################################################################
import cutlass
from cutlass_library import SubstituteTemplate
from cutlass import (
import cutlass
from cutlass_library import (
DataTypeNames,
EpilogueScheduleSuffixes,
KernelScheduleSuffixes,
@ -42,7 +43,6 @@ from cutlass import (
ShortLayoutTypeNames
)
from cutlass.backend import library
from cutlass.backend.utils.software import SubstituteTemplate
from gemm_testbed import test_all_gemm
@ -82,6 +82,7 @@ def get_name(
stages,
element_a,
element_b,
element_c,
arch,
opclass,
kernel_schedule=None,
@ -102,6 +103,7 @@ def get_name(
:type stages: int
:param element_a: data type of operand A
:param element_b: data type of operand B
:param element_c: data type of operand C
:param arch: compute capability of kernel being generated
:type arch: int
:param opclass: class of operation being performed (e.g., SIMT, Tensor Core)
@ -122,7 +124,7 @@ def get_name(
"arch": str(arch),
"eA": DataTypeNames[element_a],
"eB": DataTypeNames[element_b],
"eC": DataTypeNames[element_output],
"eC": DataTypeNames[element_c],
"lA": ShortLayoutTypeNames[layouts[0]],
"lB": ShortLayoutTypeNames[layouts[1]],
"lC": ShortLayoutTypeNames[layouts[2]],
@ -161,7 +163,10 @@ def add_test_gemm(
swizzle=None,
kernel_schedule=None,
epilogue_schedule=None,
compilation_modes=['nvcc', 'nvrtc']):
compilation_modes=['nvcc', 'nvrtc'],
element_A=None,
element_B=None,
element_C=None):
"""
Create test-running functions with the given specification and set it as a method of ``cls``.
@ -195,22 +200,38 @@ def add_test_gemm(
:param epilogue_schedule: epilogue schedule to use
:type epilogue_schedule: cutlass.EpilogueScheduleType
:param compilation_modes: list of compilers to used in testing the kernel (options: 'nvrtc', 'nvcc')
:type compilation_modes: list
:type compilation_modes: list,
:param element_A: data type of operand A. If set, overrides ``element``
:type element_A: cutlass.DataType
:param element_B: data type of operand B. If set, overrides ``element``
:type element_B: cutlass.DataType
:param element_C: data type of operand C. If set, overrides ``element``
:type element_C: cutlass.DataType
"""
if element_A is None:
element_A = element
if element_B is None:
element_B = element
if element_C is None:
element_C = element
if element_output is None:
element_output = element
if element_accumulator is None:
element_accumulator = element
for compilation_mode in compilation_modes:
def run(self):
"""
Dynamically-generated function that constructs a GEMM operation and verifies it against
multiple test cases.
"""
element_A = element
element_B = element
layout_A, layout_B, layout_C = layouts
alignment_A, alignment_B, alignment_C = alignments
plan = cutlass.op.Gemm(element_A=element_A, element_B=element_B,
element_C=element_output, element_D=element_output,
element_C=element_C, element_D=element_output,
layout_A=layout_A, layout_B=layout_B, layout_C=layout_C,
element_accumulator=element_accumulator,
kernel_cc=cc)
@ -233,7 +254,7 @@ def add_test_gemm(
name = get_name(
layouts=layouts, alignments=alignments, element_output=element_output, element_accumulator=element_accumulator,
element_epilogue=element_epilogue, cluster_shape=cluster_shape, threadblock_shape=threadblock_shape,
stages=stages, element_a=element, element_b=element, arch=cc, opclass=opclass,
stages=stages, element_a=element_A, element_b=element_B, element_c=element_C, arch=cc, opclass=opclass,
kernel_schedule=kernel_schedule, epilogue_schedule=epilogue_schedule, suffix=f'_{compilation_mode}')
setattr(cls, name, run)

View File

@ -0,0 +1,57 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Tests for a successful installation of the CUTLASS Python interface
"""
import os
import unittest
import cutlass
import cutlass_library
class InstallationTest(unittest.TestCase):
def test_cutlass_source_paths(self):
"""
Tests that CUTLASS source is available as part of the cutlass and cutlass_library packages
"""
src_file = 'include/cutlass/cutlass.h'
library_file = os.path.join(cutlass_library.source_path, src_file)
cutlass_file = os.path.join(cutlass.CUTLASS_PATH, src_file)
assert os.path.isfile(library_file), f"Unable to locate file {library_file}. Installation has not succeeded."
assert os.path.isfile(cutlass_file), f"Unable to locate file {cutlass_file}. Installation has not succeeded."
if __name__ == "__main__":
unittest.main()

View File

@ -50,7 +50,7 @@ class Conv2dEquivalence:
"""
def __init__(self, conv_kind, element_A, element_B, element_C, element_D, element_accumulator,
alignment_A, alignment_B, alignment_C):
self.element_A = element_A
self.element_B = element_B
self.element_C = element_C
@ -59,21 +59,21 @@ class Conv2dEquivalence:
self.alignment_A = alignment_A
self.alignment_B = alignment_B
self.alignment_C = alignment_C
self.conv_kind = conv_kind
self.plan = cutlass.op.Conv2d(
kind=self.conv_kind, element_A=element_A, element_B=element_B, element_C=element_C,
element_D=element_D, element_accumulator=element_accumulator)
self.op = self.plan.construct(
alignment_A=self.alignment_A, alignment_B=self.alignment_B,
alignment_A=self.alignment_A, alignment_B=self.alignment_B,
alignment_C=self.alignment_C)
def _plans_equal(self, other_plan) -> bool:
"""
Compares whether two plans are equal
:param other_plan: plan to compare against the default Conv2d
:type other_plan: cutlass.op.Conv2d
@ -81,9 +81,9 @@ class Conv2dEquivalence:
:rtype: bool
"""
other_op = other_plan.construct(
alignment_A=self.alignment_A, alignment_B=self.alignment_B,
alignment_A=self.alignment_A, alignment_B=self.alignment_B,
alignment_C=self.alignment_C)
return self.op.rt_module.emit() == other_op.rt_module.emit()
def generic_test(self):
@ -91,16 +91,16 @@ class Conv2dEquivalence:
Tests the equivalence of various constructions of the Conv2d interface when using CUTLASS data types
and layouts for constructing the Conv2d interface
"""
if not datatypes.numpy_available:
if not datatypes.is_numpy_available():
return
# Test when specifying all parameters
plan_other = cutlass.op.Conv2d(
kind=self.conv_kind,
element_A=self.element_A, element_B=self.element_B, element_C=self.element_C,
element_D=self.element_D, element_accumulator=self.element_accumulator)
assert self._plans_equal(plan_other)
# Test when specifying all parameters but A
plan_other = cutlass.op.Conv2d(
kind=self.conv_kind,
@ -108,7 +108,7 @@ class Conv2dEquivalence:
element_D=self.element_D, element_accumulator=self.element_accumulator,
element=self.element_A)
assert self._plans_equal(plan_other)
# Test when specifying all parameters but A and B as tensors using generic element and output
plan_other = cutlass.op.Conv2d(
kind=self.conv_kind,
@ -116,7 +116,7 @@ class Conv2dEquivalence:
element_D=self.element_D, element_accumulator=self.element_accumulator,
element=self.element_A)
assert self._plans_equal(plan_other)
# Test without explicit accumulator. Only run if the type of C and the accumulator are equal
if self.element_C == self.element_accumulator:
plan_other = cutlass.op.Conv2d(
@ -125,18 +125,18 @@ class Conv2dEquivalence:
element_D=self.element_D,
element=self.element_A)
assert self._plans_equal(plan_other)
# Test with only the generic types. Only rune if the types of A, B, C, and D are the same
if (self.element_A == self.element_B and self.element_A == self.element_C and self.element_A == self.element_D
and self.element_A == self.element_accumulator):
plan_other = cutlass.op.Conv2d(kind=self.conv_kind, element=self.element_A)
assert self._plans_equal(plan_other)
def numpy_test(self):
"""
Tests the equivalence of various constructions of the Conv2d interface when using numpy as a frontend
"""
if not datatypes.numpy_available:
if not datatypes.is_numpy_available():
return
import numpy as np
@ -145,7 +145,7 @@ class Conv2dEquivalence:
type_C = datatypes.numpy_type(self.element_C)
type_D = datatypes.numpy_type(self.element_D)
type_accum = datatypes.numpy_type(self.element_accumulator)
size = (2, 2)
A = np.zeros(size, dtype=type_A)
B = np.zeros(size, dtype=type_B)
@ -153,49 +153,49 @@ class Conv2dEquivalence:
D = np.zeros(size, dtype=type_D)
return self.tensor_test(type_A, type_B, type_C, type_D, type_accum, A, B, C, D)
def torch_test(self):
"""
Tests the equivalence of various constructions of the Conv2d interface when using torch as a frontend
"""
if not datatypes.torch_available:
if not datatypes.is_torch_available():
return
import torch
type_A = datatypes.torch_type(self.element_A)
type_B = datatypes.torch_type(self.element_B)
type_C = datatypes.torch_type(self.element_C)
type_D = datatypes.torch_type(self.element_D)
type_accum = datatypes.torch_type(self.element_accumulator)
size = (2, 2)
A = torch.empty(size, dtype=type_A)
B = torch.empty(size, dtype=type_B)
C = torch.empty(size, dtype=type_C)
D = torch.empty(size, dtype=type_D)
return self.tensor_test(type_A, type_B, type_C, type_D, type_accum, A, B, C, D)
def tensor_test(self, type_A, type_B, type_C, type_D, type_accum, A, B, C, D):
# Test when specifying all parameters via tensors
plan_np = cutlass.op.Conv2d(kind=self.conv_kind, A=A, B=B, C=C, D=D, element_accumulator=type_accum)
assert self._plans_equal(plan_np)
# Test when specifying all parameters but A as tensors
plan_np = cutlass.op.Conv2d(kind=self.conv_kind, B=B, C=C, D=D, element_accumulator=type_accum, element_A=type_A)
assert self._plans_equal(plan_np)
# Test when specifying all parameters but A and B as tensors and using generic element and output
if type_A == type_B:
plan_np = cutlass.op.Conv2d(kind=self.conv_kind, C=C, D=D, element_accumulator=type_accum, element=type_A)
assert self._plans_equal(plan_np)
# Test without explicit accumulator. Only run if the type of C and the accumulator.
if type_C == type_accum:
plan_np = cutlass.op.Conv2d(kind=self.conv_kind, A=A, B=B, C=C, D=D)
assert self._plans_equal(plan_np)
# Test with only the generic types and layouts. Only run if types and layouts of A, B, C, and D are the same.
if (type_A == type_B and type_A == type_C and type_A == type_D and type_A == type_accum):
plan_np = cutlass.op.Conv2d(kind=self.conv_kind, element=type_A)
@ -223,20 +223,20 @@ type2alignment = {
}
def add_test(conv_kind, element_A, element_B, element_C, element_D, element_accumulator):
test_name = f"test_conv2d_{conv_kind}_{element_A}_{element_B}_{element_C}_{element_D}_{element_accumulator}"
def run(self):
conv2d_eq = Conv2dEquivalence(
conv_kind=conv_kind,
conv_kind=conv_kind,
element_A=element_A, element_B=element_B,
element_C=element_C, element_D=element_D,
element_accumulator=element_accumulator,
element_accumulator=element_accumulator,
alignment_A=type2alignment[element_A], alignment_B=type2alignment[element_B],
alignment_C=type2alignment[element_C]
)
conv2d_eq.test_all()
setattr(ConvEquivalenceTest, test_name, run)
for conv_kind in ["fprop", "wgrad", "dgrad"]:
@ -255,25 +255,25 @@ class Conv2dErrorTests(unittest.TestCase):
"""
Tests various error scenarios that arise with the high-level Gemm interface
"""
def test_alignment(self):
"""
Tests case in which the alignment specified is unsupported
"""
plan = cutlass.op.Conv2d(kind="fprop", element=cutlass.DataType.f16)
with ExpectException(True, 'Alignment 3 is not supported for F16. The construction should fail.'):
op = plan.construct(alignment_A=3, alignment_B=3, alignment_C=3)
def test_invalid_tile_description(self):
"""
Tests scenarios in which an invalid tile description is provided for a given CC
"""
plan = cutlass.op.Conv2d(kind="fprop", element=cutlass.DataType.f16)
td = plan.tile_descriptions()[0]
td.threadblock_shape=[17, 32, 5]
plan.tile_description = td
with ExpectException(True, 'The threadblock shape is invalid. The compilation should fail.'):
plan.compile()

View File

@ -93,13 +93,16 @@ class EVTErrorTests(unittest.TestCase):
"""
Test when the epilogue consumes too much shared memory
"""
def evt_too_much_shared_memory(accum, C1, C2, C3, C4, C5):
def evt_too_much_shared_memory(accum, C1, C2, C3, C4, C5, C6, C7, C8):
D1 = accum + C1
D2 = D1 + C2
D3 = D2 + C3
D4 = D3 + C4
D = D4 + C5
return D, D1, D2, D3, D4
D5 = D4 + C5
D6 = D5 + C6
D7 = D6 + C7
D = D7 + C8
return D, D1, D2, D3, D4, D5, D6, D7
example_tensors = {
"accum": self.fake_tensor(np.float16, (6, 512, 512)),
@ -108,10 +111,16 @@ class EVTErrorTests(unittest.TestCase):
"C3": self.fake_tensor(np.float16, (6, 512, 512)),
"C4": self.fake_tensor(np.float16, (6, 512, 512)),
"C5": self.fake_tensor(np.float16, (6, 512, 512)),
"C6": self.fake_tensor(np.float16, (6, 512, 512)),
"C7": self.fake_tensor(np.float16, (6, 512, 512)),
"C8": self.fake_tensor(np.float16, (6, 512, 512)),
"D1": self.fake_tensor(np.float16, (6, 512, 512)),
"D2": self.fake_tensor(np.float16, (6, 512, 512)),
"D3": self.fake_tensor(np.float16, (6, 512, 512)),
"D4": self.fake_tensor(np.float16, (6, 512, 512)),
"D5": self.fake_tensor(np.float16, (6, 512, 512)),
"D6": self.fake_tensor(np.float16, (6, 512, 512)),
"D7": self.fake_tensor(np.float16, (6, 512, 512)),
"D": self.fake_tensor(np.float16, (6, 512, 512))
}

View File

@ -85,7 +85,7 @@ class GemmEquivalence:
Tests the equivalence of various constructions of the Gemm interface when using CUTLASS data types
and layouts for constructing the Gemm interface
"""
if not datatypes.numpy_available:
if not datatypes.is_numpy_available():
return
# Test when specifying all parameters
@ -126,7 +126,7 @@ class GemmEquivalence:
"""
Tests the equivalence of various constructions of the Gemm interface when using numpy as a frontend
"""
if not datatypes.numpy_available:
if not datatypes.is_numpy_available():
return
import numpy as np