Files
cutlass/python/cutlass_cppgen/op/op.py
2025-09-23 14:10:50 -07:00

432 lines
18 KiB
Python

#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d)
"""
from bisect import bisect_left
from cutlass_library import (
DataType,
DataTypeSize,
MathOperation,
OperationKind,
SharedMemPerCC
)
import cutlass_cppgen
from cutlass_cppgen import get_option_registry
from cutlass_cppgen.backend.evt import EpilogueFunctorVisitor
from cutlass_cppgen.backend.evt.passes.util import cc_map
from cutlass_cppgen.backend.utils.device import device_cc
from cutlass_cppgen.epilogue import get_activations, get_activation_epilogue, identity
from cutlass_cppgen.library_defaults import KernelsForDataType, _generator_ccs
from cutlass_cppgen.swizzle import get_swizzling_functors
from cutlass_cppgen.utils import datatypes, check
class OperationBase:
"""
Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d)
"""
def __init__(self, cc: int = None, kernel_cc: int = None, operation_kind = OperationKind.Gemm):
"""
:param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90
:type cc: int
:param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80
:type kernel_cc: int
:param operation_kind: class of operation that will be performed (e.g., GEMM, Conv)
:type operation_kind: cutlass_library.OperationKind
"""
self.operation_kind = operation_kind
self.cc = cc if cc is not None else device_cc()
self.specified_kernel_cc = kernel_cc is not None
self.current_cc = kernel_cc if kernel_cc is not None else self._find_closest_cc(self.cc)
self.tile_description = None
self._math_operation = None
self.options = get_option_registry().options_for_cc(self.current_cc, operation_kind)
if self.options is None:
raise Exception(f"Invalid or unsupported compute capability: {self.current_cc}")
# Default activation function: identity
self._activation = identity
def _find_closest_cc(self, cc: int) -> int:
"""
Returns the closest CC in _generator_ccs less than or equal to `cc`
:param cc: compute capability to query
:type cc: int
:returns: closest CC in _generator_ccs less than or equal to `cc`
:rtype: int
"""
if cc in _generator_ccs:
return cc
# Find closest CC lower than this CC
idx = bisect_left(_generator_ccs, cc)
if idx == 0:
raise Exception(f'No valid CC to fall back to for {cc}')
return _generator_ccs[idx-1]
def activations(self) -> list:
"""
Returns possible activation functions that can be used
:return: list of activation functions that can be used
:rtype: list
"""
return get_activations()
def swizzling_functors(self) -> list:
"""
Returns possible swizzling functions that can be used
:return: list of swizzling functions that can be used
:rtype: list
"""
return get_swizzling_functors()
def _reset_options(self, cc: int):
"""
Resets the kernel options based on cc
:param cc: compute capability to reset to
:type cc: int
"""
if cc != self.current_cc:
if cc not in _generator_ccs:
raise Exception(f'Invalid CC for CUTLASS kernels: {cc}.')
self.current_cc = cc
self.options = get_option_registry().options_for_cc(self.current_cc, self.operation_kind)
def _verify_scalar(self, scalar, ref_scalar, ref_dtype, name):
"""
Verifies the following properties:
1) Either ``scalar`` or ``ref_scakar`` must be set (i.e., not ``None``)
2) If ``scalar`` is not ``None``, its datatype must match matches the current version
set by the plan (i.e., those in ``ref_dtype``)
If either of these properties does not hold, an exception is raised. If these properties hold and
``scalar`` is not ``None``, ``scalar`` is returned. Otherwise, ``ref_scalar`` is returned.
:param scalar: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
:type scalar: numpy/cupy/torch scalar
:param ref_scalar: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in
:type ref_scalar: numpy/cupy/torch scalar
:param ref_dtype: data type for the scalar that this object was initialized to
:param name: identifier of the scalar to verify. Used in raising exceptions
:type name: str
:return: valid scalar to use
:rtype: numpy/cupy/torch scalar
"""
if scalar is None:
if ref_scalar is None:
raise Exception(f"Scalar {name} must be set.")
return ref_scalar
if hasattr(scalar, "dtype"):
dtype = datatypes.library_type(scalar.dtype)
if dtype != ref_dtype:
raise Exception(
f"Tensor {name} with type {dtype} does not match expected type {ref_dtype}."
)
return scalar
def _verify_tensor(self, tensor, ref_tensor, ref_dtype, ref_layout, name):
"""
Verifies the following properties:
If ref_dtype is not void:
1) Either ``tensor`` or ``ref_tensor`` must be set (i.e., not ``None``)
2) If ``tensor`` is not ``None``, its datatype and layout must match matches the current versions
set by the plan (i.e., those in ``ref_dtype`` and ``ref_layout``)
If ref_dtype is void:
Neither ``tensor`` nor ``ref_tensor`` are set
If either of these properties does not hold, an exception is raised. If these properties hold and
``tensor`` is not ``None``, ``tensor`` is returned. Otherwise, ``ref_tensor`` is returned.
:param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
:type tensor: numpy/cupy/torch array/tensor object
:param ref_tensor: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in
:type ref_tensor: numpy/cupy/torch array/tensor object
:param ref_dtype: data type for the tensor that this object was initialized to
:param ref_layout: layout for the tensor that this object was initialized to
:param name: identifier of the tensor to verify. Used in raising exceptions
:type name: str
:return: valid tensor object to use
:rtype: numpy/cupy/torch array/tensor object
"""
if ref_dtype == DataType.void:
if tensor is not None or ref_tensor is not None:
raise Exception("Operands with element DataType.void must not be provided a tensor")
return None
if tensor is None:
if ref_tensor is None:
raise Exception(f"Tensor {name} must be set.")
return ref_tensor
self._verify_type_and_layout(tensor, ref_dtype, ref_layout, name)
return tensor
@property
def opclass(self) -> cutlass_cppgen.OpcodeClass:
"""
Returns the opcode class currently in use
:return: opcode class currently in use
:rtype: cutlass_cppgen.OpcodeClass
"""
return self.op_class
@opclass.setter
def opclass(self, oc: cutlass_cppgen.OpcodeClass):
if isinstance(oc, str):
oc = datatypes.getattr_enum(cutlass_cppgen.OpcodeClass, oc)
if oc in self.possible_op_classes:
self.op_class = oc
else:
raise Exception(
f'Unsupported operation class {oc} for CC {self.cc} and data type combination '
f'({self._element_a}, {self._element_b}, {self._element_accumulator}) and '
f'layout combination ({self._layout_a}, {self._layout_b}).')
# Changing the op class also changes the possible operations available. Reset these.
self.possible_operations = self.options.operations(
self.op_class, self._element_a, self._element_b,
self._element_accumulator, self._layout_a, self._layout_b, self._math_operation)
# Changing the op class changes the elements per access in the epilogue. Reset this.
if self.epilogue_functor is not None:
self.epilogue_functor = self._reset_epilogue_functor_alignment(self._elements_per_access(), self.epilogue_functor)
@property
def math_operation(self) -> cutlass_cppgen.MathOperation:
"""
Returns the math operation currently in use
:return: math operation currently in use
:rtype: cutlass_cppgen.MathOperation
"""
return self._math_operation
@math_operation.setter
def math_operation(self, mo: cutlass_cppgen.MathOperation):
if isinstance(mo, str):
mo = datatypes.getattr_enum(cutlass_cppgen.MathOperation, mo)
if not self.specified_kernel_cc:
if self.current_cc in [90, 100, 101, 103]:
# CUTLASS 3.0 kernels do not use different math operations. If one is specified, we
# revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels.
cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
self._reset_options(80)
self._reset_operations(reset_epilogue=False)
elif self.current_cc in [90, 100, 101, 103]:
raise Exception("CUTLASS 3.0 kernels do not use different math operations. "
"To use 2.x kernels with a specific math operation, do not set the `kernel_cc`"
"parameter when constructing the plan.")
self._math_operation = mo
self._reset_operations()
def _elements_per_access(self):
if self.op_class == cutlass_cppgen.OpcodeClass.Simt:
return 1
elif self._element_c != DataType.void:
return 128 // DataTypeSize[self._element_c]
else:
return 128 // max(self.possible_operations.alignments("C"))
def _create_epilogue_functor_activation(self, activation):
"""
Returns the epilogue functor with given activation function
"""
if self.epilogue_functor is None:
elements_per_access = self._elements_per_access()
else:
elements_per_access = self.epilogue_functor.epilogue_vector_length
if not self.specified_kernel_cc:
if self.current_cc in [90, 100, 101, 103] and activation != identity:
# CUTLASS 3.0 kernels in Python currently only support identity activation. If one requests a non-identity activation,
# revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels.
cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
if self._element_c != self._element_d:
raise Exception("CUTLASS 2.x kernels require element C to be the same as element D")
self._reset_options(80)
self._reset_operations(reset_epilogue=False)
elif (self.cc in [90, 100, 101, 103] and self.current_cc not in [90, 100, 101, 103] and activation == identity and self._math_operation is None):
# SM80 fallback kernels are currently used. Since an identity activation is requested,
# we can switch back to using SM90 kernels.
self._reset_options(self.cc)
self._reset_operations(reset_epilogue=False)
else:
if self.current_cc in [90, 100, 101, 103] and activation != identity:
raise Exception("Epilogues with elementwise fusion are not currently supported "
"in the Python interface for 3.x kernels. To use 2.x kernels "
"with fused elementwise epilogues, do not set the `kernel_cc` "
"parameter when constructing the plan.")
return get_activation_epilogue(
activation,
self._element_d,
elements_per_access,
self._element_accumulator,
self._element_accumulator,
)
def _reset_epilogue_functor_activation(self, activation):
"""
Set the epilogue functor based on the provided activation function
"""
self.epilogue_functor = self._create_epilogue_functor_activation(activation)
def _reset_epilogue_functor_alignment(self, alignment, epilogue_functor):
"""
Reset the alignment of the current epilogue functor based on alignment C
"""
if isinstance(epilogue_functor, EpilogueFunctorVisitor):
return epilogue_functor
if epilogue_functor is None or not hasattr(epilogue_functor, 'activation_functor'):
# Identity epilogue does not have 'activation_functor'
activation = identity
else:
activation = epilogue_functor.activation_functor
epilogue_functor = get_activation_epilogue(
activation,
self._element_d,
alignment,
self._element_accumulator,
self._element_accumulator,
)
return epilogue_functor
@property
def activation(self):
"""
Returns the type of the current activation function used
"""
if hasattr(self.epilogue_functor, "activation_functor"):
return self.epilogue_functor.activation_functor
else:
return identity
@activation.setter
def activation(self, act):
"""
Sets the type of the activation function to use
Activation can come with a set of arguments
:param act: type of activation function to use
:type act: str or tuple. e.g. "relu", ("leaky_relu", 0.01)
"""
if isinstance(act, tuple):
if isinstance(act[0], str):
act_fn = getattr(cutlass_cppgen.backend.epilogue, act[0])
else:
act_fn = act[0]
self._reset_epilogue_functor_activation(act_fn)
self._activation_args = act[1]
self._activation = act[0]
else:
if isinstance(act, str):
act = getattr(cutlass_cppgen.backend.epilogue, act)
self._reset_epilogue_functor_activation(act)
self._activation = act
@property
def epilogue_visitor(self):
"""
Return the epilogue functor
"""
return self.epilogue_functor
@epilogue_visitor.setter
def epilogue_visitor(self, visitor):
"""
Create the epilogue visitor
"""
self.epilogue_functor = EpilogueFunctorVisitor(cc_map[self.cc], visitor)
# The epilogue_functor may consume too much shared memory
# Reset the possible operations
if self.cc not in [90, 100, 101, 103]:
# The shared memory is only a concern for sm90+ epilogue
# In sm80, the epilogue and mainloop share the shared memory
return
datatype_comb = self.possible_operations.datatype_comb
layout_comb = self.possible_operations.layout_comb
new_possible_operations = KernelsForDataType(datatype_comb, layout_comb)
for operation in self.possible_operations.all_operations:
td = datatypes.td_from_profiler_op(operation)
# Filter invalid epilogue schedules
if cc_map[self.cc] == 90 and td.epilogue_schedule not in [
cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized,
cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative]:
continue
epilogue_smem_bytes = self.epilogue_functor.get_smem_size(td)
# Verify the maximum number of mainloop stages
mainloop_smem_per_stage = check.calculate_smem_usage_per_stage(td, OperationKind.Gemm)
smem_capacity_bytes = SharedMemPerCC[self.cc] << 10
mainloop_stages = (smem_capacity_bytes - epilogue_smem_bytes) // mainloop_smem_per_stage
if mainloop_stages < 2:
# Mainloop stages must >= 2
continue
new_possible_operations.add(operation)
if len(new_possible_operations.all_operations) == 0:
raise RuntimeError(
"The epilogue consumes too much shared memory. "
"No valid tile description is found in the generator.")
self.possible_operations = new_possible_operations
def run_setup(self):
"""
Steps that must be taken before caling `plan.run()`
"""
# Initialize the memory pool if, if not already done
cutlass_cppgen.get_memory_pool()