CUTLASS 3.4.0 (#1286)

* CUTLASS 3.4.0

* Update CHANGELOG.md

---------

Co-authored-by: Pradeep Ramani <prramani@nvidia.com>
This commit is contained in:
Pradeep Ramani
2023-12-29 12:21:31 -08:00
committed by GitHub
parent b7508e3379
commit 8236f30675
211 changed files with 11409 additions and 2763 deletions

View File

@ -46,8 +46,8 @@ from utils.evt_testbed import EVTTestBed, EVTTestCaseBase
cutlass.set_log_level(logging.WARNING)
@unittest.skipIf(device_cc() not in [80, 90], "This unittest is for Sm80 and Sm90 only")
class TestEVTComputeSM90(EVTTestCaseBase):
@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]")
class TestEVTCompute(EVTTestCaseBase):
def test_arith(self):
"""

View File

@ -46,8 +46,8 @@ from utils.evt_testbed import EVTTestBed, EVTTestCaseBase
cutlass.set_log_level(logging.WARNING)
@unittest.skipIf(device_cc() not in [80, 90], "This unittest is for Sm80 and Sm90 only")
class TestEVTLayoutSM90(EVTTestCaseBase):
@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]")
class TestEVTLayout(EVTTestCaseBase):
def test_permute_1(self):
"""
@ -74,7 +74,7 @@ class TestEVTLayoutSM90(EVTTestCaseBase):
result_keys = ["D", "F"]
launcher.verify((m, n, k), input_keys, result_keys, l)
@unittest.skipIf(device_cc() == 80, "This unittest is for cc = Sm90 only")
@unittest.skipIf(device_cc() != 90, "This unittest is for cc = Sm90 only")
def test_permute_2(self):
"""
Returning a tensor with shape [m, n]
@ -99,7 +99,7 @@ class TestEVTLayoutSM90(EVTTestCaseBase):
result_keys = ["D", "F"]
launcher.verify((m, n, k), input_keys, result_keys, l)
@unittest.skipIf(device_cc() == 80, "This unittest is for cc = Sm90 only")
@unittest.skipIf(device_cc() != 90, "This unittest is for cc = Sm90 only")
def test_permute_3(self):
"""
Returning a tensor with shape [m, n]

View File

@ -46,8 +46,8 @@ from utils.evt_testbed import EVTTestBed, EVTTestCaseBase
cutlass.set_log_level(logging.WARNING)
@unittest.skipIf(device_cc() not in [80, 90], "This unittest is for Sm80 and Sm90 only")
class TestEVTLoadSM90(EVTTestCaseBase):
@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]")
class TestEVTLoad(EVTTestCaseBase):
def test_tensor_load(self):
"""

View File

@ -47,8 +47,8 @@ from utils.evt_testbed import EVTTestBed, EVTTestCaseBase
cutlass.set_log_level(logging.WARNING)
@unittest.skipIf(device_cc() not in [80, 90], "This unittest is for Sm80 and Sm90 only")
class TestEVTMixedSM90(EVTTestCaseBase):
@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]")
class TestEVTMixed(EVTTestCaseBase):
def test_mixed_dag(self):
def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias):
F = alpha * accum + (beta * C + aux)
@ -84,7 +84,7 @@ class TestEVTMixedSM90(EVTTestCaseBase):
result_keys = ["D", "F", "F_row_max", "E_col_max"]
launcher.verify((m, n, k), input_keys, result_keys, l)
@unittest.skipIf(device_cc() != 80, "This unittest is for cc = Sm80 only")
@unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only")
def test_mixed_dag_float(self):
def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias):
F = alpha * accum + (beta * C + aux)
@ -114,7 +114,7 @@ class TestEVTMixedSM90(EVTTestCaseBase):
result_keys = ["D", "F", "F_row_max", "E_col_max"]
launcher.verify((m, n, k), input_keys, result_keys, l)
@unittest.skipIf(device_cc() != 80, "This unittest is for cc = Sm80 only")
@unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only")
def test_mixed_dag_stage2(self):
def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias):
F = alpha * accum + (beta * C + aux)
@ -144,7 +144,7 @@ class TestEVTMixedSM90(EVTTestCaseBase):
result_keys = ["D", "F", "F_row_max", "E_col_max"]
launcher.verify((m, n, k), input_keys, result_keys, l)
@unittest.skipIf(device_cc() != 80, "This unittest is for cc = Sm80 only")
@unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only")
def test_mixed_dag_partition_k(self):
def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias):
F = alpha * accum + (beta * C + aux)
@ -179,7 +179,7 @@ class TestEVTMixedSM90(EVTTestCaseBase):
result_keys = ["D", "F", "F_row_max", "E_col_max"]
launcher.verify((m, n, k), input_keys, result_keys, l)
@unittest.skipIf(device_cc() != 80, "This unittest is for cc = Sm80 only")
@unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only")
def test_mixed_dag_stream_k(self):
def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias):
F = alpha * accum + (beta * C + aux)

View File

@ -46,8 +46,8 @@ from utils.evt_testbed import EVTTestBed, EVTTestCaseBase
cutlass.set_log_level(logging.WARNING)
@unittest.skipIf(device_cc() not in [80, 90], "This unittest is for Sm80 and Sm90 only")
class TestEVTStoreSM90(EVTTestCaseBase):
@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]")
class TestEVTStore(EVTTestCaseBase):
def test_aux_store(self):
"""

View File

@ -46,8 +46,10 @@ from utils import LayoutCombination, add_test_gemm
cutlass.set_log_level(logging.WARNING)
cc = 80
dtype = cutlass.DataType.f16
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
class GemmF16Sm80(unittest.TestCase):
"""
Wrapper class to which tests will be added dynamically in __main__
@ -56,13 +58,14 @@ class GemmF16Sm80(unittest.TestCase):
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
class GemmF16Sm80StreamK(unittest.TestCase):
"""
Wrapper class to which tests will be added dynamically in __main__
"""
pass
add_test_specialized = partial(add_test_gemm, element=cutlass.DataType.f16, cc=cc, cluster_shape=[1, 1, 1])
add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1])
# Tests using TensorOp
add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp)

View File

@ -46,8 +46,10 @@ from utils import LayoutCombination, add_test_gemm
cutlass.set_log_level(logging.WARNING)
cc = 90
dtype = cutlass.DataType.f16
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.')
@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
class GemmF16Sm90(unittest.TestCase):
"""
Wrapper class to which tests will be added dynamically in __main__
@ -55,7 +57,7 @@ class GemmF16Sm90(unittest.TestCase):
pass
add_test_specialized = partial(add_test_gemm, cls=GemmF16Sm90, element=cutlass.DataType.f16,
add_test_specialized = partial(add_test_gemm, cls=GemmF16Sm90, element=dtype,
warp_count=None, compilation_modes=['nvcc'])
add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp)

View File

@ -46,8 +46,11 @@ from utils import LayoutCombination, add_test_gemm
cutlass.set_log_level(logging.WARNING)
cc = 80
dtype = cutlass.DataType.f32
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
class GemmF32Sm80(unittest.TestCase):
"""
Wrapper class to which tests will be added dynamically in __main__
@ -56,6 +59,7 @@ class GemmF32Sm80(unittest.TestCase):
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
class GemmF32Sm80StreamK(unittest.TestCase):
"""
Wrapper class to which tests will be added dynamically in __main__
@ -63,37 +67,37 @@ class GemmF32Sm80StreamK(unittest.TestCase):
pass
add_test_specialized = partial(add_test_gemm, element=cutlass.DataType.f32, cc=cc, cluster_shape=[1, 1, 1])
add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1])
# Tests using TensorOp
add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp)
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 128, 32], warp_count=[1, 2, 1], stages=3)
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 32], warp_count=[1, 1, 1], stages=4)
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype,
element_accumulator=dtype, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[4, 4, 4], element_output=dtype, element_C=dtype,
element_accumulator=dtype, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype,
element_accumulator=dtype, threadblock_shape=[ 64, 128, 32], warp_count=[1, 2, 1], stages=3)
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype,
element_accumulator=dtype, threadblock_shape=[ 64, 64, 32], warp_count=[1, 1, 1], stages=4)
# Tests using SIMT
add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt)
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2)
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2)
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2)
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
element_accumulator=dtype, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2)
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
element_accumulator=dtype, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2)
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
element_accumulator=dtype, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2)
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
# Stream K tests
add_test_streamk = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK)
add_test_streamk(cls=GemmF32Sm80StreamK, layouts=LayoutCombination.TTN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
add_test_streamk(cls=GemmF32Sm80StreamK, layouts=LayoutCombination.TTN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype,
element_accumulator=dtype, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
if __name__ == '__main__':

View File

@ -46,8 +46,11 @@ from utils import LayoutCombination, add_test_gemm
cutlass.set_log_level(logging.WARNING)
cc = 80
dtype = cutlass.DataType.f64
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
class GemmF64Sm80(unittest.TestCase):
"""
Wrapper class to which tests will be added dynamically in __main__
@ -56,6 +59,7 @@ class GemmF64Sm80(unittest.TestCase):
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
class GemmF64Sm80StreamK(unittest.TestCase):
"""
Wrapper class to which tests will be added dynamically in __main__
@ -63,36 +67,36 @@ class GemmF64Sm80StreamK(unittest.TestCase):
pass
add_test_specialized = partial(add_test_gemm, element=cutlass.DataType.f64, cc=cc, cluster_shape=[1, 1, 1])
add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1])
# Tests using TensorOp
add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp)
add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3)
add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
element_accumulator=cutlass.DataType.f64, threadblock_shape=[ 64, 64, 16], warp_count=[2, 2, 1], stages=4)
add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
element_accumulator=cutlass.DataType.f64, threadblock_shape=[ 32, 32, 16], warp_count=[2, 1, 1], stages=5)
add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
element_accumulator=dtype, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3)
add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
element_accumulator=dtype, threadblock_shape=[ 64, 64, 16], warp_count=[2, 2, 1], stages=4)
add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
element_accumulator=dtype, threadblock_shape=[ 32, 32, 16], warp_count=[2, 1, 1], stages=5)
# Tests using SIMT
add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt)
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
element_accumulator=cutlass.DataType.f64, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2)
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2)
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
element_accumulator=cutlass.DataType.f64, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2)
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
element_accumulator=dtype, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2)
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
element_accumulator=dtype, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2)
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
element_accumulator=dtype, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2)
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
# Stream K tests
add_test_streamk = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK)
add_test_streamk(cls=GemmF64Sm80StreamK, layouts=LayoutCombination.NTT, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3)
add_test_streamk(cls=GemmF64Sm80StreamK, layouts=LayoutCombination.NTT, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
element_accumulator=dtype, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3)
if __name__ == '__main__':

View File

@ -46,8 +46,11 @@ from utils import LayoutCombination, add_test_gemm
cutlass.set_log_level(logging.WARNING)
cc = 90
dtype = cutlass.DataType.f64
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.')
@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
class GemmF64Sm90(unittest.TestCase):
"""
Wrapper class to which tests will be added dynamically in __main__
@ -56,8 +59,7 @@ class GemmF64Sm90(unittest.TestCase):
add_test_specialized = partial(add_test_gemm, cls=GemmF64Sm90, alignments=[1, 1, 1], cluster_shape=[1, 1, 1],
element=cutlass.DataType.f64, element_output=cutlass.DataType.f64,
element_accumulator=cutlass.DataType.f64, compilation_modes=['nvcc'])
element=dtype, element_output=dtype, element_accumulator=dtype, compilation_modes=['nvcc'])
add_test_specialized(opclass=cutlass.OpcodeClass.TensorOp, layouts=LayoutCombination.NNT, threadblock_shape=[128, 128, 32], stages=3)
add_test_specialized(opclass=cutlass.OpcodeClass.TensorOp, layouts=LayoutCombination.TNN, threadblock_shape=[128, 128, 32], stages=3)

View File

@ -0,0 +1,112 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Low-level functionality tests for GEMM with S8 operands on SM90
"""
from functools import partial
import logging
import unittest
import cutlass
from cutlass.backend.utils.device import device_cc
from utils import LayoutCombination, add_test_gemm
cutlass.set_log_level(logging.WARNING)
cc = 90
dtype = cutlass.DataType.e4m3
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.')
@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
class GemmF8E4M3Sm90(unittest.TestCase):
"""
Wrapper class to which tests will be added dynamically in __main__
"""
pass
add_test_specialized = partial(add_test_gemm, cls=GemmF8E4M3Sm90, element=dtype, compilation_modes=['nvcc'])
add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp)
# Test with 1x1x1 clusters
add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.e4m3,
element_accumulator=cutlass.DataType.f32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None)
# Tests with different cluster shapes
add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.e4m3,
element_accumulator=cutlass.DataType.f32, cluster_shape=[2, 2, 1], threadblock_shape=[128, 128, 128], stages=None)
add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.e4m3,
element_accumulator=cutlass.DataType.f32, cluster_shape=[1, 4, 1], threadblock_shape=[128, 128, 128], stages=None)
# Tests with warp-specialized ping-pong schedule
add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.e4m3,
element_accumulator=cutlass.DataType.f32, cluster_shape=[2, 1, 1], threadblock_shape=[128, 128, 128], stages=None,
kernel_schedule=cutlass.KernelScheduleType.TmaWarpSpecializedPingpong,
epilogue_schedule=cutlass.EpilogueScheduleType.TmaWarpSpecialized)
# Tests for SIMT
add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt)
add_test_simt(layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.e4m3,
element_accumulator=cutlass.DataType.f32, cluster_shape=[1, 1, 1], threadblock_shape=[64, 32, 8], stages=2)
#
# Add a test for E5M2
#
dtype = cutlass.DataType.e5m2
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.')
@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
class GemmF8E5M2Sm90(unittest.TestCase):
"""
Wrapper class to which tests will be added dynamically in __main__
"""
pass
add_test_specialized = partial(add_test_gemm, cls=GemmF8E5M2Sm90, element=dtype, compilation_modes=['nvcc'])
add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp)
# Tests with 1x1x1 clusters
add_test_tensorop(layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=dtype,
element_accumulator=cutlass.DataType.f32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=3)
if __name__ == '__main__':
unittest.main()

View File

@ -46,8 +46,11 @@ from utils import LayoutCombination, add_test_gemm
cutlass.set_log_level(logging.WARNING)
cc = 80
dtype =cutlass.DataType.f16
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
class GemmMixedSm80(unittest.TestCase):
"""
Wrapper class to which tests will be added dynamically in __main__
@ -55,7 +58,7 @@ class GemmMixedSm80(unittest.TestCase):
pass
add_test_mixed = partial(add_test_gemm, cls=GemmMixedSm80, element=cutlass.DataType.f16, cc=cc, cluster_shape=[1, 1, 1],
add_test_mixed = partial(add_test_gemm, cls=GemmMixedSm80, element=dtype, cc=cc, cluster_shape=[1, 1, 1],
opclass=cutlass.OpcodeClass.TensorOp, threadblock_shape=[128, 128, 64],
warp_count=[2, 2, 1], stages=3, element_accumulator=cutlass.DataType.f32)

View File

@ -46,8 +46,11 @@ from utils import LayoutCombination, add_test_gemm
cutlass.set_log_level(logging.WARNING)
cc = 80
dtype = cutlass.DataType.s8
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
class GemmS8Sm80(unittest.TestCase):
"""
Wrapper class to which tests will be added dynamically in __main__
@ -56,6 +59,7 @@ class GemmS8Sm80(unittest.TestCase):
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
class GemmS8Sm80StreamK(unittest.TestCase):
"""
Wrapper class to which tests will be added dynamically in __main__
@ -63,7 +67,7 @@ class GemmS8Sm80StreamK(unittest.TestCase):
pass
add_test_specialized = partial(add_test_gemm, element=cutlass.DataType.s8, cc=cc, cluster_shape=[1, 1, 1])
add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1])
# Tests using TensorOp
add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp)

View File

@ -46,8 +46,11 @@ from utils import LayoutCombination, add_test_gemm
cutlass.set_log_level(logging.WARNING)
cc = 90
dtype = cutlass.DataType.s8
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.')
@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
class GemmS8Sm90(unittest.TestCase):
"""
Wrapper class to which tests will be added dynamically in __main__
@ -55,7 +58,7 @@ class GemmS8Sm90(unittest.TestCase):
pass
add_test_specialized = partial(add_test_gemm, cls=GemmS8Sm90, element=cutlass.DataType.s8, compilation_modes=['nvcc'])
add_test_specialized = partial(add_test_gemm, cls=GemmS8Sm90, element=dtype, compilation_modes=['nvcc'])
add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp)

View File

@ -128,13 +128,22 @@ class GemmUniversalLauncher:
def uniform_init(self, shape, dtype, layout):
size = prod(shape)
if dtype.is_floating_point:
data = torch.ceil(torch.empty(size=(size,), dtype=dtype, device="cuda").uniform_(self.rand_min - 0.5, self.rand_max - 0.5))
# Initialize data in FP32 and call convert to the data type we desire.
# This is a workaround for the following error that occurs when attempting to
# call uniform_ on a tensor with torch.float8_e4m3fn data:
# RuntimeError: "check_uniform_bounds" not implemented for 'Float8_e4m3fn'
data = torch.ceil(
torch.empty(size=(size,), dtype=torch.float32, device="cuda").uniform_(
self.rand_min - 0.5, self.rand_max - 0.5)
).to(dtype)
else:
# PyTorch does not currently support integer-typed matrix multiplications on GPU.
# Fall back to CPU for integer type references.
data = torch.empty(size=(size,), dtype=dtype, device="cpu").random_(self.rand_min, self.rand_max + 1)
if dtype == torch.float64 or dtype == torch.float32:
is_fp8 = dtype == getattr(torch, "float8_e4m3fn", -1) or dtype == dtype == getattr(torch, "float8_e5m2", -1)
if dtype == torch.float64 or dtype == torch.float32 or is_fp8:
data = data.to("cpu")
data_ref = data.reshape(shape)
@ -145,6 +154,12 @@ class GemmUniversalLauncher:
data_cutlass = data_ref.transpose(-1, -2).contiguous()
data_cutlass = data_cutlass.to("cuda")
# As of this writing, few operations in PyTorch are supported with FP8 data.
# Thus, we perform computation in FP32 for FP8 reference checks.
if is_fp8:
data_ref = data_ref.to(torch.float32)
return data_cutlass, data_ref
def reference(self, problem_size, tensor_A, tensor_B, tensor_C, alpha, beta):