CUTLASS 3.2 (#1024)

* CUTLASS 3.2
This commit is contained in:
ANIKET SHIVAM
2023-08-07 14:50:32 -10:00
committed by GitHub
parent a0d787b746
commit 4575443d44
392 changed files with 47559 additions and 7940 deletions

View File

@ -41,6 +41,7 @@ import cutlass
import cutlass_bindings
import cutlass.utils.datatypes as datatypes
from cutlass.backend.utils.device import device_cc
from utils import ExpectException
class GemmEquivalence:
@ -220,38 +221,6 @@ class GemmEquivalenceTest(unittest.TestCase):
gemm_eq.test_all()
class ExpectException:
"""
Utility class to assert that an exception was raised when expected
Example:
.. highlight:: python
.. code-block:: python
with ExceptionExpected(True, 'Division by zero'):
x = 1.0 / 0.0
:param exception_expected: whether an exception is expected to be raised
:type exception_expected: bool
:param message: message to print if an exception is raised when not expected or vice versa
:type message: str
"""
def __init__(self, exception_expected: bool, message: str = ''):
self.exception_expected = exception_expected
self.message = message
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, traceback):
exception_raised = exc_type is not None
assert self.exception_expected == exception_raised, self.message
# Suppress the exception
return True
class GemmErrorTests(unittest.TestCase):
"""
Tests various error scenarios that arise with the high-level Gemm interface
@ -316,9 +285,22 @@ class GemmErrorTests(unittest.TestCase):
td.stages = 0
plan.construct(td)
with ExpectException(cc < 80, f'Requested more than 2 stages on SM{cc}'):
td.stages = 3
plan.construct(td)
if cc < 90:
with ExpectException(cc < 80, f'Requested more than 2 stages on SM{cc}'):
td.stages = 3
plan.construct(td)
else:
original_kschedule = td.kernel_schedule
original_eschedule = td.epilogue_schedule
with ExpectException(False, f'Incorrectly flagged an error for insufficient shared memory'):
td.kernel_schedule = cutlass.KernelScheduleType.TmaWarpSpecializedPingpong
td.epilogue_schedule = cutlass.EpilogueScheduleType.NoSmemWarpSpecialized
td.stages = 3
plan.construct(td)
# Reset schedules
td.kernel_schedule = original_kschedule
td.epilogue_schedule = original_eschedule
with ExpectException(True, f'Requested too many stages'):
td.stages = 100
@ -335,9 +317,25 @@ class GemmErrorTests(unittest.TestCase):
# Reset cluster shape
td.cluster_shape = cluster_shape
kernel_schedule = td.kernel_schedule
with ExpectException(cc < 90, f'Requested a persistent kernel on SM{cc}'):
with ExpectException(cc < 90, f'Requested a non-auto schedule on SM{cc}'):
td.kernel_schedule = cutlass.KernelScheduleType.TmaWarpSpecializedPingpong
td.epilogue_schedule = cutlass.EpilogueScheduleType.TmaWarpSpecialized
plan.construct(td)
with ExpectException(True, f'Requested a non-auto kernel schedule with an auto epilogue schedule'):
td.kernel_schedule = cutlass.KernelScheduleType.TmaWarpSpecializedPingpong
td.epilogue_schedule = cutlass.EpilogueScheduleType.ScheduleAuto
plan.construct(td)
with ExpectException(True, f'Requested an auto kernel schedule with a non-auto epilogue schedule'):
td.kernel_schedule = cutlass.KernelScheduleType.ScheduleAuto
td.epilogue_schedule = cutlass.EpilogueScheduleType.TmaWarpSpecialized
plan.construct(td)
with ExpectException(cc < 90, f'Requested a tile scheduler on SM{cc}'):
td.kernel_schedule = cutlass.KernelScheduleType.TmaWarpSpecializedCooperative
td.epilogue_schedule = cutlass.EpilogueScheduleType.TmaWarpSpecializedCooperative
td.tile_scheduler = cutlass.TileSchedulerType.StreamK
plan.construct(td)
# Ensure that all returned tile descriptions are unique