432 lines
18 KiB
Python
432 lines
18 KiB
Python
#################################################################################################
|
|
#
|
|
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: BSD-3-Clause
|
|
#
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions are met:
|
|
#
|
|
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
# list of conditions and the following disclaimer.
|
|
#
|
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
# this list of conditions and the following disclaimer in the documentation
|
|
# and/or other materials provided with the distribution.
|
|
#
|
|
# 3. Neither the name of the copyright holder nor the names of its
|
|
# contributors may be used to endorse or promote products derived from
|
|
# this software without specific prior written permission.
|
|
#
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
#
|
|
#################################################################################################
|
|
|
|
"""
|
|
Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d)
|
|
"""
|
|
|
|
from bisect import bisect_left
|
|
|
|
from cutlass_library import (
|
|
DataType,
|
|
DataTypeSize,
|
|
MathOperation,
|
|
OperationKind,
|
|
SharedMemPerCC
|
|
)
|
|
|
|
import cutlass_cppgen
|
|
from cutlass_cppgen import get_option_registry
|
|
from cutlass_cppgen.backend.evt import EpilogueFunctorVisitor
|
|
from cutlass_cppgen.backend.evt.passes.util import cc_map
|
|
from cutlass_cppgen.backend.utils.device import device_cc
|
|
from cutlass_cppgen.epilogue import get_activations, get_activation_epilogue, identity
|
|
from cutlass_cppgen.library_defaults import KernelsForDataType, _generator_ccs
|
|
from cutlass_cppgen.swizzle import get_swizzling_functors
|
|
from cutlass_cppgen.utils import datatypes, check
|
|
|
|
|
|
class OperationBase:
|
|
"""
|
|
Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d)
|
|
"""
|
|
|
|
def __init__(self, cc: int = None, kernel_cc: int = None, operation_kind = OperationKind.Gemm):
|
|
"""
|
|
:param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90
|
|
:type cc: int
|
|
:param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80
|
|
:type kernel_cc: int
|
|
:param operation_kind: class of operation that will be performed (e.g., GEMM, Conv)
|
|
:type operation_kind: cutlass_library.OperationKind
|
|
"""
|
|
self.operation_kind = operation_kind
|
|
self.cc = cc if cc is not None else device_cc()
|
|
self.specified_kernel_cc = kernel_cc is not None
|
|
self.current_cc = kernel_cc if kernel_cc is not None else self._find_closest_cc(self.cc)
|
|
self.tile_description = None
|
|
self._math_operation = None
|
|
|
|
self.options = get_option_registry().options_for_cc(self.current_cc, operation_kind)
|
|
|
|
if self.options is None:
|
|
raise Exception(f"Invalid or unsupported compute capability: {self.current_cc}")
|
|
|
|
# Default activation function: identity
|
|
self._activation = identity
|
|
|
|
def _find_closest_cc(self, cc: int) -> int:
|
|
"""
|
|
Returns the closest CC in _generator_ccs less than or equal to `cc`
|
|
|
|
:param cc: compute capability to query
|
|
:type cc: int
|
|
|
|
:returns: closest CC in _generator_ccs less than or equal to `cc`
|
|
:rtype: int
|
|
"""
|
|
if cc in _generator_ccs:
|
|
return cc
|
|
|
|
# Find closest CC lower than this CC
|
|
idx = bisect_left(_generator_ccs, cc)
|
|
if idx == 0:
|
|
raise Exception(f'No valid CC to fall back to for {cc}')
|
|
return _generator_ccs[idx-1]
|
|
|
|
def activations(self) -> list:
|
|
"""
|
|
Returns possible activation functions that can be used
|
|
|
|
:return: list of activation functions that can be used
|
|
:rtype: list
|
|
"""
|
|
return get_activations()
|
|
|
|
def swizzling_functors(self) -> list:
|
|
"""
|
|
Returns possible swizzling functions that can be used
|
|
|
|
:return: list of swizzling functions that can be used
|
|
:rtype: list
|
|
"""
|
|
return get_swizzling_functors()
|
|
|
|
def _reset_options(self, cc: int):
|
|
"""
|
|
Resets the kernel options based on cc
|
|
|
|
:param cc: compute capability to reset to
|
|
:type cc: int
|
|
"""
|
|
if cc != self.current_cc:
|
|
if cc not in _generator_ccs:
|
|
raise Exception(f'Invalid CC for CUTLASS kernels: {cc}.')
|
|
self.current_cc = cc
|
|
self.options = get_option_registry().options_for_cc(self.current_cc, self.operation_kind)
|
|
|
|
def _verify_scalar(self, scalar, ref_scalar, ref_dtype, name):
|
|
"""
|
|
Verifies the following properties:
|
|
1) Either ``scalar`` or ``ref_scakar`` must be set (i.e., not ``None``)
|
|
2) If ``scalar`` is not ``None``, its datatype must match matches the current version
|
|
set by the plan (i.e., those in ``ref_dtype``)
|
|
|
|
If either of these properties does not hold, an exception is raised. If these properties hold and
|
|
``scalar`` is not ``None``, ``scalar`` is returned. Otherwise, ``ref_scalar`` is returned.
|
|
|
|
:param scalar: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
|
|
:type scalar: numpy/cupy/torch scalar
|
|
:param ref_scalar: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in
|
|
:type ref_scalar: numpy/cupy/torch scalar
|
|
:param ref_dtype: data type for the scalar that this object was initialized to
|
|
:param name: identifier of the scalar to verify. Used in raising exceptions
|
|
:type name: str
|
|
|
|
:return: valid scalar to use
|
|
:rtype: numpy/cupy/torch scalar
|
|
"""
|
|
if scalar is None:
|
|
if ref_scalar is None:
|
|
raise Exception(f"Scalar {name} must be set.")
|
|
return ref_scalar
|
|
if hasattr(scalar, "dtype"):
|
|
dtype = datatypes.library_type(scalar.dtype)
|
|
if dtype != ref_dtype:
|
|
raise Exception(
|
|
f"Tensor {name} with type {dtype} does not match expected type {ref_dtype}."
|
|
)
|
|
return scalar
|
|
|
|
def _verify_tensor(self, tensor, ref_tensor, ref_dtype, ref_layout, name):
|
|
"""
|
|
Verifies the following properties:
|
|
If ref_dtype is not void:
|
|
1) Either ``tensor`` or ``ref_tensor`` must be set (i.e., not ``None``)
|
|
2) If ``tensor`` is not ``None``, its datatype and layout must match matches the current versions
|
|
set by the plan (i.e., those in ``ref_dtype`` and ``ref_layout``)
|
|
If ref_dtype is void:
|
|
Neither ``tensor`` nor ``ref_tensor`` are set
|
|
|
|
If either of these properties does not hold, an exception is raised. If these properties hold and
|
|
``tensor`` is not ``None``, ``tensor`` is returned. Otherwise, ``ref_tensor`` is returned.
|
|
|
|
:param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
|
|
:type tensor: numpy/cupy/torch array/tensor object
|
|
:param ref_tensor: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in
|
|
:type ref_tensor: numpy/cupy/torch array/tensor object
|
|
:param ref_dtype: data type for the tensor that this object was initialized to
|
|
:param ref_layout: layout for the tensor that this object was initialized to
|
|
:param name: identifier of the tensor to verify. Used in raising exceptions
|
|
:type name: str
|
|
|
|
:return: valid tensor object to use
|
|
:rtype: numpy/cupy/torch array/tensor object
|
|
"""
|
|
if ref_dtype == DataType.void:
|
|
if tensor is not None or ref_tensor is not None:
|
|
raise Exception("Operands with element DataType.void must not be provided a tensor")
|
|
return None
|
|
|
|
if tensor is None:
|
|
if ref_tensor is None:
|
|
raise Exception(f"Tensor {name} must be set.")
|
|
return ref_tensor
|
|
|
|
self._verify_type_and_layout(tensor, ref_dtype, ref_layout, name)
|
|
return tensor
|
|
|
|
@property
|
|
def opclass(self) -> cutlass_cppgen.OpcodeClass:
|
|
"""
|
|
Returns the opcode class currently in use
|
|
|
|
:return: opcode class currently in use
|
|
:rtype: cutlass_cppgen.OpcodeClass
|
|
"""
|
|
return self.op_class
|
|
|
|
@opclass.setter
|
|
def opclass(self, oc: cutlass_cppgen.OpcodeClass):
|
|
if isinstance(oc, str):
|
|
oc = datatypes.getattr_enum(cutlass_cppgen.OpcodeClass, oc)
|
|
if oc in self.possible_op_classes:
|
|
self.op_class = oc
|
|
else:
|
|
raise Exception(
|
|
f'Unsupported operation class {oc} for CC {self.cc} and data type combination '
|
|
f'({self._element_a}, {self._element_b}, {self._element_accumulator}) and '
|
|
f'layout combination ({self._layout_a}, {self._layout_b}).')
|
|
|
|
# Changing the op class also changes the possible operations available. Reset these.
|
|
self.possible_operations = self.options.operations(
|
|
self.op_class, self._element_a, self._element_b,
|
|
self._element_accumulator, self._layout_a, self._layout_b, self._math_operation)
|
|
|
|
# Changing the op class changes the elements per access in the epilogue. Reset this.
|
|
if self.epilogue_functor is not None:
|
|
self.epilogue_functor = self._reset_epilogue_functor_alignment(self._elements_per_access(), self.epilogue_functor)
|
|
|
|
@property
|
|
def math_operation(self) -> cutlass_cppgen.MathOperation:
|
|
"""
|
|
Returns the math operation currently in use
|
|
|
|
:return: math operation currently in use
|
|
:rtype: cutlass_cppgen.MathOperation
|
|
"""
|
|
return self._math_operation
|
|
|
|
@math_operation.setter
|
|
def math_operation(self, mo: cutlass_cppgen.MathOperation):
|
|
if isinstance(mo, str):
|
|
mo = datatypes.getattr_enum(cutlass_cppgen.MathOperation, mo)
|
|
|
|
if not self.specified_kernel_cc:
|
|
if self.current_cc in [90, 100, 101, 103]:
|
|
# CUTLASS 3.0 kernels do not use different math operations. If one is specified, we
|
|
# revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels.
|
|
cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
|
|
self._reset_options(80)
|
|
self._reset_operations(reset_epilogue=False)
|
|
elif self.current_cc in [90, 100, 101, 103]:
|
|
raise Exception("CUTLASS 3.0 kernels do not use different math operations. "
|
|
"To use 2.x kernels with a specific math operation, do not set the `kernel_cc`"
|
|
"parameter when constructing the plan.")
|
|
|
|
self._math_operation = mo
|
|
self._reset_operations()
|
|
|
|
def _elements_per_access(self):
|
|
if self.op_class == cutlass_cppgen.OpcodeClass.Simt:
|
|
return 1
|
|
elif self._element_c != DataType.void:
|
|
return 128 // DataTypeSize[self._element_c]
|
|
else:
|
|
return 128 // max(self.possible_operations.alignments("C"))
|
|
|
|
def _create_epilogue_functor_activation(self, activation):
|
|
"""
|
|
Returns the epilogue functor with given activation function
|
|
"""
|
|
if self.epilogue_functor is None:
|
|
elements_per_access = self._elements_per_access()
|
|
else:
|
|
elements_per_access = self.epilogue_functor.epilogue_vector_length
|
|
|
|
if not self.specified_kernel_cc:
|
|
if self.current_cc in [90, 100, 101, 103] and activation != identity:
|
|
# CUTLASS 3.0 kernels in Python currently only support identity activation. If one requests a non-identity activation,
|
|
# revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels.
|
|
cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
|
|
if self._element_c != self._element_d:
|
|
raise Exception("CUTLASS 2.x kernels require element C to be the same as element D")
|
|
self._reset_options(80)
|
|
self._reset_operations(reset_epilogue=False)
|
|
elif (self.cc in [90, 100, 101, 103] and self.current_cc not in [90, 100, 101, 103] and activation == identity and self._math_operation is None):
|
|
# SM80 fallback kernels are currently used. Since an identity activation is requested,
|
|
# we can switch back to using SM90 kernels.
|
|
self._reset_options(self.cc)
|
|
self._reset_operations(reset_epilogue=False)
|
|
else:
|
|
if self.current_cc in [90, 100, 101, 103] and activation != identity:
|
|
raise Exception("Epilogues with elementwise fusion are not currently supported "
|
|
"in the Python interface for 3.x kernels. To use 2.x kernels "
|
|
"with fused elementwise epilogues, do not set the `kernel_cc` "
|
|
"parameter when constructing the plan.")
|
|
|
|
return get_activation_epilogue(
|
|
activation,
|
|
self._element_d,
|
|
elements_per_access,
|
|
self._element_accumulator,
|
|
self._element_accumulator,
|
|
)
|
|
|
|
def _reset_epilogue_functor_activation(self, activation):
|
|
"""
|
|
Set the epilogue functor based on the provided activation function
|
|
"""
|
|
self.epilogue_functor = self._create_epilogue_functor_activation(activation)
|
|
|
|
def _reset_epilogue_functor_alignment(self, alignment, epilogue_functor):
|
|
"""
|
|
Reset the alignment of the current epilogue functor based on alignment C
|
|
"""
|
|
if isinstance(epilogue_functor, EpilogueFunctorVisitor):
|
|
return epilogue_functor
|
|
|
|
if epilogue_functor is None or not hasattr(epilogue_functor, 'activation_functor'):
|
|
# Identity epilogue does not have 'activation_functor'
|
|
activation = identity
|
|
else:
|
|
activation = epilogue_functor.activation_functor
|
|
|
|
epilogue_functor = get_activation_epilogue(
|
|
activation,
|
|
self._element_d,
|
|
alignment,
|
|
self._element_accumulator,
|
|
self._element_accumulator,
|
|
)
|
|
return epilogue_functor
|
|
|
|
@property
|
|
def activation(self):
|
|
"""
|
|
Returns the type of the current activation function used
|
|
"""
|
|
if hasattr(self.epilogue_functor, "activation_functor"):
|
|
return self.epilogue_functor.activation_functor
|
|
else:
|
|
return identity
|
|
|
|
@activation.setter
|
|
def activation(self, act):
|
|
"""
|
|
Sets the type of the activation function to use
|
|
Activation can come with a set of arguments
|
|
|
|
:param act: type of activation function to use
|
|
:type act: str or tuple. e.g. "relu", ("leaky_relu", 0.01)
|
|
|
|
"""
|
|
if isinstance(act, tuple):
|
|
if isinstance(act[0], str):
|
|
act_fn = getattr(cutlass_cppgen.backend.epilogue, act[0])
|
|
else:
|
|
act_fn = act[0]
|
|
self._reset_epilogue_functor_activation(act_fn)
|
|
self._activation_args = act[1]
|
|
self._activation = act[0]
|
|
else:
|
|
if isinstance(act, str):
|
|
act = getattr(cutlass_cppgen.backend.epilogue, act)
|
|
self._reset_epilogue_functor_activation(act)
|
|
self._activation = act
|
|
|
|
@property
|
|
def epilogue_visitor(self):
|
|
"""
|
|
Return the epilogue functor
|
|
"""
|
|
return self.epilogue_functor
|
|
|
|
@epilogue_visitor.setter
|
|
def epilogue_visitor(self, visitor):
|
|
"""
|
|
Create the epilogue visitor
|
|
"""
|
|
self.epilogue_functor = EpilogueFunctorVisitor(cc_map[self.cc], visitor)
|
|
|
|
# The epilogue_functor may consume too much shared memory
|
|
# Reset the possible operations
|
|
if self.cc not in [90, 100, 101, 103]:
|
|
# The shared memory is only a concern for sm90+ epilogue
|
|
# In sm80, the epilogue and mainloop share the shared memory
|
|
return
|
|
|
|
datatype_comb = self.possible_operations.datatype_comb
|
|
layout_comb = self.possible_operations.layout_comb
|
|
new_possible_operations = KernelsForDataType(datatype_comb, layout_comb)
|
|
for operation in self.possible_operations.all_operations:
|
|
td = datatypes.td_from_profiler_op(operation)
|
|
# Filter invalid epilogue schedules
|
|
if cc_map[self.cc] == 90 and td.epilogue_schedule not in [
|
|
cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized,
|
|
cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative]:
|
|
continue
|
|
epilogue_smem_bytes = self.epilogue_functor.get_smem_size(td)
|
|
|
|
# Verify the maximum number of mainloop stages
|
|
mainloop_smem_per_stage = check.calculate_smem_usage_per_stage(td, OperationKind.Gemm)
|
|
smem_capacity_bytes = SharedMemPerCC[self.cc] << 10
|
|
mainloop_stages = (smem_capacity_bytes - epilogue_smem_bytes) // mainloop_smem_per_stage
|
|
if mainloop_stages < 2:
|
|
# Mainloop stages must >= 2
|
|
continue
|
|
|
|
new_possible_operations.add(operation)
|
|
if len(new_possible_operations.all_operations) == 0:
|
|
raise RuntimeError(
|
|
"The epilogue consumes too much shared memory. "
|
|
"No valid tile description is found in the generator.")
|
|
self.possible_operations = new_possible_operations
|
|
|
|
|
|
def run_setup(self):
|
|
"""
|
|
Steps that must be taken before caling `plan.run()`
|
|
"""
|
|
# Initialize the memory pool if, if not already done
|
|
cutlass_cppgen.get_memory_pool()
|