CUTLASS 3.4.0 (#1286)
* CUTLASS 3.4.0 * Update CHANGELOG.md --------- Co-authored-by: Pradeep Ramani <prramani@nvidia.com>
This commit is contained in:
@ -46,8 +46,8 @@ from utils.evt_testbed import EVTTestBed, EVTTestCaseBase
|
||||
cutlass.set_log_level(logging.WARNING)
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() not in [80, 90], "This unittest is for Sm80 and Sm90 only")
|
||||
class TestEVTComputeSM90(EVTTestCaseBase):
|
||||
@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]")
|
||||
class TestEVTCompute(EVTTestCaseBase):
|
||||
|
||||
def test_arith(self):
|
||||
"""
|
||||
|
||||
@ -46,8 +46,8 @@ from utils.evt_testbed import EVTTestBed, EVTTestCaseBase
|
||||
cutlass.set_log_level(logging.WARNING)
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() not in [80, 90], "This unittest is for Sm80 and Sm90 only")
|
||||
class TestEVTLayoutSM90(EVTTestCaseBase):
|
||||
@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]")
|
||||
class TestEVTLayout(EVTTestCaseBase):
|
||||
|
||||
def test_permute_1(self):
|
||||
"""
|
||||
@ -74,7 +74,7 @@ class TestEVTLayoutSM90(EVTTestCaseBase):
|
||||
result_keys = ["D", "F"]
|
||||
launcher.verify((m, n, k), input_keys, result_keys, l)
|
||||
|
||||
@unittest.skipIf(device_cc() == 80, "This unittest is for cc = Sm90 only")
|
||||
@unittest.skipIf(device_cc() != 90, "This unittest is for cc = Sm90 only")
|
||||
def test_permute_2(self):
|
||||
"""
|
||||
Returning a tensor with shape [m, n]
|
||||
@ -99,7 +99,7 @@ class TestEVTLayoutSM90(EVTTestCaseBase):
|
||||
result_keys = ["D", "F"]
|
||||
launcher.verify((m, n, k), input_keys, result_keys, l)
|
||||
|
||||
@unittest.skipIf(device_cc() == 80, "This unittest is for cc = Sm90 only")
|
||||
@unittest.skipIf(device_cc() != 90, "This unittest is for cc = Sm90 only")
|
||||
def test_permute_3(self):
|
||||
"""
|
||||
Returning a tensor with shape [m, n]
|
||||
|
||||
@ -46,8 +46,8 @@ from utils.evt_testbed import EVTTestBed, EVTTestCaseBase
|
||||
cutlass.set_log_level(logging.WARNING)
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() not in [80, 90], "This unittest is for Sm80 and Sm90 only")
|
||||
class TestEVTLoadSM90(EVTTestCaseBase):
|
||||
@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]")
|
||||
class TestEVTLoad(EVTTestCaseBase):
|
||||
|
||||
def test_tensor_load(self):
|
||||
"""
|
||||
|
||||
@ -47,8 +47,8 @@ from utils.evt_testbed import EVTTestBed, EVTTestCaseBase
|
||||
cutlass.set_log_level(logging.WARNING)
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() not in [80, 90], "This unittest is for Sm80 and Sm90 only")
|
||||
class TestEVTMixedSM90(EVTTestCaseBase):
|
||||
@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]")
|
||||
class TestEVTMixed(EVTTestCaseBase):
|
||||
def test_mixed_dag(self):
|
||||
def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias):
|
||||
F = alpha * accum + (beta * C + aux)
|
||||
@ -84,7 +84,7 @@ class TestEVTMixedSM90(EVTTestCaseBase):
|
||||
result_keys = ["D", "F", "F_row_max", "E_col_max"]
|
||||
launcher.verify((m, n, k), input_keys, result_keys, l)
|
||||
|
||||
@unittest.skipIf(device_cc() != 80, "This unittest is for cc = Sm80 only")
|
||||
@unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only")
|
||||
def test_mixed_dag_float(self):
|
||||
def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias):
|
||||
F = alpha * accum + (beta * C + aux)
|
||||
@ -114,7 +114,7 @@ class TestEVTMixedSM90(EVTTestCaseBase):
|
||||
result_keys = ["D", "F", "F_row_max", "E_col_max"]
|
||||
launcher.verify((m, n, k), input_keys, result_keys, l)
|
||||
|
||||
@unittest.skipIf(device_cc() != 80, "This unittest is for cc = Sm80 only")
|
||||
@unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only")
|
||||
def test_mixed_dag_stage2(self):
|
||||
def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias):
|
||||
F = alpha * accum + (beta * C + aux)
|
||||
@ -144,7 +144,7 @@ class TestEVTMixedSM90(EVTTestCaseBase):
|
||||
result_keys = ["D", "F", "F_row_max", "E_col_max"]
|
||||
launcher.verify((m, n, k), input_keys, result_keys, l)
|
||||
|
||||
@unittest.skipIf(device_cc() != 80, "This unittest is for cc = Sm80 only")
|
||||
@unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only")
|
||||
def test_mixed_dag_partition_k(self):
|
||||
def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias):
|
||||
F = alpha * accum + (beta * C + aux)
|
||||
@ -179,7 +179,7 @@ class TestEVTMixedSM90(EVTTestCaseBase):
|
||||
result_keys = ["D", "F", "F_row_max", "E_col_max"]
|
||||
launcher.verify((m, n, k), input_keys, result_keys, l)
|
||||
|
||||
@unittest.skipIf(device_cc() != 80, "This unittest is for cc = Sm80 only")
|
||||
@unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only")
|
||||
def test_mixed_dag_stream_k(self):
|
||||
def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias):
|
||||
F = alpha * accum + (beta * C + aux)
|
||||
|
||||
@ -46,8 +46,8 @@ from utils.evt_testbed import EVTTestBed, EVTTestCaseBase
|
||||
cutlass.set_log_level(logging.WARNING)
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() not in [80, 90], "This unittest is for Sm80 and Sm90 only")
|
||||
class TestEVTStoreSM90(EVTTestCaseBase):
|
||||
@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]")
|
||||
class TestEVTStore(EVTTestCaseBase):
|
||||
|
||||
def test_aux_store(self):
|
||||
"""
|
||||
|
||||
@ -46,8 +46,10 @@ from utils import LayoutCombination, add_test_gemm
|
||||
|
||||
cutlass.set_log_level(logging.WARNING)
|
||||
cc = 80
|
||||
dtype = cutlass.DataType.f16
|
||||
|
||||
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
|
||||
@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
|
||||
class GemmF16Sm80(unittest.TestCase):
|
||||
"""
|
||||
Wrapper class to which tests will be added dynamically in __main__
|
||||
@ -56,13 +58,14 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
|
||||
@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
|
||||
class GemmF16Sm80StreamK(unittest.TestCase):
|
||||
"""
|
||||
Wrapper class to which tests will be added dynamically in __main__
|
||||
"""
|
||||
pass
|
||||
|
||||
add_test_specialized = partial(add_test_gemm, element=cutlass.DataType.f16, cc=cc, cluster_shape=[1, 1, 1])
|
||||
add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1])
|
||||
|
||||
# Tests using TensorOp
|
||||
add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp)
|
||||
|
||||
@ -46,8 +46,10 @@ from utils import LayoutCombination, add_test_gemm
|
||||
|
||||
cutlass.set_log_level(logging.WARNING)
|
||||
cc = 90
|
||||
dtype = cutlass.DataType.f16
|
||||
|
||||
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.')
|
||||
@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
|
||||
class GemmF16Sm90(unittest.TestCase):
|
||||
"""
|
||||
Wrapper class to which tests will be added dynamically in __main__
|
||||
@ -55,7 +57,7 @@ class GemmF16Sm90(unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
add_test_specialized = partial(add_test_gemm, cls=GemmF16Sm90, element=cutlass.DataType.f16,
|
||||
add_test_specialized = partial(add_test_gemm, cls=GemmF16Sm90, element=dtype,
|
||||
warp_count=None, compilation_modes=['nvcc'])
|
||||
|
||||
add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp)
|
||||
|
||||
@ -46,8 +46,11 @@ from utils import LayoutCombination, add_test_gemm
|
||||
|
||||
cutlass.set_log_level(logging.WARNING)
|
||||
cc = 80
|
||||
dtype = cutlass.DataType.f32
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
|
||||
@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
|
||||
class GemmF32Sm80(unittest.TestCase):
|
||||
"""
|
||||
Wrapper class to which tests will be added dynamically in __main__
|
||||
@ -56,6 +59,7 @@ class GemmF32Sm80(unittest.TestCase):
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
|
||||
@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
|
||||
class GemmF32Sm80StreamK(unittest.TestCase):
|
||||
"""
|
||||
Wrapper class to which tests will be added dynamically in __main__
|
||||
@ -63,37 +67,37 @@ class GemmF32Sm80StreamK(unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
add_test_specialized = partial(add_test_gemm, element=cutlass.DataType.f32, cc=cc, cluster_shape=[1, 1, 1])
|
||||
add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1])
|
||||
|
||||
# Tests using TensorOp
|
||||
add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp)
|
||||
|
||||
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
||||
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
||||
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 128, 32], warp_count=[1, 2, 1], stages=3)
|
||||
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 32], warp_count=[1, 1, 1], stages=4)
|
||||
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype,
|
||||
element_accumulator=dtype, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
||||
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[4, 4, 4], element_output=dtype, element_C=dtype,
|
||||
element_accumulator=dtype, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
||||
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype,
|
||||
element_accumulator=dtype, threadblock_shape=[ 64, 128, 32], warp_count=[1, 2, 1], stages=3)
|
||||
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype,
|
||||
element_accumulator=dtype, threadblock_shape=[ 64, 64, 32], warp_count=[1, 1, 1], stages=4)
|
||||
# Tests using SIMT
|
||||
add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt)
|
||||
|
||||
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
|
||||
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2)
|
||||
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2)
|
||||
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2)
|
||||
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
|
||||
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
|
||||
element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
|
||||
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
|
||||
element_accumulator=dtype, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2)
|
||||
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
|
||||
element_accumulator=dtype, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2)
|
||||
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
|
||||
element_accumulator=dtype, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2)
|
||||
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
|
||||
element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
|
||||
|
||||
# Stream K tests
|
||||
add_test_streamk = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK)
|
||||
add_test_streamk(cls=GemmF32Sm80StreamK, layouts=LayoutCombination.TTN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
|
||||
element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
||||
add_test_streamk(cls=GemmF32Sm80StreamK, layouts=LayoutCombination.TTN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype,
|
||||
element_accumulator=dtype, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@ -46,8 +46,11 @@ from utils import LayoutCombination, add_test_gemm
|
||||
|
||||
cutlass.set_log_level(logging.WARNING)
|
||||
cc = 80
|
||||
dtype = cutlass.DataType.f64
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
|
||||
@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
|
||||
class GemmF64Sm80(unittest.TestCase):
|
||||
"""
|
||||
Wrapper class to which tests will be added dynamically in __main__
|
||||
@ -56,6 +59,7 @@ class GemmF64Sm80(unittest.TestCase):
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
|
||||
@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
|
||||
class GemmF64Sm80StreamK(unittest.TestCase):
|
||||
"""
|
||||
Wrapper class to which tests will be added dynamically in __main__
|
||||
@ -63,36 +67,36 @@ class GemmF64Sm80StreamK(unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
add_test_specialized = partial(add_test_gemm, element=cutlass.DataType.f64, cc=cc, cluster_shape=[1, 1, 1])
|
||||
add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1])
|
||||
|
||||
# Tests using TensorOp
|
||||
add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp)
|
||||
|
||||
add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
|
||||
element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3)
|
||||
add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
|
||||
element_accumulator=cutlass.DataType.f64, threadblock_shape=[ 64, 64, 16], warp_count=[2, 2, 1], stages=4)
|
||||
add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
|
||||
element_accumulator=cutlass.DataType.f64, threadblock_shape=[ 32, 32, 16], warp_count=[2, 1, 1], stages=5)
|
||||
add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
|
||||
element_accumulator=dtype, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3)
|
||||
add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
|
||||
element_accumulator=dtype, threadblock_shape=[ 64, 64, 16], warp_count=[2, 2, 1], stages=4)
|
||||
add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
|
||||
element_accumulator=dtype, threadblock_shape=[ 32, 32, 16], warp_count=[2, 1, 1], stages=5)
|
||||
|
||||
# Tests using SIMT
|
||||
add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt)
|
||||
|
||||
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
|
||||
element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
|
||||
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
|
||||
element_accumulator=cutlass.DataType.f64, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2)
|
||||
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
|
||||
element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2)
|
||||
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
|
||||
element_accumulator=cutlass.DataType.f64, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2)
|
||||
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
|
||||
element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
|
||||
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
|
||||
element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
|
||||
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
|
||||
element_accumulator=dtype, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2)
|
||||
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
|
||||
element_accumulator=dtype, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2)
|
||||
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
|
||||
element_accumulator=dtype, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2)
|
||||
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
|
||||
element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
|
||||
|
||||
# Stream K tests
|
||||
add_test_streamk = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK)
|
||||
add_test_streamk(cls=GemmF64Sm80StreamK, layouts=LayoutCombination.NTT, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64,
|
||||
element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3)
|
||||
add_test_streamk(cls=GemmF64Sm80StreamK, layouts=LayoutCombination.NTT, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
|
||||
element_accumulator=dtype, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@ -46,8 +46,11 @@ from utils import LayoutCombination, add_test_gemm
|
||||
|
||||
cutlass.set_log_level(logging.WARNING)
|
||||
cc = 90
|
||||
dtype = cutlass.DataType.f64
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.')
|
||||
@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
|
||||
class GemmF64Sm90(unittest.TestCase):
|
||||
"""
|
||||
Wrapper class to which tests will be added dynamically in __main__
|
||||
@ -56,8 +59,7 @@ class GemmF64Sm90(unittest.TestCase):
|
||||
|
||||
|
||||
add_test_specialized = partial(add_test_gemm, cls=GemmF64Sm90, alignments=[1, 1, 1], cluster_shape=[1, 1, 1],
|
||||
element=cutlass.DataType.f64, element_output=cutlass.DataType.f64,
|
||||
element_accumulator=cutlass.DataType.f64, compilation_modes=['nvcc'])
|
||||
element=dtype, element_output=dtype, element_accumulator=dtype, compilation_modes=['nvcc'])
|
||||
|
||||
add_test_specialized(opclass=cutlass.OpcodeClass.TensorOp, layouts=LayoutCombination.NNT, threadblock_shape=[128, 128, 32], stages=3)
|
||||
add_test_specialized(opclass=cutlass.OpcodeClass.TensorOp, layouts=LayoutCombination.TNN, threadblock_shape=[128, 128, 32], stages=3)
|
||||
|
||||
112
test/python/cutlass/gemm/gemm_f8_sm90.py
Normal file
112
test/python/cutlass/gemm/gemm_f8_sm90.py
Normal file
@ -0,0 +1,112 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Low-level functionality tests for GEMM with S8 operands on SM90
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
import logging
|
||||
import unittest
|
||||
|
||||
import cutlass
|
||||
from cutlass.backend.utils.device import device_cc
|
||||
|
||||
from utils import LayoutCombination, add_test_gemm
|
||||
|
||||
|
||||
cutlass.set_log_level(logging.WARNING)
|
||||
cc = 90
|
||||
dtype = cutlass.DataType.e4m3
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.')
|
||||
@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
|
||||
class GemmF8E4M3Sm90(unittest.TestCase):
|
||||
"""
|
||||
Wrapper class to which tests will be added dynamically in __main__
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
add_test_specialized = partial(add_test_gemm, cls=GemmF8E4M3Sm90, element=dtype, compilation_modes=['nvcc'])
|
||||
|
||||
add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp)
|
||||
|
||||
# Test with 1x1x1 clusters
|
||||
add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.e4m3,
|
||||
element_accumulator=cutlass.DataType.f32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None)
|
||||
|
||||
# Tests with different cluster shapes
|
||||
add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.e4m3,
|
||||
element_accumulator=cutlass.DataType.f32, cluster_shape=[2, 2, 1], threadblock_shape=[128, 128, 128], stages=None)
|
||||
add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.e4m3,
|
||||
element_accumulator=cutlass.DataType.f32, cluster_shape=[1, 4, 1], threadblock_shape=[128, 128, 128], stages=None)
|
||||
|
||||
# Tests with warp-specialized ping-pong schedule
|
||||
add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.e4m3,
|
||||
element_accumulator=cutlass.DataType.f32, cluster_shape=[2, 1, 1], threadblock_shape=[128, 128, 128], stages=None,
|
||||
kernel_schedule=cutlass.KernelScheduleType.TmaWarpSpecializedPingpong,
|
||||
epilogue_schedule=cutlass.EpilogueScheduleType.TmaWarpSpecialized)
|
||||
|
||||
# Tests for SIMT
|
||||
add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt)
|
||||
add_test_simt(layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.e4m3,
|
||||
element_accumulator=cutlass.DataType.f32, cluster_shape=[1, 1, 1], threadblock_shape=[64, 32, 8], stages=2)
|
||||
|
||||
|
||||
#
|
||||
# Add a test for E5M2
|
||||
#
|
||||
dtype = cutlass.DataType.e5m2
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.')
|
||||
@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
|
||||
class GemmF8E5M2Sm90(unittest.TestCase):
|
||||
"""
|
||||
Wrapper class to which tests will be added dynamically in __main__
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
add_test_specialized = partial(add_test_gemm, cls=GemmF8E5M2Sm90, element=dtype, compilation_modes=['nvcc'])
|
||||
|
||||
add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp)
|
||||
|
||||
# Tests with 1x1x1 clusters
|
||||
add_test_tensorop(layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=dtype,
|
||||
element_accumulator=cutlass.DataType.f32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=3)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@ -46,8 +46,11 @@ from utils import LayoutCombination, add_test_gemm
|
||||
|
||||
cutlass.set_log_level(logging.WARNING)
|
||||
cc = 80
|
||||
dtype =cutlass.DataType.f16
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
|
||||
@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
|
||||
class GemmMixedSm80(unittest.TestCase):
|
||||
"""
|
||||
Wrapper class to which tests will be added dynamically in __main__
|
||||
@ -55,7 +58,7 @@ class GemmMixedSm80(unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
add_test_mixed = partial(add_test_gemm, cls=GemmMixedSm80, element=cutlass.DataType.f16, cc=cc, cluster_shape=[1, 1, 1],
|
||||
add_test_mixed = partial(add_test_gemm, cls=GemmMixedSm80, element=dtype, cc=cc, cluster_shape=[1, 1, 1],
|
||||
opclass=cutlass.OpcodeClass.TensorOp, threadblock_shape=[128, 128, 64],
|
||||
warp_count=[2, 2, 1], stages=3, element_accumulator=cutlass.DataType.f32)
|
||||
|
||||
|
||||
@ -46,8 +46,11 @@ from utils import LayoutCombination, add_test_gemm
|
||||
|
||||
cutlass.set_log_level(logging.WARNING)
|
||||
cc = 80
|
||||
dtype = cutlass.DataType.s8
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
|
||||
@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
|
||||
class GemmS8Sm80(unittest.TestCase):
|
||||
"""
|
||||
Wrapper class to which tests will be added dynamically in __main__
|
||||
@ -56,6 +59,7 @@ class GemmS8Sm80(unittest.TestCase):
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
|
||||
@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
|
||||
class GemmS8Sm80StreamK(unittest.TestCase):
|
||||
"""
|
||||
Wrapper class to which tests will be added dynamically in __main__
|
||||
@ -63,7 +67,7 @@ class GemmS8Sm80StreamK(unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
add_test_specialized = partial(add_test_gemm, element=cutlass.DataType.s8, cc=cc, cluster_shape=[1, 1, 1])
|
||||
add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1])
|
||||
|
||||
# Tests using TensorOp
|
||||
add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp)
|
||||
|
||||
@ -46,8 +46,11 @@ from utils import LayoutCombination, add_test_gemm
|
||||
|
||||
cutlass.set_log_level(logging.WARNING)
|
||||
cc = 90
|
||||
dtype = cutlass.DataType.s8
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.')
|
||||
@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
|
||||
class GemmS8Sm90(unittest.TestCase):
|
||||
"""
|
||||
Wrapper class to which tests will be added dynamically in __main__
|
||||
@ -55,7 +58,7 @@ class GemmS8Sm90(unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
add_test_specialized = partial(add_test_gemm, cls=GemmS8Sm90, element=cutlass.DataType.s8, compilation_modes=['nvcc'])
|
||||
add_test_specialized = partial(add_test_gemm, cls=GemmS8Sm90, element=dtype, compilation_modes=['nvcc'])
|
||||
|
||||
add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp)
|
||||
|
||||
|
||||
@ -128,13 +128,22 @@ class GemmUniversalLauncher:
|
||||
def uniform_init(self, shape, dtype, layout):
|
||||
size = prod(shape)
|
||||
if dtype.is_floating_point:
|
||||
data = torch.ceil(torch.empty(size=(size,), dtype=dtype, device="cuda").uniform_(self.rand_min - 0.5, self.rand_max - 0.5))
|
||||
# Initialize data in FP32 and call convert to the data type we desire.
|
||||
# This is a workaround for the following error that occurs when attempting to
|
||||
# call uniform_ on a tensor with torch.float8_e4m3fn data:
|
||||
# RuntimeError: "check_uniform_bounds" not implemented for 'Float8_e4m3fn'
|
||||
data = torch.ceil(
|
||||
torch.empty(size=(size,), dtype=torch.float32, device="cuda").uniform_(
|
||||
self.rand_min - 0.5, self.rand_max - 0.5)
|
||||
).to(dtype)
|
||||
else:
|
||||
# PyTorch does not currently support integer-typed matrix multiplications on GPU.
|
||||
# Fall back to CPU for integer type references.
|
||||
data = torch.empty(size=(size,), dtype=dtype, device="cpu").random_(self.rand_min, self.rand_max + 1)
|
||||
|
||||
if dtype == torch.float64 or dtype == torch.float32:
|
||||
is_fp8 = dtype == getattr(torch, "float8_e4m3fn", -1) or dtype == dtype == getattr(torch, "float8_e5m2", -1)
|
||||
|
||||
if dtype == torch.float64 or dtype == torch.float32 or is_fp8:
|
||||
data = data.to("cpu")
|
||||
|
||||
data_ref = data.reshape(shape)
|
||||
@ -145,6 +154,12 @@ class GemmUniversalLauncher:
|
||||
data_cutlass = data_ref.transpose(-1, -2).contiguous()
|
||||
|
||||
data_cutlass = data_cutlass.to("cuda")
|
||||
|
||||
# As of this writing, few operations in PyTorch are supported with FP8 data.
|
||||
# Thus, we perform computation in FP32 for FP8 reference checks.
|
||||
if is_fp8:
|
||||
data_ref = data_ref.to(torch.float32)
|
||||
|
||||
return data_cutlass, data_ref
|
||||
|
||||
def reference(self, problem_size, tensor_A, tensor_B, tensor_C, alpha, beta):
|
||||
|
||||
Reference in New Issue
Block a user