@ -41,6 +41,7 @@ import cutlass
|
||||
import cutlass_bindings
|
||||
import cutlass.utils.datatypes as datatypes
|
||||
from cutlass.backend.utils.device import device_cc
|
||||
from utils import ExpectException
|
||||
|
||||
|
||||
class GemmEquivalence:
|
||||
@ -220,38 +221,6 @@ class GemmEquivalenceTest(unittest.TestCase):
|
||||
gemm_eq.test_all()
|
||||
|
||||
|
||||
class ExpectException:
|
||||
"""
|
||||
Utility class to assert that an exception was raised when expected
|
||||
|
||||
Example:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
with ExceptionExpected(True, 'Division by zero'):
|
||||
x = 1.0 / 0.0
|
||||
|
||||
:param exception_expected: whether an exception is expected to be raised
|
||||
:type exception_expected: bool
|
||||
:param message: message to print if an exception is raised when not expected or vice versa
|
||||
:type message: str
|
||||
"""
|
||||
def __init__(self, exception_expected: bool, message: str = ''):
|
||||
self.exception_expected = exception_expected
|
||||
self.message = message
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, traceback):
|
||||
exception_raised = exc_type is not None
|
||||
assert self.exception_expected == exception_raised, self.message
|
||||
|
||||
# Suppress the exception
|
||||
return True
|
||||
|
||||
|
||||
class GemmErrorTests(unittest.TestCase):
|
||||
"""
|
||||
Tests various error scenarios that arise with the high-level Gemm interface
|
||||
@ -316,9 +285,22 @@ class GemmErrorTests(unittest.TestCase):
|
||||
td.stages = 0
|
||||
plan.construct(td)
|
||||
|
||||
with ExpectException(cc < 80, f'Requested more than 2 stages on SM{cc}'):
|
||||
td.stages = 3
|
||||
plan.construct(td)
|
||||
if cc < 90:
|
||||
with ExpectException(cc < 80, f'Requested more than 2 stages on SM{cc}'):
|
||||
td.stages = 3
|
||||
plan.construct(td)
|
||||
else:
|
||||
original_kschedule = td.kernel_schedule
|
||||
original_eschedule = td.epilogue_schedule
|
||||
with ExpectException(False, f'Incorrectly flagged an error for insufficient shared memory'):
|
||||
td.kernel_schedule = cutlass.KernelScheduleType.TmaWarpSpecializedPingpong
|
||||
td.epilogue_schedule = cutlass.EpilogueScheduleType.NoSmemWarpSpecialized
|
||||
td.stages = 3
|
||||
plan.construct(td)
|
||||
|
||||
# Reset schedules
|
||||
td.kernel_schedule = original_kschedule
|
||||
td.epilogue_schedule = original_eschedule
|
||||
|
||||
with ExpectException(True, f'Requested too many stages'):
|
||||
td.stages = 100
|
||||
@ -335,9 +317,25 @@ class GemmErrorTests(unittest.TestCase):
|
||||
# Reset cluster shape
|
||||
td.cluster_shape = cluster_shape
|
||||
|
||||
kernel_schedule = td.kernel_schedule
|
||||
with ExpectException(cc < 90, f'Requested a persistent kernel on SM{cc}'):
|
||||
with ExpectException(cc < 90, f'Requested a non-auto schedule on SM{cc}'):
|
||||
td.kernel_schedule = cutlass.KernelScheduleType.TmaWarpSpecializedPingpong
|
||||
td.epilogue_schedule = cutlass.EpilogueScheduleType.TmaWarpSpecialized
|
||||
plan.construct(td)
|
||||
|
||||
with ExpectException(True, f'Requested a non-auto kernel schedule with an auto epilogue schedule'):
|
||||
td.kernel_schedule = cutlass.KernelScheduleType.TmaWarpSpecializedPingpong
|
||||
td.epilogue_schedule = cutlass.EpilogueScheduleType.ScheduleAuto
|
||||
plan.construct(td)
|
||||
|
||||
with ExpectException(True, f'Requested an auto kernel schedule with a non-auto epilogue schedule'):
|
||||
td.kernel_schedule = cutlass.KernelScheduleType.ScheduleAuto
|
||||
td.epilogue_schedule = cutlass.EpilogueScheduleType.TmaWarpSpecialized
|
||||
plan.construct(td)
|
||||
|
||||
with ExpectException(cc < 90, f'Requested a tile scheduler on SM{cc}'):
|
||||
td.kernel_schedule = cutlass.KernelScheduleType.TmaWarpSpecializedCooperative
|
||||
td.epilogue_schedule = cutlass.EpilogueScheduleType.TmaWarpSpecializedCooperative
|
||||
td.tile_scheduler = cutlass.TileSchedulerType.StreamK
|
||||
plan.construct(td)
|
||||
|
||||
# Ensure that all returned tile descriptions are unique
|
||||
|
||||
Reference in New Issue
Block a user