Rename python/cutlass to python/cutlass_cppgen (#2652)
This commit is contained in:
committed by
Haicheng Wu
parent
4260d4aef9
commit
177a82e251
36
python/cutlass_cppgen/op/__init__.py
Normal file
36
python/cutlass_cppgen/op/__init__.py
Normal file
@ -0,0 +1,36 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass_cppgen.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad
|
||||
from cutlass_cppgen.op.gemm import Gemm
|
||||
from cutlass_cppgen.op.gemm_grouped import GroupedGemm
|
||||
from cutlass_cppgen.op.op import OperationBase
|
||||
997
python/cutlass_cppgen/op/conv.py
Normal file
997
python/cutlass_cppgen/op/conv.py
Normal file
@ -0,0 +1,997 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Ease-of-use interface for constructing, compiling, and running CONVs
|
||||
|
||||
The ``Conv2d`` interface is meant to allow one to easily instantiate, compile, and run
|
||||
CONV2D operations in CUTLASS via Python, without specifying many configuration parameters.
|
||||
Under the hood, the interface will select sensible default parameters for the many template
|
||||
parameters for CUTLASS CONVs.
|
||||
|
||||
Note: optimal performance is not to be expected from this interface. To achieve optimal
|
||||
performance, one should specify and tune each configuration parameter.
|
||||
|
||||
The simplest example of using this interface is the following:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
# A, B, C, and D are torch/numpy/cupy tensor objects
|
||||
plan = cutlass_cppgen.op.Conv(A, B, C, D)
|
||||
plan.run(stride=(1, 1), padding=(0, 0), dilation=(1, 1))
|
||||
|
||||
One can also use the interface by specifying data types of operands at construction
|
||||
and using different tensor objects with these data types at runtime:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
# The following is shorthand for:
|
||||
# cutlass_cppgen.op.Conv2d(kind="fprop",
|
||||
# element_A=torch.float32, element_B=torch.float32,
|
||||
# element_C=torch.float32, element_D=torch.float32,
|
||||
# element_accumulator=torch.float32)
|
||||
plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=torch.float32)
|
||||
|
||||
A0 = torch.rand((128, 256), dtype=torch.float32, device='cuda')
|
||||
B0 = torch.rand((256, 64), dtype=torch.float32, device='cuda')
|
||||
C0 = torch.zeros((128, 64), dtype=torch.float32, device='cuda')
|
||||
D0 = torch.zeros((128, 64), dtype=torch.float32, device.'cuda')
|
||||
plan.run(A0, B0, C0, D0, stride=(1, 1), padding=(0, 0), dilation=(1, 1))
|
||||
|
||||
A = torch.rand((32, 128), dtype=torch.float32, device='cuda')
|
||||
B = torch.rand((128, 256), dtype=torch.float32, device='cuda')
|
||||
C = torch.zeros((32, 256), dtype=torch.float32, device='cuda')
|
||||
D = torch.zeros((32, 256), dtype=torch.float32, device.'cuda')
|
||||
plan.run(A1, B1, C1, D1, stride=(1, 1), padding=(0, 0), dilation=(1, 1))
|
||||
|
||||
The interface additionally enables one to decouple the compilation of the underlying CUTLASS
|
||||
kernel from its execution:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32)
|
||||
|
||||
# Do other work...
|
||||
|
||||
plan.run(A0, B0, C0, D0, stride=(1, 1), padding=(0, 0), dilation=(1, 1))
|
||||
|
||||
# Do other work...
|
||||
|
||||
plan.run(A1, B1, C1, D1, stride=(1, 1), padding=(0, 0), dilation=(1, 1))
|
||||
|
||||
Elementwise activation functions are easily fused to the GEMM via the interface:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32)
|
||||
plan.activation = cutlass_cppgen.epilogue.relu
|
||||
|
||||
Operations can also be run asynchronously:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32)
|
||||
args = plan.run()
|
||||
|
||||
# Do other work...
|
||||
|
||||
args.sync()
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Optional
|
||||
from cutlass_cppgen.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
cudart = lazy_import("cuda.cudart")
|
||||
from cutlass_library import (
|
||||
ConvKind,
|
||||
ConvMode,
|
||||
DataTypeSize,
|
||||
IteratorAlgorithm,
|
||||
OperationKind,
|
||||
SplitKMode,
|
||||
StrideSupport,
|
||||
)
|
||||
|
||||
import cutlass_cppgen
|
||||
from cutlass_cppgen import epilogue
|
||||
from cutlass_cppgen.backend import compiler
|
||||
from cutlass_cppgen.backend.conv2d_operation import Conv2dArguments, Conv2dOperation
|
||||
from cutlass_cppgen.backend.reduction_operation import ReductionOperation, ReductionArguments
|
||||
from cutlass_cppgen.backend.library import TensorDescription, TileDescription
|
||||
from cutlass_cppgen.op.op import OperationBase
|
||||
from cutlass_cppgen.shape import Conv2DProblemSize, MatrixCoord
|
||||
from cutlass_cppgen.utils import check, datatypes
|
||||
|
||||
|
||||
class Conv2d(OperationBase):
|
||||
"""
|
||||
Constructs a ``Conv2d`` object.
|
||||
|
||||
The convolution kind (fprop, wgrad, degrad), the data types of operands A, B, and C,
|
||||
along with the data type of output D and that used for accumulation, are bound to the ``Conv``
|
||||
object throughout its lifetime -- these are not to be changed after a ``Conv2d`` has been constructed.
|
||||
|
||||
The constructor has optional parameters for flexibly setting these parameters. The following
|
||||
constructors are equivalent:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
# Use F32 for A, B, C, D, and accumulation in fprop
|
||||
|
||||
# Use the generic ``element`` parameter to concisely set all data types for operands to the same values.
|
||||
Conv2d(kind="fprop", element=cutlass_cppgen.DataType.f32)
|
||||
|
||||
# Explicitly specify the data types to use for A, B, C, and D.
|
||||
Conv2d(kind="fprop", element_A=cutlass_cppgen.DataType.f32, element_B=cutlass_cppgen.DataType.f32,
|
||||
element_C=cutlass_cppgen.DataType.f32, element_D=cutlass_cppgen.DataType.f32)
|
||||
|
||||
# Set the data types and elements from existing tensors. Note that one can use different tensors when
|
||||
# executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must
|
||||
# have the same data type as those passed in here).
|
||||
# A, B, C, and D are torch.Tensor objects of type torch.float32 under the channel-last layout
|
||||
Conv2d(kind="fprop", A=A, B=B, C=C, D=D)
|
||||
|
||||
# Explicitly specify the data type for only some of A, B, C, and D. Unspecified data types will inherit
|
||||
# those passed in via the generic ``element``
|
||||
Conv2d(kind="fprop", element_A=cutlass_cppgen.DataType.f32, element_accumulator=cutlass_cppgen.DataType.f32,
|
||||
element=cutlass_cppgen.DataType.f32)
|
||||
|
||||
The order of precedence for the setting of the data type for a given operand/output is as follows:
|
||||
1) If the tensor type is specified (e.g., ``A``), use the data type inferred from this tensor
|
||||
2) Otherwise, if the data type (e.g., ``element_A``) is specified, use those
|
||||
3) Otherwise, use the generic values (e.g., ``element``)
|
||||
|
||||
:param kind: the convolution kind (i.e. fprop, wgrad, and dgrad)
|
||||
:type kind: str
|
||||
:param A: tensor representing data type of operand A
|
||||
:param B: tensor representing data type of operand B
|
||||
:param C: tensor representing data type of operand C
|
||||
:param D: tensor representing data type of operand D
|
||||
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
|
||||
:param beta: scalar parameter beta from GEMM operation that scales operand C
|
||||
:param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type
|
||||
:type element: cutlass_cppgen.DataType
|
||||
:param element_A: data type to be used for operand A
|
||||
:type element_A: cutlass_cppgen.DataType
|
||||
:param element_B: data type to be used for operand B
|
||||
:type element_B: cutlass_cppgen.DataType
|
||||
:param element_C: data type to be used for operand C
|
||||
:type element_C: cutlass_cppgen.DataType
|
||||
:param element_D: data type to be used for operand D
|
||||
:type element_D: cutlass_cppgen.DataType
|
||||
:param element_accumulator: data type to be used in accumulation of the product of operands A and B
|
||||
:type element_accumulator: cutlass_cppgen.DataType
|
||||
:param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90
|
||||
:type cc: int
|
||||
:param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80
|
||||
:type kernel_cc: int
|
||||
"""
|
||||
def __init__(
|
||||
self, kind="fprop",
|
||||
A=None, B=None, C=None, D=None, alpha=1.0, beta=0.0,
|
||||
element=None,
|
||||
element_A=None, element_B=None, element_C=None, element_D=None,
|
||||
element_accumulator=None,
|
||||
cc: int = None, kernel_cc: int = None
|
||||
):
|
||||
super().__init__(cc=cc, kernel_cc=kernel_cc, operation_kind=OperationKind.Conv2d)
|
||||
# Verify the kernel cc
|
||||
if self.current_cc in [90, 100, 101, 103]:
|
||||
# The Conv2d kernel on Hopper (SM90) is currently unsupported
|
||||
# Revert to use SM80-tagged kernels
|
||||
cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
|
||||
self.specified_kernel_cc = 80
|
||||
self._reset_options(80)
|
||||
|
||||
# The arch is used in testing
|
||||
self.arch = self.current_cc
|
||||
self.name = "conv2d" + kind
|
||||
|
||||
# The convolution kind. (concept: cutlass_library.library.ConvKind)
|
||||
self.conv_kind = datatypes.getattr_enum(ConvKind, kind)
|
||||
|
||||
# The element types (concept: cutlass library types) of A, B, C, and D
|
||||
elements = []
|
||||
layouts = []
|
||||
|
||||
# Complete the data types based on user-provided arguments
|
||||
for elt, tens, name in zip([element_A, element_B, element_C, element_D],
|
||||
[A, B, C, D],
|
||||
["A", "B", "C", "D"]):
|
||||
if elt is not None and tens is not None:
|
||||
raise Exception(f'Must not specify both element_{name} and tensor {name}')
|
||||
if elt is None and tens is None and element is None:
|
||||
raise Exception(f'Must specify one of element_{name}, tensor {name}, or generic element.')
|
||||
|
||||
elt_to_set = None
|
||||
lay_to_set = None
|
||||
|
||||
if tens is not None:
|
||||
elt_to_set, _ = datatypes.get_datatype_and_layout(tens)
|
||||
else:
|
||||
elt_to_set = elt if elt is not None else element
|
||||
|
||||
assert elt_to_set is not None
|
||||
|
||||
# Currently we only support layout TensorNHWC
|
||||
lay_to_set = cutlass_cppgen.LayoutType.TensorNHWC
|
||||
elements.append(datatypes.library_type(elt_to_set))
|
||||
layouts.append(lay_to_set)
|
||||
|
||||
self._element_a, self._element_b, self._element_c, self._element_d = elements
|
||||
self._layout_a, self._layout_b, self._layout_c, self._layout_d = layouts
|
||||
|
||||
self.A, self.B, self.C, self.D, self.alpha, self.beta = A, B, C, D, alpha, beta
|
||||
|
||||
if element_accumulator is None:
|
||||
self._element_accumulator = self._element_c
|
||||
else:
|
||||
self._element_accumulator = datatypes.library_type(element_accumulator)
|
||||
|
||||
# Default inputs if none is supplied in run()
|
||||
self.A = A
|
||||
self.B = B
|
||||
self.C = C
|
||||
self.D = D
|
||||
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
|
||||
# We only specify the stride of the swizzling functor here
|
||||
# The actual swizzling functor is determined in run based on conv_kind and stride
|
||||
self._swizzling_stride = 1
|
||||
|
||||
# Arguments that will be set to default value in _reset_operations
|
||||
# The default tile_description and op_class are fetched from manifest of cutlass library
|
||||
self._tile_description = None
|
||||
self.op_class = None
|
||||
# The default identity epilogue will be created
|
||||
self.epilogue_functor = None
|
||||
|
||||
self._reset_operations()
|
||||
|
||||
# Arguments that will be determined online based on arguments of "run"
|
||||
# based on stride, input/output channels, alignment, and conv_kind
|
||||
self._iterator_algorithm = None
|
||||
self._stride_support = None
|
||||
|
||||
def _reset_operations(self, reset_epilogue: bool = True):
|
||||
# Set the default op class
|
||||
datatype_comb = (self._element_a, self._element_b, self._element_accumulator)
|
||||
layout_comb = (self._layout_a, self._layout_b)
|
||||
|
||||
self.possible_op_classes = self.options.supporting_opclasses(
|
||||
self._element_a, self._element_b, self._element_accumulator,
|
||||
self._layout_a, self._layout_b, self._math_operation
|
||||
)
|
||||
|
||||
if cutlass_cppgen.OpcodeClass.TensorOp in self.possible_op_classes:
|
||||
self.opclass = cutlass_cppgen.OpcodeClass.TensorOp
|
||||
elif cutlass_cppgen.OpcodeClass.Simt in self.possible_op_classes:
|
||||
self.opclass = cutlass_cppgen.OpcodeClass.Simt
|
||||
else:
|
||||
if self._math_operation is not None:
|
||||
math_op_str = f' and math operation {self._math_operation}'
|
||||
else:
|
||||
math_op_str = ''
|
||||
|
||||
raise Exception(f'No kernel configuration found for supported data type and layout '
|
||||
f'combination {datatype_comb}x{layout_comb}{math_op_str}')
|
||||
|
||||
if reset_epilogue:
|
||||
self._reset_epilogue_functor_activation(epilogue.identity)
|
||||
|
||||
self.alignment_pref_A = min(
|
||||
128 // DataTypeSize[self._element_a], max(self.possible_operations.alignments("A")))
|
||||
self.alignment_pref_B = min(
|
||||
128 // DataTypeSize[self._element_b], max(self.possible_operations.alignments("B")))
|
||||
self.alignment_pref_C = min(
|
||||
128 // DataTypeSize[self._element_c], max(self.possible_operations.alignments("C")))
|
||||
|
||||
#
|
||||
# Tile description Related
|
||||
#
|
||||
|
||||
@property
|
||||
def tile_description(self) -> TileDescription:
|
||||
"""
|
||||
Returns the tile description
|
||||
"""
|
||||
return self._tile_description
|
||||
|
||||
@tile_description.setter
|
||||
def tile_description(
|
||||
self, td=None):
|
||||
"""
|
||||
Set the tile description
|
||||
|
||||
:param td: tile description
|
||||
:type td: cutlass_cppgen.backend.TileDescription, or a dict with keys
|
||||
{
|
||||
"threadblock_shape": [int, int, int],
|
||||
"warp_count": [int, int, int],
|
||||
"stages": int,
|
||||
"instruction_shape": [int, int, int] (optional),
|
||||
"cluster_shape": [int, int, int] (optional)
|
||||
}
|
||||
"""
|
||||
if td is None:
|
||||
return
|
||||
if isinstance(td, dict):
|
||||
if self._tile_description is None:
|
||||
op = self.possible_operations.default_operation(self._math_operation)
|
||||
self._tile_description = datatypes.td_from_profiler_op(op)
|
||||
if "cluster_shape" in td.keys():
|
||||
if td["cluster_shape"] != [1, 1, 1]:
|
||||
cutlass_cppgen.logger.warning("Conv2d currently only support 'cluster_shape'=[1, 1, 1]'.")
|
||||
td["cluster_shape"] = [1, 1, 1]
|
||||
td = self._tile_description.clone_and_update(td)
|
||||
|
||||
valid, msg = self._valid_tile_description(td)
|
||||
if valid:
|
||||
self._tile_description = td
|
||||
else:
|
||||
raise Exception(msg)
|
||||
|
||||
def _valid_tile_description(self, td: TileDescription) -> tuple:
|
||||
"""
|
||||
Checks whether the provided tile description is valid for the given compute capability. At present,
|
||||
this checks the following:
|
||||
|
||||
- Does the tile description use a number of stages supported by the compute capability in question?
|
||||
- Does the tile size requested fit within shared memory?
|
||||
- Are cluster dimensions outside the valid range requested for a given architecture (e.g.,
|
||||
more non-unit cluster dimensions for pre-SM90 architectures)?
|
||||
- Is the kernel schedule being used supported on the architecture in question?
|
||||
|
||||
:param td: tile description to validate
|
||||
:type td: cutlass_cppgen.backend.TileDescription
|
||||
:return: tuple in which the first element is a bool indicating that the tile description is valid
|
||||
and the second element is a string providing an optional error message.
|
||||
:rtype: tuple
|
||||
"""
|
||||
valid, msg = check.valid_stage_count(self.cc, self.current_cc, td)
|
||||
if not valid:
|
||||
return (valid, msg)
|
||||
|
||||
valid, msg = check.valid_cluster_shape(self.current_cc, td.cluster_shape)
|
||||
if not valid:
|
||||
return (valid, msg)
|
||||
|
||||
return valid, msg
|
||||
|
||||
def tile_descriptions(self) -> list:
|
||||
"""
|
||||
Returns a list of valid tile descriptions for the operations
|
||||
|
||||
:returns: list of valid tile descriptions for the operations
|
||||
:rtype: list
|
||||
"""
|
||||
descriptions = []
|
||||
description_str = []
|
||||
for op in self.possible_operations.all_operations:
|
||||
td = datatypes.td_from_profiler_op(op)
|
||||
|
||||
if self._math_operation is not None:
|
||||
if td.math_instruction.math_operation != self._math_operation:
|
||||
continue
|
||||
|
||||
if str(td) not in description_str:
|
||||
description_str.append(str(td))
|
||||
descriptions.append(td)
|
||||
return descriptions
|
||||
|
||||
#
|
||||
# Swizzling functor Related
|
||||
#
|
||||
|
||||
@property
|
||||
def swizzling_stride(self):
|
||||
"""
|
||||
Returns the stride of swizzling currently being used by the Conv2d
|
||||
|
||||
:return: swizzing stride
|
||||
"""
|
||||
return self._swizzling_stride
|
||||
|
||||
@swizzling_stride.setter
|
||||
def swizzling_stride(self, stride: int):
|
||||
"""
|
||||
Sets the swizzling functor to the type specified by `swizzling_functor`
|
||||
"""
|
||||
if not isinstance(stride, int):
|
||||
raise Exception(f"Expect integer (1, 2, 4, 8), got {stride}")
|
||||
self._swizzling_stride = stride
|
||||
|
||||
def _propose_swizzling_functor(self, stride):
|
||||
"""
|
||||
Automatically propose the swizzling functor based on the stride
|
||||
"""
|
||||
if self.conv_kind == ConvKind.Dgrad:
|
||||
if stride[0] != 1 or stride[1] != 1:
|
||||
return getattr(cutlass_cppgen.swizzle, f"StridedDgradIdentitySwizzle{self._swizzling_stride}")
|
||||
|
||||
return getattr(cutlass_cppgen.swizzle, f"IdentitySwizzle{self._swizzling_stride}")
|
||||
|
||||
#
|
||||
# Iterator Algorithm Related
|
||||
#
|
||||
|
||||
@property
|
||||
def iterator_algorithm(self) -> IteratorAlgorithm:
|
||||
"""
|
||||
Returns the iterator algorithm
|
||||
"""
|
||||
return self._iterator_algorithm
|
||||
|
||||
@iterator_algorithm.setter
|
||||
def iterator_algorithm(self, alg: str):
|
||||
"""
|
||||
Sets the iterator algorithm
|
||||
|
||||
:param alg: The iterator algorithm
|
||||
:type td: string, options: "analytic", "optimized", "few_channels", and "fixed_channels"
|
||||
"""
|
||||
iterator_alg = datatypes.getattr_enum(IteratorAlgorithm, alg)
|
||||
|
||||
# Check if the iterator algorithm is valid
|
||||
if iterator_alg in [IteratorAlgorithm.FewChannels, IteratorAlgorithm.FixedChannels] and self.conv_kind != ConvKind.Fprop:
|
||||
raise Exception(f"{self.conv_kind} does not support iterator algorithm {alg}.")
|
||||
|
||||
self._iterator_algorithm = iterator_alg
|
||||
|
||||
def _propose_iterator_algorithm(self, problem_size, alignment_a, alignment_b) -> IteratorAlgorithm:
|
||||
"""
|
||||
Propose a valid iterator algorithm based on problem size and alignment
|
||||
"""
|
||||
if self.conv_kind == ConvKind.Fprop:
|
||||
# Check whether the fixed channel is applicable
|
||||
if problem_size.C == alignment_a:
|
||||
return IteratorAlgorithm.FixedChannels
|
||||
elif (problem_size.C % alignment_a == 0 and
|
||||
problem_size.R <= 32 and problem_size.S <= 32):
|
||||
return IteratorAlgorithm.Optimized
|
||||
else:
|
||||
return IteratorAlgorithm.Analytic
|
||||
elif self.conv_kind == ConvKind.Dgrad:
|
||||
if (problem_size.K % alignment_a == 0 and
|
||||
problem_size.R <= 32 and problem_size.S <= 32 and
|
||||
problem_size.C % alignment_b == 0):
|
||||
return IteratorAlgorithm.Optimized
|
||||
else:
|
||||
return IteratorAlgorithm.Analytic
|
||||
elif self.conv_kind == ConvKind.Wgrad:
|
||||
if (problem_size.K % alignment_a == 0 and
|
||||
problem_size.C % alignment_b == 0):
|
||||
return IteratorAlgorithm.Optimized
|
||||
else:
|
||||
return IteratorAlgorithm.Analytic
|
||||
|
||||
def _validate_iterator_algorithm(self, iterator_algorithm, problem_size, alignment_a, alignment_b) -> bool:
|
||||
"""
|
||||
Validate whether the user provide iterator algorithm works for the given problem size
|
||||
"""
|
||||
if self.conv_kind == ConvKind.Fprop:
|
||||
if iterator_algorithm == IteratorAlgorithm.FixedChannels:
|
||||
return problem_size.C == alignment_a
|
||||
elif iterator_algorithm == IteratorAlgorithm.Optimized:
|
||||
return (problem_size.C % alignment_a == 0 and
|
||||
problem_size.R <= 32 and problem_size.S <= 32)
|
||||
elif iterator_algorithm == IteratorAlgorithm.FewChannels:
|
||||
return problem_size.C % alignment_a == 0
|
||||
elif self.conv_kind == ConvKind.Dgrad:
|
||||
if iterator_algorithm == IteratorAlgorithm.Optimized:
|
||||
return (problem_size.K % alignment_a == 0 and
|
||||
problem_size.R <= 32 and problem_size.S <= 32 and
|
||||
problem_size.C % alignment_b == 0)
|
||||
elif self.conv_kind == ConvKind.Wgrad:
|
||||
if iterator_algorithm == IteratorAlgorithm.Optimized:
|
||||
return (problem_size.K % alignment_a == 0 and
|
||||
problem_size.C % alignment_b == 0)
|
||||
|
||||
return True
|
||||
|
||||
#
|
||||
# Stride Support Related
|
||||
#
|
||||
|
||||
def _propose_stride_support(self, stride):
|
||||
if self.conv_kind == ConvKind.Dgrad:
|
||||
if stride[0] == 1 and stride[1] == 1:
|
||||
return StrideSupport.Unity
|
||||
|
||||
return StrideSupport.Strided
|
||||
|
||||
#
|
||||
# Construct and Compilation
|
||||
#
|
||||
|
||||
def construct(
|
||||
self, tile_description: TileDescription = None,
|
||||
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
|
||||
iterator_algorithm: IteratorAlgorithm = None,
|
||||
stride_support = None, swizzling_functor: cutlass_cppgen.swizzle = None,
|
||||
epilogue_functor=None) -> cutlass_cppgen.backend.Conv2dOperation:
|
||||
"""
|
||||
Constructs a ``cutlass_cppgen.backend.Conv2dOperation`` based on the input parameters and current
|
||||
kernel specification of the ``Conv2d`` object.
|
||||
|
||||
:param tile_description: tile description specifying shapes and operand types to use in the kernel
|
||||
:type tile_description: cutlass_cppgen.backend.TileDescription
|
||||
:param alignment_A: alignment of operand A
|
||||
:type alignment_A: int
|
||||
:param alignment_B: alignment of operand B
|
||||
:type alignment_B: int
|
||||
:param alignment_C: alignment of operand C
|
||||
:type alignment_C: int
|
||||
:param iterator_algorithm: the iterator algorithm used
|
||||
:type iterator_algorithm: cutlass_library.library.IteratorAlgorithm
|
||||
:param stride_support: the stride support of dgrad
|
||||
:type stride_support: cutlass_library.library.StrideSupport
|
||||
:param swizzling_functor: the swizzling functor
|
||||
:type swizzling_functor: cutlass_cppgen.swizzle
|
||||
:param epilogue_functor: the epilogue functor
|
||||
|
||||
:return: operation that was constructed
|
||||
:rtype: cutlass_cppgen.backend.Conv2dOperation
|
||||
"""
|
||||
# Get alignment
|
||||
alignment_A = check.alignment_or_default(alignment_A, self.alignment_pref_A)
|
||||
alignment_B = check.alignment_or_default(alignment_B, self.alignment_pref_B)
|
||||
alignment_C = check.alignment_or_default(alignment_C, self.alignment_pref_C)
|
||||
|
||||
tensor_A = TensorDescription(self._element_a, self._layout_b, alignment_A)
|
||||
tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B)
|
||||
tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C)
|
||||
|
||||
if tile_description is None:
|
||||
if self.tile_description is not None:
|
||||
tile_description = self.tile_description
|
||||
else:
|
||||
op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0]
|
||||
tile_description = datatypes.td_from_profiler_op(op)
|
||||
else:
|
||||
valid, err_str = self._valid_tile_description(tile_description)
|
||||
if not valid:
|
||||
raise Exception(f"Invalid tile description. {err_str}")
|
||||
self.tile_description = tile_description
|
||||
|
||||
if iterator_algorithm is None:
|
||||
# If the iterator algorithm is already set
|
||||
if self.iterator_algorithm is not None:
|
||||
iterator_algorithm = self.iterator_algorithm
|
||||
else:
|
||||
# Otherwise, we conservatively use the analytic iterator for correctness
|
||||
iterator_algorithm = IteratorAlgorithm.Analytic
|
||||
|
||||
if stride_support is None:
|
||||
# If the stride support is already set
|
||||
if self._stride_support is not None:
|
||||
stride_support = self._stride_support
|
||||
else:
|
||||
# Otherwise, we assume strided
|
||||
stride_support = StrideSupport.Strided
|
||||
|
||||
if swizzling_functor is None:
|
||||
# If the swizzling functor is already set
|
||||
swizzling_functor = self._propose_swizzling_functor(stride=(2, 2))
|
||||
|
||||
if epilogue_functor is None:
|
||||
if self.epilogue_functor is not None:
|
||||
epilogue_functor = self.epilogue_functor
|
||||
else:
|
||||
epilogue_functor = self._create_epilogue_functor_activation(self._activation)
|
||||
|
||||
# Reset the alignment of the epilogue functor
|
||||
epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, epilogue_functor)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=self.conv_kind,
|
||||
iterator_algorithm=iterator_algorithm,
|
||||
arch=self.current_cc,
|
||||
tile_description=tile_description,
|
||||
A=tensor_A, B=tensor_B, C=tensor_C,
|
||||
stride_support=stride_support,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=swizzling_functor,
|
||||
)
|
||||
|
||||
return operation
|
||||
|
||||
def compile(self, tile_description: TileDescription = None,
|
||||
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
|
||||
iterator_algorithm: IteratorAlgorithm = None,
|
||||
stride_support = None, swizzling_functor: cutlass_cppgen.swizzle = None,
|
||||
epilogue_functor = None, print_module: bool = False) -> cutlass_cppgen.backend.Conv2dOperation:
|
||||
"""
|
||||
Emits and compiles the kernel currently specified. If ``tile_description`` and any
|
||||
of the ``alignment`` parameters are set, the kernel will be chosen using this
|
||||
tile description and alignments. Otherwise, a default tile description and alignment
|
||||
will be used.
|
||||
|
||||
::param tile_description: tile description specifying shapes and operand types to use in the kernel
|
||||
:type tile_description: cutlass_cppgen.backend.TileDescription
|
||||
:param alignment_A: alignment of operand A
|
||||
:type alignment_A: int
|
||||
:param alignment_B: alignment of operand B
|
||||
:type alignment_B: int
|
||||
:param alignment_C: alignment of operand C
|
||||
:type alignment_C: int
|
||||
:param iterator_algorithm: the iterator algorithm used
|
||||
:type iterator_algorithm: cutlass_library.library.IteratorAlgorithm
|
||||
:param stride_support: the stride support of dgrad
|
||||
:type stride_support: cutlass_library.library.StrideSupport
|
||||
:param swizzling_functor: the swizzling functor
|
||||
:type swizzling_functor: cutlass_cppgen.swizzle
|
||||
:param epilogue_functor: the epilogue functor
|
||||
|
||||
:return: operation that was compiled
|
||||
:rtype: cutlass_cppgen.backend.Conv2dOperation
|
||||
"""
|
||||
|
||||
self.operation = self.construct(
|
||||
tile_description, alignment_A, alignment_B, alignment_C,
|
||||
iterator_algorithm, stride_support, swizzling_functor, epilogue_functor)
|
||||
|
||||
if print_module:
|
||||
print(self.operation.rt_module.emit())
|
||||
|
||||
compiler.add_module([self.operation,])
|
||||
return self.operation
|
||||
|
||||
#
|
||||
# Run Related
|
||||
#
|
||||
|
||||
def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name):
|
||||
"""
|
||||
Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception
|
||||
is raised if it does not.
|
||||
|
||||
:param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
|
||||
:type tensor: numpy/cupy/torch array/tensor object
|
||||
:param ref_dtype: data type for the tensor that this object was initialized to
|
||||
:param name: identifier of the tensor to verify. Used in raising exceptions
|
||||
:type name: str
|
||||
"""
|
||||
dtype, _ = datatypes.get_datatype_and_layout(tensor)
|
||||
if dtype != ref_type:
|
||||
raise Exception(f'Tensor {name} with type and layout {dtype} '
|
||||
f'does not match the expected type of {ref_type}.')
|
||||
|
||||
def _get_and_verify_conv_problem_size(self, A, B, C, stride, padding, dilation):
|
||||
if self.conv_kind == ConvKind.Fprop:
|
||||
input = A
|
||||
weight = B
|
||||
output = C
|
||||
output_tensor = "C"
|
||||
elif self.conv_kind == ConvKind.Dgrad:
|
||||
output = A
|
||||
weight = B
|
||||
input = C
|
||||
output_tensor = "A"
|
||||
elif self.conv_kind == ConvKind.Wgrad:
|
||||
output = A
|
||||
input = B
|
||||
weight = C
|
||||
output_tensor = "A"
|
||||
else:
|
||||
raise Exception(f"Convolution kind {self.conv_kind} is not supported")
|
||||
|
||||
N_, H_, W_, C_ = datatypes.get_tensor_shape(input, op="CONV")
|
||||
K_, R_, S_, _ = datatypes.get_tensor_shape(weight, op="CONV")
|
||||
_, P_, Q_, _ = datatypes.get_tensor_shape(output, op="CONV")
|
||||
|
||||
problem_size = Conv2DProblemSize(
|
||||
N_, H_, W_, C_,
|
||||
K_, R_, S_, C_,
|
||||
padding[0], padding[1],
|
||||
stride[0], stride[1],
|
||||
dilation[0], dilation[1],
|
||||
ConvMode.CrossCorrelation,
|
||||
1, 1
|
||||
)
|
||||
|
||||
if P_ != problem_size.P or Q_ != problem_size.Q:
|
||||
raise Exception(
|
||||
f"Tensor {output_tensor} size should be ({N_}, {problem_size.P}, {problem_size.Q}, {K_}), got ({N_}, {P_}, {Q_}, {K_})")
|
||||
|
||||
return problem_size
|
||||
|
||||
def run(self, A=None, B=None, C=None, D=None,
|
||||
stride=(1, 1), padding=(0, 0), dilation=(1, 1),
|
||||
alpha=None, beta=None,
|
||||
split_k=("serial", 1), sync: bool = True,
|
||||
print_module: bool = False,
|
||||
stream: Optional[cuda.CUstream] = None) -> Conv2dArguments:
|
||||
"""
|
||||
Runs the kernel currently specified. If it has not already been, the kernel is emitted and
|
||||
compiled. Tensors holding operands and outputs of the kernel are sourced either from the
|
||||
``A``, ``B``, ``C``, ``D``, ``alpha``, and ``beta``
|
||||
parameters provided in the call, or from those
|
||||
passed in on the construction of this object -- one of the two must be specified.
|
||||
|
||||
By default, this call returns only once the kernel has completed. To launch the kernel
|
||||
and immediately return, set ``sync=False``. In this case, it is the responsibility of the
|
||||
caller to syncrhonize the results of the kernel before attempting to access outputs
|
||||
by calling ``sync()`` on the arguments returned from this call.
|
||||
|
||||
:param A: tensor representing data type and layout of operand A
|
||||
:param B: tensor representing data type and layout of operand B
|
||||
:param C: tensor representing data type and layout of operand C
|
||||
:param D: tensor representing data type and layout of operand D
|
||||
:param stride: (stride_h, stride_w) describing the convolution stride. Default: (1, 1)
|
||||
:param padding: (pad_h, pad_w) describing the convolution padding. Default: (0, 0)
|
||||
:param dilation: (dilation_h, dilation_w) describing the dilation of convolution. Default: (1, 1)
|
||||
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
|
||||
:param beta: scalar parameter beta from GEMM operation that scales operand C
|
||||
:param split_k: a tuple (split_k_mode, split_k_slices)
|
||||
:param sync: whether the call should wait for the kernel to complete before returning
|
||||
:type sync: bool
|
||||
:param print_module: whether to print the emitted C++ code
|
||||
:type print_module: bool
|
||||
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
|
||||
:type stream: :class:`cuda.cuda.CUstream`
|
||||
|
||||
:return: arguments passed in to the kernel
|
||||
:rtype: cutlass_cppgen.backend.Conv2dArguments
|
||||
"""
|
||||
if not stream:
|
||||
stream = cuda.CUstream(0)
|
||||
super().run_setup()
|
||||
|
||||
A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A")
|
||||
B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B")
|
||||
C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C")
|
||||
D = self._verify_tensor(D, self.D, self._element_d, self._layout_d, "D")
|
||||
alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
|
||||
beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")
|
||||
|
||||
# handle the case when there is no C
|
||||
if C is None:
|
||||
if beta != 0:
|
||||
raise Exception(f"With beta {beta} != 0, C has to be provided.")
|
||||
else:
|
||||
C = D
|
||||
|
||||
# Construct problem size based on input
|
||||
# It also verifies whether the A, B, C, D, stride, padding, and dilation are matching
|
||||
problem_size = self._get_and_verify_conv_problem_size(A, B, C, stride, padding, dilation)
|
||||
|
||||
# Propose stride support based on input
|
||||
stride_support = self._propose_stride_support(stride)
|
||||
|
||||
# Propose swizzling functor
|
||||
swizzling_functor = self._propose_swizzling_functor(stride)
|
||||
|
||||
shape_a = datatypes.get_tensor_shape(A, op="CONV")
|
||||
shape_b = datatypes.get_tensor_shape(B, op="CONV")
|
||||
shape_c = datatypes.get_tensor_shape(C, op="CONV")
|
||||
|
||||
# Get the alignment
|
||||
alignment_a = self.possible_operations.find_alignment(shape_a, self._layout_a, operand="A")
|
||||
alignment_b = self.possible_operations.find_alignment(shape_b, self._layout_b, operand="B")
|
||||
alignment_c = self.possible_operations.find_alignment(shape_c, self._layout_c, operand="C")
|
||||
|
||||
alignment_a = check.update_alignment(alignment_a, self.alignment_pref_A)
|
||||
alignment_b = check.update_alignment(alignment_b, self.alignment_pref_B)
|
||||
alignment_c = check.update_alignment(alignment_c, self.alignment_pref_C)
|
||||
|
||||
# Propose iterator algorithm based on input
|
||||
if self._iterator_algorithm is None:
|
||||
# Propose a default iterator algorithm based on the problem size
|
||||
iterator_algorithm = self._propose_iterator_algorithm(problem_size, alignment_a, alignment_b)
|
||||
else:
|
||||
if (self._validate_iterator_algorithm(self._iterator_algorithm, problem_size, alignment_a, alignment_b)):
|
||||
iterator_algorithm = self._iterator_algorithm
|
||||
else:
|
||||
raise Exception(f"Iterator algorithm {self._iterator_algorithm} is invalid for current problem.")
|
||||
|
||||
epilogue_args = [alpha, beta]
|
||||
|
||||
if hasattr(self, "_activation_args"):
|
||||
if isinstance(self._activation_args, list):
|
||||
epilogue_args += self._activation_args
|
||||
else:
|
||||
epilogue_args.append(self._activation_args)
|
||||
|
||||
if split_k[0] == "parallel" and split_k[1] > 1:
|
||||
epilogue_functor = self._create_epilogue_functor_activation(epilogue.identity)
|
||||
else:
|
||||
epilogue_functor = self.epilogue_functor
|
||||
|
||||
# The alignment is determined by the iterator function (I believe)
|
||||
self.compile(tile_description=self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
|
||||
alignment_C=alignment_c, iterator_algorithm=iterator_algorithm, stride_support=stride_support,
|
||||
swizzling_functor=swizzling_functor, epilogue_functor=epilogue_functor, print_module=print_module)
|
||||
|
||||
# Create reduction operation for parallel split-k
|
||||
if split_k[0] == "parallel" and split_k[1] > 1:
|
||||
epilogue_functor_reduction = self._reset_epilogue_functor_alignment(alignment_c, self.epilogue_functor)
|
||||
self.reduction_operation = ReductionOperation(
|
||||
shape=MatrixCoord(4, 32 * alignment_c), C=self.operation.C,
|
||||
element_accumulator=self._element_accumulator,
|
||||
element_compute=self._element_accumulator,
|
||||
epilogue_functor=epilogue_functor_reduction,
|
||||
count=alignment_c
|
||||
)
|
||||
if print_module:
|
||||
print(self.reduction_operation.rt_module.emit())
|
||||
compiler.add_module([self.reduction_operation,])
|
||||
|
||||
arguments = Conv2dArguments(
|
||||
operation=self.operation, problem_size=problem_size,
|
||||
A=A, B=B, C=C, D=D,
|
||||
output_op=self.operation.epilogue_type(*epilogue_args),
|
||||
split_k_mode=datatypes.getattr_enum(SplitKMode, split_k[0]),
|
||||
split_k_slices=split_k[1],
|
||||
stream=stream
|
||||
)
|
||||
|
||||
self.operation.run(arguments)
|
||||
|
||||
if split_k[0] == "parallel" and split_k[1] > 1:
|
||||
implicit_gemm_size = arguments.problem_size.implicit_gemm_size(self.conv_kind)
|
||||
reduction_arguments = ReductionArguments(
|
||||
self.reduction_operation,
|
||||
problem_size=[implicit_gemm_size.m, implicit_gemm_size.n],
|
||||
partitions=split_k[1],
|
||||
workspace=arguments.ptr_D,
|
||||
destination=D,
|
||||
source=C,
|
||||
output_op=self.reduction_operation.epilogue_type(*epilogue_args),
|
||||
stream=stream
|
||||
)
|
||||
self.reduction_operation.run(reduction_arguments)
|
||||
|
||||
if sync:
|
||||
if split_k[0] == "parallel" and split_k[1] > 1:
|
||||
reduction_arguments.sync()
|
||||
|
||||
# Free memory allocated by args because we are not
|
||||
# calling `arguments.sync()` in this case (which will free memory)
|
||||
arguments.free()
|
||||
else:
|
||||
arguments.sync()
|
||||
|
||||
return arguments
|
||||
|
||||
#
|
||||
# Helper functions
|
||||
#
|
||||
@staticmethod
|
||||
def output_size(input_size, weight_size, padding, stride, dilation):
|
||||
problem_size = Conv2DProblemSize(
|
||||
*input_size,
|
||||
*weight_size,
|
||||
padding[0], padding[1],
|
||||
stride[0], stride[1],
|
||||
dilation[0], dilation[1],
|
||||
ConvMode.CrossCorrelation,
|
||||
1, 1
|
||||
)
|
||||
return (problem_size.N, problem_size.P, problem_size.Q, problem_size.K)
|
||||
|
||||
|
||||
#
|
||||
# Easy to use interfaces for fprop, wgrad, and dgrad
|
||||
#
|
||||
|
||||
class Conv2dFprop(Conv2d):
|
||||
def __init__(
|
||||
self,
|
||||
input=None, weight=None, C=None, output=None, alpha=1, beta=0,
|
||||
element=None,
|
||||
element_input=None, element_weight=None, element_C=None, element_output=None,
|
||||
element_accumulator=None,
|
||||
cc: int = None, kernel_cc: int = None):
|
||||
A, B, D = input, weight, output
|
||||
element_A, element_B, element_D = element_input, element_weight, element_output
|
||||
super().__init__(
|
||||
"fprop", A, B, C, D, alpha, beta, element,
|
||||
element_A, element_B, element_C, element_D,
|
||||
element_accumulator, cc, kernel_cc)
|
||||
|
||||
def run(
|
||||
self, input=None, weight=None, C=None, output=None, alpha=None, beta=None,
|
||||
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
|
||||
sync: bool = True, print_module: bool = False,
|
||||
stream: Optional[cuda.CUstream] = None) -> Conv2dArguments:
|
||||
|
||||
if not stream:
|
||||
stream = cuda.CUstream(0)
|
||||
|
||||
A, B, D = input, weight, output
|
||||
return super().run(
|
||||
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream)
|
||||
|
||||
|
||||
class Conv2dDgrad(Conv2d):
|
||||
def __init__(
|
||||
self,
|
||||
grad_output=None, weight=None, C=None, grad_input=None, alpha=1, beta=0,
|
||||
element=None,
|
||||
element_grad_output=None, element_weight=None, element_C=None, element_grad_input=None,
|
||||
element_accumulator=None,
|
||||
cc: int = None, kernel_cc: int = None):
|
||||
A, B, D = grad_output, weight, grad_input
|
||||
element_A, element_B, element_D = element_grad_output, element_weight, element_grad_input
|
||||
super().__init__(
|
||||
"dgrad", A, B, C, D, alpha, beta, element,
|
||||
element_A, element_B, element_C, element_D,
|
||||
element_accumulator, cc, kernel_cc)
|
||||
|
||||
def run(self, grad_output=None, weight=None, C=None, grad_input=None, alpha=None, beta=None,
|
||||
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
|
||||
sync: bool = True, print_module: bool = False,
|
||||
stream: Optional[cuda.CUstream] = None) -> Conv2dArguments:
|
||||
#
|
||||
if not stream:
|
||||
stream = cuda.CUstream(0)
|
||||
|
||||
A, B, D = grad_output, weight, grad_input
|
||||
return super().run(
|
||||
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream)
|
||||
|
||||
|
||||
class Conv2dWgrad(Conv2d):
|
||||
def __init__(
|
||||
self,
|
||||
grad_output=None, input=None, C=None, grad_weight=None, alpha=1, beta=0,
|
||||
element=None,
|
||||
element_grad_output=None, element_input=None, element_C=None, element_grad_weight=None,
|
||||
element_accumulator=None,
|
||||
cc: int = None, kernel_cc: int = None):
|
||||
A, B, D = grad_output, input, grad_weight
|
||||
element_A, element_B, element_D = element_grad_output, element_input, element_grad_weight
|
||||
super().__init__(
|
||||
"wgrad", A, B, C, D, alpha, beta, element,
|
||||
element_A, element_B, element_C, element_D,
|
||||
element_accumulator, cc, kernel_cc)
|
||||
|
||||
def run(self, grad_output=None, input=None, C=None, grad_weight=None, alpha=None, beta=None,
|
||||
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
|
||||
sync: bool = True, print_module: bool = False,
|
||||
stream: Optional[cuda.CUstream] = None) -> Conv2dArguments:
|
||||
if not stream:
|
||||
stream = cuda.CUstream(0)
|
||||
|
||||
A, B, D = grad_output, input, grad_weight
|
||||
return super().run(
|
||||
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream)
|
||||
725
python/cutlass_cppgen/op/gemm.py
Normal file
725
python/cutlass_cppgen/op/gemm.py
Normal file
@ -0,0 +1,725 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Ease-of-use interface for constructing, compiling, and running GEMMs.
|
||||
|
||||
The ``Gemm`` interface is meant to allow one to easily instantiate, compile, and run
|
||||
GEMM operations in CUTLASS via Python, without specifying many configuration parameters.
|
||||
Under the hood, the interface will select sensible default parameters for the many template
|
||||
parameters for CUTLASS GEMMs.
|
||||
|
||||
Note: optimal performance is not to be expected from this interface. To achieve optimal
|
||||
performance, one should specify and tune each configuration parameter.
|
||||
|
||||
The simplest example of using this interface is the following:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
# A, B, C, and D are torch/numpy/cupy tensor objects
|
||||
plan = cutlass_cppgen.op.Gemm(A, B, C, D)
|
||||
plan.run()
|
||||
|
||||
|
||||
One can also use the interface by specifying data types of operands at construction
|
||||
and using different tensor objects with these data types at runtime:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
# The following is shorthand for:
|
||||
# cutlass_cppgen.op.Gemm(element_A=torch.float32, element_B=torch.float32,
|
||||
# element_C=torch.float32, element_D=torch.float32,
|
||||
# element_accumulator=torch.float32,
|
||||
# layout=cutlass_cppgen.LayoutType.RowMajor)
|
||||
plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
||||
|
||||
A0 = torch.rand((128, 256), device='cuda')
|
||||
B0 = torch.rand((256, 64), device='cuda')
|
||||
C0 = torch.zeros((128, 64), device='cuda')
|
||||
D0 = torch.zeros((128, 64), device.'cuda')
|
||||
plan.run(A0, B0, C0, D0)
|
||||
|
||||
A = torch.rand((32, 128), device='cuda')
|
||||
B = torch.rand((128, 256), device='cuda')
|
||||
C = torch.zeros((32, 256), device='cuda')
|
||||
D = torch.zeros((32, 256), device.'cuda')
|
||||
plan.run(A1, B1, C1, D1)
|
||||
|
||||
The interface additionally enables one to decouple the compilation of the underlying CUTLASS
|
||||
kernel from its execution:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
||||
plan.compile()
|
||||
|
||||
# Do other work...
|
||||
|
||||
plan.run(A0, B0, C0, D0)
|
||||
|
||||
# Do other work...
|
||||
|
||||
plan.run(A1, B1, C1, D1)
|
||||
|
||||
Elementwise activation functions are easily fused to the GEMM via the interface:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
||||
plan.activation = cutlass_cppgen.epilogue.relu
|
||||
|
||||
Operations can also be run asynchronously:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
||||
args = plan.run()
|
||||
|
||||
# Do other work...
|
||||
|
||||
args.sync()
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from typing import Optional
|
||||
from math import prod
|
||||
|
||||
from cutlass_cppgen.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
from cutlass_library import (
|
||||
DataType,
|
||||
DataTypeSize,
|
||||
GemmUniversalMode,
|
||||
KernelScheduleSuffixes,
|
||||
)
|
||||
|
||||
import cutlass_cppgen
|
||||
from cutlass_cppgen import epilogue, swizzle
|
||||
from cutlass_cppgen.backend import compiler
|
||||
from cutlass_cppgen.backend.evt import EpilogueFunctorVisitor
|
||||
from cutlass_cppgen.backend.gemm_operation import GemmArguments, GemmOperationUniversal
|
||||
from cutlass_cppgen.backend.library import TensorDescription, TileDescription
|
||||
from cutlass_cppgen.op.op import OperationBase
|
||||
from cutlass_cppgen.shape import GemmCoord
|
||||
from cutlass_cppgen.utils import check, datatypes
|
||||
|
||||
|
||||
class Gemm(OperationBase):
|
||||
"""
|
||||
Constructs a ``Gemm`` object.
|
||||
|
||||
The data types and layouts of operands A, B, and C, along with the data type of output D
|
||||
and that used for accumulation, are bound to the ``Gemm`` object throughout its lifetime --
|
||||
these are not to be changed after a ``Gemm`` has been constructed.
|
||||
|
||||
The constructor has optional parameters for flexibly setting these parameters. The following
|
||||
constructors are equivalent:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
# Use F32 for A, B, C, D, and accumulation. All operands are row major.
|
||||
|
||||
# Use the generic ``element`` and ``layout`` parameters to concisely set all data types and layouts
|
||||
# for operands to the same values.
|
||||
Gemm(element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
||||
|
||||
# Explicitly specify the data types to use for A, B, C, and D. Use the generic ``layout``.
|
||||
Gemm(element_A=cutlass_cppgen.DataType.f32, element_B=cutlass_cppgen.DataType.f32, element_C=cutlass_cppgen.DataType.f32,
|
||||
element_D=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
||||
|
||||
# Set the data types and elements from existing tensors. Note that one can use different tensors when
|
||||
# executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must
|
||||
# have the same data type and layout as those passed in here).
|
||||
# A, B, C, and D are row-major torch.Tensor objects of type torch.float32
|
||||
Gemm(A=A, B=B, C=C, D=D)
|
||||
|
||||
# Use the generic ``element`` and explicitly specify the layouts to use for A, B, and C (layout of D is
|
||||
# the same as that for D, at present)
|
||||
Gemm(element=cutlass_cppgen.DataType.f32, layout_A=cutlass_cppgen.LayoutType.RowMajor,
|
||||
layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor)
|
||||
|
||||
# Explicitly specify the data type and layout for only some of A, B, C, and D. Unspecified data types
|
||||
# and layouts will inherit those passed in via the generic ``element`` and ``layout``
|
||||
Gemm(element_A=cutlass_cppgen.DataType.f32, layout_B=cutlass_cppgen.LayoutType.RowMajor,
|
||||
element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
||||
|
||||
The order of precedence for the setting of the data type and layout for a given operand/output is as follows:
|
||||
1) If the tensor type is specified (e.g., ``A``), use the data type and layout inferred from this tensor
|
||||
2) Otherwise, if the data type/layout (e.g., ``element_A``, ``layout_A``) is specified, use those
|
||||
3) Otherwise, use the generic values (e.g., ``element``, ``layout``)
|
||||
|
||||
:param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90
|
||||
:type cc: int
|
||||
:param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80
|
||||
:type kernel_cc: int
|
||||
:param A: tensor representing data type and layout of operand A
|
||||
:param B: tensor representing data type and layout of operand B
|
||||
:param C: tensor representing data type and layout of operand C
|
||||
:param D: tensor representing data type and layout of operand D
|
||||
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
|
||||
:param beta: scalar parameter beta from GEMM operation that scales operand C
|
||||
:param element_accumulator: data type to be used in accumulation of the product of operands A and B
|
||||
:type element_accumulator: cutlass_cppgen.DataType
|
||||
:param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type
|
||||
:type element: cutlass_cppgen.DataType
|
||||
:param layout: generic layout type to be used for operands A, B, C, and D
|
||||
:type layout: cutlass_cppgen.LayoutType
|
||||
:param element_A: data type to be used for operand A
|
||||
:type element_A: cutlass_cppgen.DataType
|
||||
:param element_B: data type to be used for operand B
|
||||
:type element_B: cutlass_cppgen.DataType
|
||||
:param element_C: data type to be used for operand C
|
||||
:type element_C: cutlass_cppgen.DataType
|
||||
:param element_D: data type to be used for operand D
|
||||
:type element_D: cutlass_cppgen.DataType
|
||||
:param layout_A: layout of operand A
|
||||
:type layout_A: cutlass_cppgen.LayoutType
|
||||
:param layout_B: layout of operand B
|
||||
:type layout_B: cutlass_cppgen.LayoutType
|
||||
:param layout_C: layout of operand C
|
||||
:type layout_C: cutlass_cppgen.LayoutType
|
||||
:param layout_D: layout of operand D
|
||||
:type layout_D: cutlass_cppgen.LayoutType
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, A=None, B=None, C=None, D=None,
|
||||
alpha=1.0, beta=0.0, element_accumulator=None,
|
||||
element=None, layout=None,
|
||||
element_A=None, element_B=None, element_C=None, element_D=None,
|
||||
layout_A=None, layout_B=None, layout_C=None,
|
||||
cc: int = None, kernel_cc: int = None
|
||||
):
|
||||
super().__init__(cc=cc, kernel_cc=kernel_cc)
|
||||
self.name = "gemm"
|
||||
self.compiled = False
|
||||
|
||||
elements = []
|
||||
layouts = []
|
||||
|
||||
# Check that at least one of the following is set for each tensor (illustrated assuming tensor A):
|
||||
# ``A``, ``element_A``, ``element`` and ``A``, ``layout_A``, ``layout``
|
||||
for elt, lay, tens, name in zip([element_A, element_B, element_C, element_D],
|
||||
[layout_A, layout_B, layout_C, layout_C],
|
||||
[A, B, C, D],
|
||||
["A", "B", "C", "D"]):
|
||||
if elt is not None and tens is not None:
|
||||
raise Exception(f'Must not specify both element_{name} and tensor {name}')
|
||||
if lay is not None and tens is not None:
|
||||
raise Exception(f'Must not specify both layout_{name} and tensor {name}')
|
||||
if elt is None and tens is None and element is None:
|
||||
raise Exception(f'Must specify one of element_{name}, tensor {name}, or generic element.')
|
||||
if lay is None and tens is None and layout is None:
|
||||
raise Exception(f'Must specify one of layout_{name}, tensor {name}, or generic layout.')
|
||||
|
||||
elt_to_set = None
|
||||
lay_to_set = None
|
||||
if tens is not None:
|
||||
elt_to_set, lay_to_set = datatypes.get_datatype_and_layout(tens)
|
||||
else:
|
||||
elt_to_set = elt if elt is not None else element
|
||||
lay_to_set = lay if lay is not None else layout
|
||||
|
||||
elements.append(datatypes.library_type(elt_to_set))
|
||||
layouts.append(lay_to_set)
|
||||
|
||||
self._element_a, self._element_b, self._element_c, self._element_d = elements
|
||||
self._layout_a, self._layout_b, self._layout_c, self._layout_d = layouts
|
||||
|
||||
if element_accumulator is None:
|
||||
self._element_accumulator = self._element_c
|
||||
else:
|
||||
self._element_accumulator = datatypes.library_type(element_accumulator)
|
||||
|
||||
self.A = A
|
||||
self.B = B
|
||||
self.C = C
|
||||
self.D = D
|
||||
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
|
||||
self.epilogue_functor = None
|
||||
self.op_class = None
|
||||
self._tile_description = None
|
||||
|
||||
self._reset_operations()
|
||||
|
||||
self._swizzling_functor = cutlass_cppgen.swizzle.IdentitySwizzle1
|
||||
|
||||
def _reset_operations(self, reset_epilogue: bool = True):
|
||||
# Set the default op class
|
||||
datatype_comb = (self._element_a, self._element_b, self._element_accumulator)
|
||||
layout_comb = (self._layout_a, self._layout_b)
|
||||
|
||||
self.possible_op_classes = self.options.supporting_opclasses(
|
||||
self._element_a, self._element_b, self._element_accumulator,
|
||||
self._layout_a, self._layout_b, self._math_operation)
|
||||
|
||||
if cutlass_cppgen.OpcodeClass.TensorOp in self.possible_op_classes:
|
||||
self.opclass = cutlass_cppgen.OpcodeClass.TensorOp
|
||||
elif cutlass_cppgen.OpcodeClass.Simt in self.possible_op_classes:
|
||||
self.opclass = cutlass_cppgen.OpcodeClass.Simt
|
||||
else:
|
||||
if self._math_operation is not None:
|
||||
math_op_str = f' and math operation {self._math_operation}'
|
||||
else:
|
||||
math_op_str = ''
|
||||
|
||||
raise Exception(f'No kernel configuration found for supported data type and layout '
|
||||
f'combination {datatype_comb}x{layout_comb}{math_op_str}')
|
||||
|
||||
if reset_epilogue:
|
||||
self._reset_epilogue_functor_activation(cutlass_cppgen.epilogue.identity)
|
||||
|
||||
@property
|
||||
def swizzling_functor(self):
|
||||
"""
|
||||
Returns the type of the swizzling functor currently being used by the GEMM
|
||||
|
||||
:return: swizzing functor type
|
||||
"""
|
||||
return self._swizzling_functor
|
||||
|
||||
@swizzling_functor.setter
|
||||
def swizzling_functor(self, swizzling_functor):
|
||||
"""
|
||||
Sets the swizzling functor to the type specified by `swizzling_functor`
|
||||
"""
|
||||
if swizzling_functor == cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK:
|
||||
if self.op_class == cutlass_cppgen.OpcodeClass.Simt:
|
||||
raise Exception('ThreadblockSwizzleStreamK is currently only supported with opcode class TensorOp')
|
||||
|
||||
if self.current_cc in [90, 100, 101, 103]:
|
||||
raise Exception('ThreadblockSwizzleStreamK is currently unsupported on SM90+')
|
||||
self._swizzling_functor = swizzling_functor
|
||||
|
||||
#
|
||||
# Tile description Related
|
||||
#
|
||||
|
||||
@property
|
||||
def tile_description(self) -> TileDescription:
|
||||
"""
|
||||
Returns the tile description
|
||||
"""
|
||||
return self._tile_description
|
||||
|
||||
@tile_description.setter
|
||||
def tile_description(
|
||||
self, td=None):
|
||||
"""
|
||||
Set the tile description
|
||||
|
||||
:param td: tile description
|
||||
:type td: cutlass_cppgen.backend.TileDescription, or a dict with keys
|
||||
{
|
||||
"threadblock_shape": [int, int, int],
|
||||
"warp_count": [int, int, int],
|
||||
"stages": int,
|
||||
"instruction_shape": [int, int, int] (optional),
|
||||
"cluster_shape": [int, int, int] (optional)
|
||||
}
|
||||
"""
|
||||
if td is None:
|
||||
return
|
||||
if isinstance(td, dict):
|
||||
if self._tile_description is None:
|
||||
op = self.possible_operations.default_operation(self._math_operation)
|
||||
self._tile_description = datatypes.td_from_profiler_op(op)
|
||||
td = self._tile_description.clone_and_update(td)
|
||||
|
||||
valid, msg = self._valid_tile_description(td)
|
||||
if valid:
|
||||
self._tile_description = td
|
||||
else:
|
||||
raise Exception(msg)
|
||||
|
||||
def _valid_tile_description(self, td: TileDescription) -> tuple:
|
||||
"""
|
||||
Checks whether the provided tile description is valid for the given compute capability. At present,
|
||||
this checks the following:
|
||||
|
||||
- Does the tile description use a number of stages supported by the compute capability in question?
|
||||
- Does the tile size requested fit within shared memory?
|
||||
- Are cluster dimensions outside the valid range requested for a given architecture (e.g.,
|
||||
more non-unit cluster dimensions for pre-SM90 architectures)?
|
||||
- Is the kernel schedule being used supported on the architecture in question?
|
||||
|
||||
:param td: tile description to validate
|
||||
:type td: cutlass_cppgen.backend.TileDescription
|
||||
:return: tuple in which the first element is a bool indicating that the tile description is valid
|
||||
and the second element is a string providing an optional error message.
|
||||
:rtype: tuple
|
||||
"""
|
||||
valid, msg = check.valid_stage_count(self.cc, self.current_cc, td, self._element_c, self._element_d)
|
||||
if not valid:
|
||||
return (valid, msg)
|
||||
|
||||
valid, msg = check.valid_cluster_shape(self.current_cc, td.cluster_shape)
|
||||
if not valid:
|
||||
return (valid, msg)
|
||||
|
||||
valid, msg = check.valid_schedule(self.current_cc, td.kernel_schedule, td.epilogue_schedule, td.tile_scheduler)
|
||||
|
||||
if self.cc in [100, 101, 103] and td.kernel_schedule is not None and td.is_2sm and td.cluster_shape[0] % 2 != 0:
|
||||
valid = False
|
||||
msg = "Cluster shape must be divisible by 2 for 2SM kernels on SM100, SM101, and SM103"
|
||||
|
||||
return valid, msg
|
||||
|
||||
def tile_descriptions(self) -> list:
|
||||
"""
|
||||
Returns a list of valid tile descriptions for the operations
|
||||
|
||||
:returns: list of valid tile descriptions for the operations
|
||||
:rtype: list
|
||||
"""
|
||||
tds = [datatypes.td_from_profiler_op(op) for op in self.possible_operations.all_operations]
|
||||
if self._math_operation is not None:
|
||||
tds = [td for td in tds if td.math_instruction.math_operation == self._math_operation]
|
||||
return tds
|
||||
|
||||
def construct(
|
||||
self, tile_description: TileDescription = None,
|
||||
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None) -> GemmOperationUniversal:
|
||||
"""
|
||||
Constructs a ``cutlass_cppgen.backend.GemmUniversalOperation`` based on the input parameters and current
|
||||
kernel specification of the ``Gemm`` object.
|
||||
|
||||
:param tile_description: tile description specifying shapes and operand types to use in the kernel
|
||||
:type tile_description: cutlass_cppgen.backend.TileDescription
|
||||
:param alignment_A: alignment of operand A
|
||||
:type alignment_A: int
|
||||
:param alignment_B: alignment of operand B
|
||||
:type alignment_B: int
|
||||
:param alignment_C: alignment of operand C
|
||||
:type alignment_C: int
|
||||
|
||||
:return: operation that was constructed
|
||||
:rtype: cutlass_cppgen.backend.GemmOperationUniversal
|
||||
"""
|
||||
alignment_pref_A = min(128 // DataTypeSize[self._element_a], max(self.possible_operations.alignments("A")))
|
||||
alignment_pref_B = min(128 // DataTypeSize[self._element_b], max(self.possible_operations.alignments("B")))
|
||||
alignment_A = check.alignment_or_default(alignment_A, alignment_pref_A)
|
||||
alignment_B = check.alignment_or_default(alignment_B, alignment_pref_B)
|
||||
|
||||
tensor_A = TensorDescription(self._element_a, self._layout_a, alignment_A)
|
||||
tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B)
|
||||
|
||||
if alignment_C is None:
|
||||
alignment_C = max(self.possible_operations.alignments("C"))
|
||||
if self._element_c != DataType.void:
|
||||
alignment_C = min(128 // DataTypeSize[self._element_c], alignment_C)
|
||||
|
||||
if tile_description is None:
|
||||
if self._tile_description is None:
|
||||
op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0]
|
||||
tile_description = datatypes.td_from_profiler_op(op)
|
||||
|
||||
# The selected op may have lower alignment than that determined above, so we must
|
||||
# reset alignment here.
|
||||
alignment_C = op.C.alignment
|
||||
else:
|
||||
tile_description = self._tile_description
|
||||
else:
|
||||
valid, err_str = self._valid_tile_description(tile_description)
|
||||
if not valid:
|
||||
raise Exception(f"Invalid tile description. {err_str}")
|
||||
self._tile_description = tile_description
|
||||
|
||||
tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C)
|
||||
self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor)
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=self.current_cc,
|
||||
tile_description=tile_description,
|
||||
A=tensor_A, B=tensor_B, C=tensor_C,
|
||||
epilogue_functor=self.epilogue_functor,
|
||||
swizzling_functor=self._swizzling_functor,
|
||||
)
|
||||
|
||||
return operation
|
||||
|
||||
def compile(self, tile_description: TileDescription = None,
|
||||
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
|
||||
print_module: bool = False) -> cutlass_cppgen.backend.GemmOperationUniversal:
|
||||
"""
|
||||
Emits and compiles the kernel currently specified. If ``tile_description`` and any
|
||||
of the ``alignment`` parameters are set, the kernel will be chosen using this
|
||||
tile description and alignments. Otherwise, a default tile description and alignment
|
||||
will be used.
|
||||
|
||||
:param tile_description: tile description specifying shapes and operand types to use in the kernel
|
||||
:type tile_description: cutlass_cppgen.backend.TileDescription
|
||||
:param alignment_A: alignment of operand A
|
||||
:type alignment_A: int
|
||||
:param alignment_B: alignment of operand B
|
||||
:type alignment_B: int
|
||||
:param alignment_C: alignment of operand C
|
||||
:type alignment_C: int
|
||||
:param print_module: whether to print the emitted C++ code
|
||||
:type print_module: bool
|
||||
|
||||
:return: operation that was compiled
|
||||
:rtype: cutlass_cppgen.backend.GemmOperationUniversal
|
||||
"""
|
||||
self.operation = self.construct(tile_description, alignment_A, alignment_B, alignment_C)
|
||||
|
||||
if print_module:
|
||||
print(self.operation.rt_module.emit())
|
||||
|
||||
compiler.add_module([self.operation,])
|
||||
return self.operation
|
||||
|
||||
def _verify_rank(self, tensor):
|
||||
"""
|
||||
Verifies that ``tensor`` has rank greater than 1
|
||||
|
||||
:param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
|
||||
:type tensor: numpy/cupy/torch array/tensor object
|
||||
"""
|
||||
if len(tensor.shape) < 2:
|
||||
raise Exception(f"Tensors must be of rank greater than 1. Received tensor of shape: {tensor.shape}")
|
||||
|
||||
def _get_batch_count(self, A, B, C, D) -> int:
|
||||
"""
|
||||
Returns the batch count specified by the tensors A, B, C, and D and verifies that these
|
||||
tensors match in batch size. Presence of a batch dimension is detected by one of the
|
||||
tensors being rank 3. If a batch dimension is present, it must be present in one of
|
||||
operands A, B, or C (but need not be in all), and must be present in D.
|
||||
|
||||
:param A: tensor A
|
||||
:type A: numpy/cupy/torch array/tensor object
|
||||
:param B: tensor B
|
||||
:type B: numpy/cupy/torch array/tensor object
|
||||
:param C: tensor C
|
||||
:type C: numpy/cupy/torch array/tensor object
|
||||
:param D: tensor D
|
||||
:type D: numpy/cupy/torch array/tensor object
|
||||
|
||||
:return: tuple of batch count dimensions
|
||||
:rtype: tuple
|
||||
"""
|
||||
A_batch = prod(A.shape[:-2]) if len(A.shape) > 2 else 1
|
||||
B_batch = prod(B.shape[:-2]) if len(B.shape) > 2 else 1
|
||||
|
||||
if 1 not in [A_batch, B_batch]:
|
||||
if A_batch != B_batch:
|
||||
raise Exception(f"Get invalid batch counts: A={A_batch}, B={B_batch}")
|
||||
return max(A_batch, B_batch)
|
||||
|
||||
def _get_batch_stride(self, tensor) -> int:
|
||||
"""
|
||||
Returns the batch stride of ``tensor``. If ``tensor`` is only rank-2, batch stride is 0.
|
||||
|
||||
:param tensor: tensor object to process
|
||||
:type tensor: numpy/cupy/torch array/tensor object
|
||||
|
||||
:return: stride between each matrix in the batch
|
||||
:rtype: int
|
||||
"""
|
||||
if tensor is not None and len(tensor.shape) > 2:
|
||||
return tensor.shape[-2] * tensor.shape[-1]
|
||||
else:
|
||||
return 0
|
||||
|
||||
def _get_problem_args(self, A, B, C, D) -> tuple:
|
||||
"""
|
||||
Returns the problem size and GEMM universal mode to use for the
|
||||
given operands.
|
||||
|
||||
:param A: tensor A
|
||||
:type A: numpy/cupy/torch array/tensor object
|
||||
:param B: tensor B
|
||||
:type B: numpy/cupy/torch array/tensor object
|
||||
:param C: tensor C
|
||||
:type C: numpy/cupy/torch array/tensor object
|
||||
:param D: tensor D
|
||||
:type D: numpy/cupy/torch array/tensor object
|
||||
|
||||
:return: tuple containing the problem size (cutlass_cppgen.shape.GemmCoord), the GEMM mode (cutlass_cppgen.GemmUniversalMode), and the batch count (int)
|
||||
:rtype: tuple
|
||||
"""
|
||||
M, K = A.shape[-2:]
|
||||
N = B.shape[-1]
|
||||
mode = GemmUniversalMode.Gemm
|
||||
|
||||
batch_count = self._get_batch_count(A, B, C, D)
|
||||
returned_batch_count = batch_count
|
||||
|
||||
# If we are running a batched GEMM in which there is a nonzero batch stride
|
||||
# only for A, then we can fold the batched dimension of A into the M dimension
|
||||
# (i.e., (b, m, k) x (k, n) -> (m*b, k) x (k, n)). This works only if both A
|
||||
# and C are row major. A similar operation can be performed if only B has a nonzero
|
||||
# batch dimension
|
||||
if batch_count > 1:
|
||||
A_row = self._layout_a == cutlass_cppgen.LayoutType.RowMajor
|
||||
B_row = self._layout_b == cutlass_cppgen.LayoutType.RowMajor
|
||||
C_row = self._layout_c == cutlass_cppgen.LayoutType.RowMajor
|
||||
|
||||
# Consider a Tensor to be batched if its rank is > 2 and
|
||||
# the product of the modes beyond rank 2 equals our pre-determined batch size.
|
||||
batched = lambda x : x is None or (len(x.shape) > 2 and prod(x.shape[:-2]) == batch_count)
|
||||
|
||||
if batched(A) and not batched(B) and (C is None or batched(C)) and A_row and C_row:
|
||||
M *= batch_count
|
||||
returned_batch_count = 1
|
||||
elif not batched(A) and batched(B) and (C is None or batched(C)) and not B_row and not C_row:
|
||||
N *= batch_count
|
||||
returned_batch_count = 1
|
||||
else:
|
||||
mode = GemmUniversalMode.Batched
|
||||
|
||||
return GemmCoord(M, N, K), mode, returned_batch_count
|
||||
|
||||
def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name):
|
||||
"""
|
||||
Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception
|
||||
is raised if it does not.
|
||||
|
||||
:param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
|
||||
:type tensor: numpy/cupy/torch array/tensor object
|
||||
:param ref_dtype: data type for the tensor that this object was initialized to
|
||||
:param ref_layout: layout for the tensor that this object was initialized to
|
||||
:param name: identifier of the tensor to verify. Used in raising exceptions
|
||||
:type name: str
|
||||
"""
|
||||
dtype, layout = datatypes.get_datatype_and_layout(tensor)
|
||||
if dtype != ref_type or layout != ref_layout:
|
||||
try:
|
||||
# Attempt to transpose the tensor to fit the desired layout
|
||||
tensor = tensor.transpose(-1, -2)
|
||||
except:
|
||||
raise Exception(f'Tensor {name} with type and layout ({dtype}, {layout}) '
|
||||
f'does not match the expected type and '
|
||||
f'layout of ({ref_type}, {ref_layout}) and transpose failed.')
|
||||
|
||||
def run(self, A=None, B=None, C=None, D=None,
|
||||
alpha=None, beta=None, sync: bool = True, print_module: bool = False, visitor_args: dict = None,
|
||||
stream: Optional[cuda.CUstream] = None) -> GemmArguments:
|
||||
"""
|
||||
Runs the kernel currently specified. If it has not already been, the kernel is emitted and
|
||||
compiled. Tensors holding operands and outputs of the kernel are sourced either from the
|
||||
``A``, ``B``, ``C``, ``D``, ``alpha``, and ``beta``
|
||||
parameters provided in this call, or from those
|
||||
passed in on the construction of this object -- one of the two must be specified.
|
||||
|
||||
By default, this call returns only once the kernel has completed. To launch the kernel
|
||||
and immediately return, set ``sync=False``. In this case, it is the responsibility of the
|
||||
caller to syncrhonize the results of the kernel before attempting to access outputs
|
||||
by calling ``sync()`` on the arguments returned from this call.
|
||||
|
||||
:param A: tensor representing data type and layout of operand A
|
||||
:param B: tensor representing data type and layout of operand B
|
||||
:param C: tensor representing data type and layout of operand C
|
||||
:param D: tensor representing data type and layout of operand D
|
||||
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
|
||||
:param beta: scalar parameter beta from GEMM operation that scales operand C
|
||||
:param sync: whether the call should wait for the kernel to complete before returning
|
||||
:type sync: bool
|
||||
:param print_module: whether to print the emitted C++ code
|
||||
:type print_module: bool
|
||||
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
|
||||
:type stream: :class:`cuda.cuda.CUstream`
|
||||
|
||||
:return: arguments passed in to the kernel
|
||||
:rtype: cutlass_cppgen.backend.GemmArguments
|
||||
"""
|
||||
if not stream:
|
||||
stream = cuda.CUstream(0)
|
||||
super().run_setup()
|
||||
A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A")
|
||||
B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B")
|
||||
C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C")
|
||||
D = self._verify_tensor(D, self.D, self._element_d, self._layout_d, "D")
|
||||
alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
|
||||
beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")
|
||||
|
||||
is_void_c = self._element_c == DataType.void
|
||||
|
||||
self._verify_rank(A)
|
||||
self._verify_rank(B)
|
||||
if not is_void_c:
|
||||
self._verify_rank(C)
|
||||
self._verify_rank(D)
|
||||
|
||||
alignment_a = self.possible_operations.find_alignment(A.shape, self._layout_a, operand="A")
|
||||
alignment_b = self.possible_operations.find_alignment(B.shape, self._layout_b, operand="B")
|
||||
|
||||
# Set C alignment based on D.shape so as to correctly get an alignment with void-C
|
||||
# kernels, for which `C` is None.
|
||||
alignment_c = self.possible_operations.find_alignment(D.shape, self._layout_c, operand="C")
|
||||
self.compile(self._tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
|
||||
alignment_C=alignment_c, print_module=print_module)
|
||||
|
||||
problem_size, mode, batch_count = self._get_problem_args(A, B, C, D)
|
||||
|
||||
if mode == GemmUniversalMode.Gemm or batch_count == 1:
|
||||
kwargs = {'split_k_slices': 1}
|
||||
else:
|
||||
kwargs = {
|
||||
'batch': batch_count,
|
||||
'batch_strides': {
|
||||
'A': self._get_batch_stride(A),
|
||||
'B': self._get_batch_stride(B),
|
||||
'C': self._get_batch_stride(C),
|
||||
'D': self._get_batch_stride(D)
|
||||
}
|
||||
}
|
||||
|
||||
kwargs['stream'] = stream
|
||||
|
||||
if isinstance(self.epilogue_functor, EpilogueFunctorVisitor):
|
||||
output_op = self.operation.epilogue_type(visitor_args)
|
||||
else:
|
||||
output_op = self.operation.epilogue_type(alpha, beta)
|
||||
|
||||
arguments = GemmArguments(
|
||||
operation=self.operation, problem_size=problem_size,
|
||||
A=A, B=B, C=C, D=D,
|
||||
output_op=output_op,
|
||||
gemm_mode=mode,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self.operation.run(arguments)
|
||||
|
||||
if sync:
|
||||
arguments.sync()
|
||||
|
||||
return arguments
|
||||
269
python/cutlass_cppgen/op/gemm_grouped.py
Normal file
269
python/cutlass_cppgen/op/gemm_grouped.py
Normal file
@ -0,0 +1,269 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Ease-of-use interface for constructing, compiling, and running GEMMs.
|
||||
|
||||
The ``GroupedGemm`` interface is meant to allow one to easily instantiate, compile, and run
|
||||
grouped GEMM operations in CUTLASS via Python, without specifying many configuration parameters.
|
||||
Under the hood, the interface will select sensible default parameters for the many template
|
||||
parameters for CUTLASS grouped GEMMs.
|
||||
|
||||
Note: optimal performance is not to be expected from this interface. To achieve optimal
|
||||
performance, one should specify and tune each configuration parameter.
|
||||
|
||||
The simplest example of using this interface is the following:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
# As, Bs, Cs, and Ds are torch/numpy/cupy tensor objects
|
||||
plan = cutlass_cppgen.op.GroupedGemm(element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor)
|
||||
plan.run([A0, A1], [B0, B1], [C0, C1], [D0, D1])
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from typing import Optional
|
||||
from cutlass_library import DataTypeSize
|
||||
|
||||
from cutlass_cppgen.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
from cutlass_cppgen.backend.gemm_operation import (
|
||||
GemmGroupedArguments,
|
||||
GemmOperationGrouped,
|
||||
)
|
||||
from cutlass_cppgen.backend.library import (
|
||||
SchedulerMode,
|
||||
TensorDescription,
|
||||
TileDescription,
|
||||
)
|
||||
from cutlass_cppgen.op.gemm import Gemm
|
||||
from cutlass_cppgen.shape import GemmCoord
|
||||
from cutlass_cppgen.utils import check, datatypes
|
||||
|
||||
|
||||
class GroupedGemm(Gemm):
|
||||
"""
|
||||
Constructs a ``GroupedGemm`` object.
|
||||
|
||||
The data types and layouts of operands A, B, and C, along with the data type of output D
|
||||
and that used for accumulation, are bound to the ``GroupedGemm`` object throughout its lifetime --
|
||||
these are not to be changed after a ``GroupedGemm`` has been constructed.
|
||||
|
||||
The constructor has optional parameters for flexibly setting these parameters. Please see the constructor
|
||||
for ``Gemm`` for examples of these.
|
||||
|
||||
:param cc: compute capability of device to generate kernels for
|
||||
:type cc: int
|
||||
:param A: tensor representing data type and layout of operands A
|
||||
:param B: tensor representing data type and layout of operands B
|
||||
:param C: tensor representing data type and layout of operands C
|
||||
:param D: tensor representing data type and layout of operands D
|
||||
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
|
||||
:param beta: scalar parameter beta from GEMM operation that scales operand C
|
||||
:param element_accumulator: data type to be used in accumulation of the product of operands A and B
|
||||
:type element_accumulator: cutlass_cppgen.DataType
|
||||
:param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type
|
||||
:type element: cutlass_cppgen.DataType
|
||||
:param layout: generic layout type to be used for operands A, B, C, and D
|
||||
:type layout: cutlass_cppgen.LayoutType
|
||||
:param element_A: data type to be used for operand A
|
||||
:type element_A: cutlass_cppgen.DataType
|
||||
:param element_B: data type to be used for operand B
|
||||
:type element_B: cutlass_cppgen.DataType
|
||||
:param element_C: data type to be used for operand C
|
||||
:type element_C: cutlass_cppgen.DataType
|
||||
:param element_D: data type to be used for operand D
|
||||
:type element_D: cutlass_cppgen.DataType
|
||||
:type layout_A: layout of operand A
|
||||
:param layout_A: cutlass_cppgen.LayoutType
|
||||
:type layout_B: layout of operand B
|
||||
:param layout_B: cutlass_cppgen.LayoutType
|
||||
:type layout_C: layout of operand C
|
||||
:param layout_C: cutlass_cppgen.LayoutType
|
||||
:type layout_D: layout of operand D
|
||||
:param layout_D: cutlass_cppgen.LayoutType
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, A=None, B=None, C=None, D=None,
|
||||
alpha=1.0, beta=0.0, element_accumulator=None,
|
||||
element=None, layout=None,
|
||||
element_A=None, element_B=None, element_C=None, element_D=None,
|
||||
layout_A=None, layout_B=None, layout_C=None,
|
||||
cc: int = None,
|
||||
):
|
||||
super().__init__(
|
||||
A=A, B=B, C=C, D=D,
|
||||
alpha=alpha, beta=beta,
|
||||
element_accumulator=element_accumulator,
|
||||
element=element, layout=layout,
|
||||
element_A=element_A, element_B=element_B,
|
||||
element_C=element_C, element_D=element_D,
|
||||
layout_A=layout_A, layout_B=layout_B, layout_C=layout_C,
|
||||
cc=cc
|
||||
)
|
||||
|
||||
# Grouped GEMM specializations for SM90 are currently unavailable. Revert to using SM80
|
||||
if self.current_cc in [90, 100, 101, 103]:
|
||||
self._reset_options(80)
|
||||
self._reset_operations(reset_epilogue=False)
|
||||
|
||||
self.name = "grouped_gemm"
|
||||
|
||||
@Gemm.swizzling_functor.setter
|
||||
def swizzling_functor(self, swizzling_functor):
|
||||
"""
|
||||
Sets the swizzling functor to the type specified by `swizzling_functor`
|
||||
"""
|
||||
raise Exception('Grouped GEMM does not currently support different swizzling functors')
|
||||
|
||||
def construct(self, tile_description: TileDescription = None,
|
||||
alignment_A: int = None,
|
||||
alignment_B: int = None,
|
||||
alignment_C: int = None) -> GemmOperationGrouped:
|
||||
"""
|
||||
Constructs a ``cutlass_cppgen.backend.GemmOperationGrouped`` based on the input parameters and current
|
||||
kernel specification of the ``Gemm`` object.
|
||||
|
||||
:param tile_description: tile description specifying shapes and operand types to use in the kernel
|
||||
:type tile_description: cutlass_cppgen.backend.TileDescription
|
||||
:param alignment_A: alignment of operand A
|
||||
:type alignment_A: int
|
||||
:param alignment_B: alignment of operand B
|
||||
:type alignment_B: int
|
||||
:param alignment_C: alignment of operand C
|
||||
:type alignment_C: int
|
||||
|
||||
:return: operation that was constructed
|
||||
:rtype: cutlass_cppgen.backend.GemmOperationGrouped
|
||||
"""
|
||||
alignment_A = check.alignment_or_default(alignment_A, max(self.possible_operations.alignments("A")))
|
||||
alignment_B = check.alignment_or_default(alignment_B, max(self.possible_operations.alignments("B")))
|
||||
alignment_C = check.alignment_or_default(alignment_C, max(self.possible_operations.alignments("C")))
|
||||
|
||||
self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor)
|
||||
|
||||
tensor_A = TensorDescription(self._element_a, self._layout_b, alignment_A)
|
||||
tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B)
|
||||
tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C)
|
||||
|
||||
if tile_description is None:
|
||||
op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0]
|
||||
tile_description = datatypes.td_from_profiler_op(op)
|
||||
else:
|
||||
valid, err_str = self._valid_tile_description(tile_description)
|
||||
if not valid:
|
||||
raise Exception(f"Invalid tile description. {err_str}")
|
||||
self.tile_description = tile_description
|
||||
|
||||
operation = GemmOperationGrouped(
|
||||
arch=self.current_cc,
|
||||
tile_description=tile_description,
|
||||
A=tensor_A, B=tensor_B, C=tensor_C,
|
||||
epilogue_functor=self.epilogue_functor,
|
||||
swizzling_functor=self._swizzling_functor,
|
||||
precompute_mode=SchedulerMode.Device)
|
||||
|
||||
return operation
|
||||
|
||||
def run(self, A, B, C, D,
|
||||
alpha=None, beta=None, sync: bool = True,
|
||||
print_module: bool = False,
|
||||
stream: Optional[cuda.CUstream] = None) -> GemmGroupedArguments:
|
||||
"""
|
||||
Runs the kernel currently specified.
|
||||
|
||||
By default, this call returns only once the kernel has completed. To launch the kernel
|
||||
and immediately return, set ``sync=False``. In this case, it is the responsibility of the
|
||||
caller to syncrhonize the results of the kernel before attempting to access outputs
|
||||
by calling ``sync()`` on the arguments returned from this call.
|
||||
|
||||
:param A: list of tensors representing data type and layout of operand A
|
||||
:type A: list
|
||||
:param B: list of tensors representing data type and layout of operand B
|
||||
:type B: list
|
||||
:param C: list of tensors representing data type and layout of operand C
|
||||
:type C: list
|
||||
:param D: list of tensors representing data type and layout of operand D
|
||||
:type D: list
|
||||
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
|
||||
:param beta: scalar parameter beta from GEMM operation that scales operand C
|
||||
:param sync: whether the call should wait for the kernel to complete before returning
|
||||
:type sync: bool
|
||||
:param print_module: whether to print the emitted C++ code
|
||||
:type print_module: bool
|
||||
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
|
||||
:type stream: :class:`cuda.cuda.CUstream`
|
||||
|
||||
:return: arguments passed in to the kernel
|
||||
:rtype: cutlass_cppgen.backend.GemmGroupedArguments
|
||||
"""
|
||||
if not stream:
|
||||
stream = cuda.CUstream(0)
|
||||
|
||||
super().run_setup()
|
||||
|
||||
if len(A) != len(B) or len(A) != len(C) or len(A) != len(D):
|
||||
raise Exception("Lengths of A, B, C, and D lists must be equal")
|
||||
|
||||
problem_sizes = []
|
||||
As, Bs, Cs, Ds = ([None] * len(A) for _ in range(4))
|
||||
for i in range(len(A)):
|
||||
As[i] = self._verify_tensor(A[i], self.A, self._element_a, self._layout_a, "A")
|
||||
Bs[i] = self._verify_tensor(B[i], self.B, self._element_b, self._layout_b, "B")
|
||||
Cs[i] = self._verify_tensor(C[i], self.C, self._element_c, self._layout_c, "C")
|
||||
Ds[i] = self._verify_tensor(D[i], self.D, self._element_d, self._layout_d, "D")
|
||||
problem_sizes.append(GemmCoord(A[i].shape[0], B[i].shape[1], A[i].shape[1]))
|
||||
|
||||
alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
|
||||
beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")
|
||||
|
||||
alignment_a = min((self.possible_operations.find_alignment(A.shape, self._layout_a, operand="A") for A in As))
|
||||
alignment_b = min((self.possible_operations.find_alignment(B.shape, self._layout_b, operand="B") for B in Bs))
|
||||
alignment_c = min((self.possible_operations.find_alignment(C.shape, self._layout_c, operand="C") for C in Cs))
|
||||
self.compile(self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
|
||||
alignment_C=alignment_c, print_module=print_module)
|
||||
|
||||
arguments = GemmGroupedArguments(
|
||||
operation=self.operation,
|
||||
problem_sizes=problem_sizes,
|
||||
A=As, B=Bs, C=Cs, D=Ds,
|
||||
output_op=self.operation.epilogue_type(alpha, beta),
|
||||
stream=stream
|
||||
)
|
||||
|
||||
self.operation.run(arguments)
|
||||
|
||||
if sync:
|
||||
arguments.sync()
|
||||
|
||||
return arguments
|
||||
431
python/cutlass_cppgen/op/op.py
Normal file
431
python/cutlass_cppgen/op/op.py
Normal file
@ -0,0 +1,431 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d)
|
||||
"""
|
||||
|
||||
from bisect import bisect_left
|
||||
|
||||
from cutlass_library import (
|
||||
DataType,
|
||||
DataTypeSize,
|
||||
MathOperation,
|
||||
OperationKind,
|
||||
SharedMemPerCC
|
||||
)
|
||||
|
||||
import cutlass_cppgen
|
||||
from cutlass_cppgen import get_option_registry
|
||||
from cutlass_cppgen.backend.evt import EpilogueFunctorVisitor
|
||||
from cutlass_cppgen.backend.evt.passes.util import cc_map
|
||||
from cutlass_cppgen.backend.utils.device import device_cc
|
||||
from cutlass_cppgen.epilogue import get_activations, get_activation_epilogue, identity
|
||||
from cutlass_cppgen.library_defaults import KernelsForDataType, _generator_ccs
|
||||
from cutlass_cppgen.swizzle import get_swizzling_functors
|
||||
from cutlass_cppgen.utils import datatypes, check
|
||||
|
||||
|
||||
class OperationBase:
|
||||
"""
|
||||
Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d)
|
||||
"""
|
||||
|
||||
def __init__(self, cc: int = None, kernel_cc: int = None, operation_kind = OperationKind.Gemm):
|
||||
"""
|
||||
:param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90
|
||||
:type cc: int
|
||||
:param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80
|
||||
:type kernel_cc: int
|
||||
:param operation_kind: class of operation that will be performed (e.g., GEMM, Conv)
|
||||
:type operation_kind: cutlass_library.OperationKind
|
||||
"""
|
||||
self.operation_kind = operation_kind
|
||||
self.cc = cc if cc is not None else device_cc()
|
||||
self.specified_kernel_cc = kernel_cc is not None
|
||||
self.current_cc = kernel_cc if kernel_cc is not None else self._find_closest_cc(self.cc)
|
||||
self.tile_description = None
|
||||
self._math_operation = None
|
||||
|
||||
self.options = get_option_registry().options_for_cc(self.current_cc, operation_kind)
|
||||
|
||||
if self.options is None:
|
||||
raise Exception(f"Invalid or unsupported compute capability: {self.current_cc}")
|
||||
|
||||
# Default activation function: identity
|
||||
self._activation = identity
|
||||
|
||||
def _find_closest_cc(self, cc: int) -> int:
|
||||
"""
|
||||
Returns the closest CC in _generator_ccs less than or equal to `cc`
|
||||
|
||||
:param cc: compute capability to query
|
||||
:type cc: int
|
||||
|
||||
:returns: closest CC in _generator_ccs less than or equal to `cc`
|
||||
:rtype: int
|
||||
"""
|
||||
if cc in _generator_ccs:
|
||||
return cc
|
||||
|
||||
# Find closest CC lower than this CC
|
||||
idx = bisect_left(_generator_ccs, cc)
|
||||
if idx == 0:
|
||||
raise Exception(f'No valid CC to fall back to for {cc}')
|
||||
return _generator_ccs[idx-1]
|
||||
|
||||
def activations(self) -> list:
|
||||
"""
|
||||
Returns possible activation functions that can be used
|
||||
|
||||
:return: list of activation functions that can be used
|
||||
:rtype: list
|
||||
"""
|
||||
return get_activations()
|
||||
|
||||
def swizzling_functors(self) -> list:
|
||||
"""
|
||||
Returns possible swizzling functions that can be used
|
||||
|
||||
:return: list of swizzling functions that can be used
|
||||
:rtype: list
|
||||
"""
|
||||
return get_swizzling_functors()
|
||||
|
||||
def _reset_options(self, cc: int):
|
||||
"""
|
||||
Resets the kernel options based on cc
|
||||
|
||||
:param cc: compute capability to reset to
|
||||
:type cc: int
|
||||
"""
|
||||
if cc != self.current_cc:
|
||||
if cc not in _generator_ccs:
|
||||
raise Exception(f'Invalid CC for CUTLASS kernels: {cc}.')
|
||||
self.current_cc = cc
|
||||
self.options = get_option_registry().options_for_cc(self.current_cc, self.operation_kind)
|
||||
|
||||
def _verify_scalar(self, scalar, ref_scalar, ref_dtype, name):
|
||||
"""
|
||||
Verifies the following properties:
|
||||
1) Either ``scalar`` or ``ref_scakar`` must be set (i.e., not ``None``)
|
||||
2) If ``scalar`` is not ``None``, its datatype must match matches the current version
|
||||
set by the plan (i.e., those in ``ref_dtype``)
|
||||
|
||||
If either of these properties does not hold, an exception is raised. If these properties hold and
|
||||
``scalar`` is not ``None``, ``scalar`` is returned. Otherwise, ``ref_scalar`` is returned.
|
||||
|
||||
:param scalar: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
|
||||
:type scalar: numpy/cupy/torch scalar
|
||||
:param ref_scalar: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in
|
||||
:type ref_scalar: numpy/cupy/torch scalar
|
||||
:param ref_dtype: data type for the scalar that this object was initialized to
|
||||
:param name: identifier of the scalar to verify. Used in raising exceptions
|
||||
:type name: str
|
||||
|
||||
:return: valid scalar to use
|
||||
:rtype: numpy/cupy/torch scalar
|
||||
"""
|
||||
if scalar is None:
|
||||
if ref_scalar is None:
|
||||
raise Exception(f"Scalar {name} must be set.")
|
||||
return ref_scalar
|
||||
if hasattr(scalar, "dtype"):
|
||||
dtype = datatypes.library_type(scalar.dtype)
|
||||
if dtype != ref_dtype:
|
||||
raise Exception(
|
||||
f"Tensor {name} with type {dtype} does not match expected type {ref_dtype}."
|
||||
)
|
||||
return scalar
|
||||
|
||||
def _verify_tensor(self, tensor, ref_tensor, ref_dtype, ref_layout, name):
|
||||
"""
|
||||
Verifies the following properties:
|
||||
If ref_dtype is not void:
|
||||
1) Either ``tensor`` or ``ref_tensor`` must be set (i.e., not ``None``)
|
||||
2) If ``tensor`` is not ``None``, its datatype and layout must match matches the current versions
|
||||
set by the plan (i.e., those in ``ref_dtype`` and ``ref_layout``)
|
||||
If ref_dtype is void:
|
||||
Neither ``tensor`` nor ``ref_tensor`` are set
|
||||
|
||||
If either of these properties does not hold, an exception is raised. If these properties hold and
|
||||
``tensor`` is not ``None``, ``tensor`` is returned. Otherwise, ``ref_tensor`` is returned.
|
||||
|
||||
:param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
|
||||
:type tensor: numpy/cupy/torch array/tensor object
|
||||
:param ref_tensor: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in
|
||||
:type ref_tensor: numpy/cupy/torch array/tensor object
|
||||
:param ref_dtype: data type for the tensor that this object was initialized to
|
||||
:param ref_layout: layout for the tensor that this object was initialized to
|
||||
:param name: identifier of the tensor to verify. Used in raising exceptions
|
||||
:type name: str
|
||||
|
||||
:return: valid tensor object to use
|
||||
:rtype: numpy/cupy/torch array/tensor object
|
||||
"""
|
||||
if ref_dtype == DataType.void:
|
||||
if tensor is not None or ref_tensor is not None:
|
||||
raise Exception("Operands with element DataType.void must not be provided a tensor")
|
||||
return None
|
||||
|
||||
if tensor is None:
|
||||
if ref_tensor is None:
|
||||
raise Exception(f"Tensor {name} must be set.")
|
||||
return ref_tensor
|
||||
|
||||
self._verify_type_and_layout(tensor, ref_dtype, ref_layout, name)
|
||||
return tensor
|
||||
|
||||
@property
|
||||
def opclass(self) -> cutlass_cppgen.OpcodeClass:
|
||||
"""
|
||||
Returns the opcode class currently in use
|
||||
|
||||
:return: opcode class currently in use
|
||||
:rtype: cutlass_cppgen.OpcodeClass
|
||||
"""
|
||||
return self.op_class
|
||||
|
||||
@opclass.setter
|
||||
def opclass(self, oc: cutlass_cppgen.OpcodeClass):
|
||||
if isinstance(oc, str):
|
||||
oc = datatypes.getattr_enum(cutlass_cppgen.OpcodeClass, oc)
|
||||
if oc in self.possible_op_classes:
|
||||
self.op_class = oc
|
||||
else:
|
||||
raise Exception(
|
||||
f'Unsupported operation class {oc} for CC {self.cc} and data type combination '
|
||||
f'({self._element_a}, {self._element_b}, {self._element_accumulator}) and '
|
||||
f'layout combination ({self._layout_a}, {self._layout_b}).')
|
||||
|
||||
# Changing the op class also changes the possible operations available. Reset these.
|
||||
self.possible_operations = self.options.operations(
|
||||
self.op_class, self._element_a, self._element_b,
|
||||
self._element_accumulator, self._layout_a, self._layout_b, self._math_operation)
|
||||
|
||||
# Changing the op class changes the elements per access in the epilogue. Reset this.
|
||||
if self.epilogue_functor is not None:
|
||||
self.epilogue_functor = self._reset_epilogue_functor_alignment(self._elements_per_access(), self.epilogue_functor)
|
||||
|
||||
@property
|
||||
def math_operation(self) -> cutlass_cppgen.MathOperation:
|
||||
"""
|
||||
Returns the math operation currently in use
|
||||
|
||||
:return: math operation currently in use
|
||||
:rtype: cutlass_cppgen.MathOperation
|
||||
"""
|
||||
return self._math_operation
|
||||
|
||||
@math_operation.setter
|
||||
def math_operation(self, mo: cutlass_cppgen.MathOperation):
|
||||
if isinstance(mo, str):
|
||||
mo = datatypes.getattr_enum(cutlass_cppgen.MathOperation, mo)
|
||||
|
||||
if not self.specified_kernel_cc:
|
||||
if self.current_cc in [90, 100, 101, 103]:
|
||||
# CUTLASS 3.0 kernels do not use different math operations. If one is specified, we
|
||||
# revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels.
|
||||
cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
|
||||
self._reset_options(80)
|
||||
self._reset_operations(reset_epilogue=False)
|
||||
elif self.current_cc in [90, 100, 101, 103]:
|
||||
raise Exception("CUTLASS 3.0 kernels do not use different math operations. "
|
||||
"To use 2.x kernels with a specific math operation, do not set the `kernel_cc`"
|
||||
"parameter when constructing the plan.")
|
||||
|
||||
self._math_operation = mo
|
||||
self._reset_operations()
|
||||
|
||||
def _elements_per_access(self):
|
||||
if self.op_class == cutlass_cppgen.OpcodeClass.Simt:
|
||||
return 1
|
||||
elif self._element_c != DataType.void:
|
||||
return 128 // DataTypeSize[self._element_c]
|
||||
else:
|
||||
return 128 // max(self.possible_operations.alignments("C"))
|
||||
|
||||
def _create_epilogue_functor_activation(self, activation):
|
||||
"""
|
||||
Returns the epilogue functor with given activation function
|
||||
"""
|
||||
if self.epilogue_functor is None:
|
||||
elements_per_access = self._elements_per_access()
|
||||
else:
|
||||
elements_per_access = self.epilogue_functor.epilogue_vector_length
|
||||
|
||||
if not self.specified_kernel_cc:
|
||||
if self.current_cc in [90, 100, 101, 103] and activation != identity:
|
||||
# CUTLASS 3.0 kernels in Python currently only support identity activation. If one requests a non-identity activation,
|
||||
# revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels.
|
||||
cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
|
||||
if self._element_c != self._element_d:
|
||||
raise Exception("CUTLASS 2.x kernels require element C to be the same as element D")
|
||||
self._reset_options(80)
|
||||
self._reset_operations(reset_epilogue=False)
|
||||
elif (self.cc in [90, 100, 101, 103] and self.current_cc not in [90, 100, 101, 103] and activation == identity and self._math_operation is None):
|
||||
# SM80 fallback kernels are currently used. Since an identity activation is requested,
|
||||
# we can switch back to using SM90 kernels.
|
||||
self._reset_options(self.cc)
|
||||
self._reset_operations(reset_epilogue=False)
|
||||
else:
|
||||
if self.current_cc in [90, 100, 101, 103] and activation != identity:
|
||||
raise Exception("Epilogues with elementwise fusion are not currently supported "
|
||||
"in the Python interface for 3.x kernels. To use 2.x kernels "
|
||||
"with fused elementwise epilogues, do not set the `kernel_cc` "
|
||||
"parameter when constructing the plan.")
|
||||
|
||||
return get_activation_epilogue(
|
||||
activation,
|
||||
self._element_d,
|
||||
elements_per_access,
|
||||
self._element_accumulator,
|
||||
self._element_accumulator,
|
||||
)
|
||||
|
||||
def _reset_epilogue_functor_activation(self, activation):
|
||||
"""
|
||||
Set the epilogue functor based on the provided activation function
|
||||
"""
|
||||
self.epilogue_functor = self._create_epilogue_functor_activation(activation)
|
||||
|
||||
def _reset_epilogue_functor_alignment(self, alignment, epilogue_functor):
|
||||
"""
|
||||
Reset the alignment of the current epilogue functor based on alignment C
|
||||
"""
|
||||
if isinstance(epilogue_functor, EpilogueFunctorVisitor):
|
||||
return epilogue_functor
|
||||
|
||||
if epilogue_functor is None or not hasattr(epilogue_functor, 'activation_functor'):
|
||||
# Identity epilogue does not have 'activation_functor'
|
||||
activation = identity
|
||||
else:
|
||||
activation = epilogue_functor.activation_functor
|
||||
|
||||
epilogue_functor = get_activation_epilogue(
|
||||
activation,
|
||||
self._element_d,
|
||||
alignment,
|
||||
self._element_accumulator,
|
||||
self._element_accumulator,
|
||||
)
|
||||
return epilogue_functor
|
||||
|
||||
@property
|
||||
def activation(self):
|
||||
"""
|
||||
Returns the type of the current activation function used
|
||||
"""
|
||||
if hasattr(self.epilogue_functor, "activation_functor"):
|
||||
return self.epilogue_functor.activation_functor
|
||||
else:
|
||||
return identity
|
||||
|
||||
@activation.setter
|
||||
def activation(self, act):
|
||||
"""
|
||||
Sets the type of the activation function to use
|
||||
Activation can come with a set of arguments
|
||||
|
||||
:param act: type of activation function to use
|
||||
:type act: str or tuple. e.g. "relu", ("leaky_relu", 0.01)
|
||||
|
||||
"""
|
||||
if isinstance(act, tuple):
|
||||
if isinstance(act[0], str):
|
||||
act_fn = getattr(cutlass_cppgen.backend.epilogue, act[0])
|
||||
else:
|
||||
act_fn = act[0]
|
||||
self._reset_epilogue_functor_activation(act_fn)
|
||||
self._activation_args = act[1]
|
||||
self._activation = act[0]
|
||||
else:
|
||||
if isinstance(act, str):
|
||||
act = getattr(cutlass_cppgen.backend.epilogue, act)
|
||||
self._reset_epilogue_functor_activation(act)
|
||||
self._activation = act
|
||||
|
||||
@property
|
||||
def epilogue_visitor(self):
|
||||
"""
|
||||
Return the epilogue functor
|
||||
"""
|
||||
return self.epilogue_functor
|
||||
|
||||
@epilogue_visitor.setter
|
||||
def epilogue_visitor(self, visitor):
|
||||
"""
|
||||
Create the epilogue visitor
|
||||
"""
|
||||
self.epilogue_functor = EpilogueFunctorVisitor(cc_map[self.cc], visitor)
|
||||
|
||||
# The epilogue_functor may consume too much shared memory
|
||||
# Reset the possible operations
|
||||
if self.cc not in [90, 100, 101, 103]:
|
||||
# The shared memory is only a concern for sm90+ epilogue
|
||||
# In sm80, the epilogue and mainloop share the shared memory
|
||||
return
|
||||
|
||||
datatype_comb = self.possible_operations.datatype_comb
|
||||
layout_comb = self.possible_operations.layout_comb
|
||||
new_possible_operations = KernelsForDataType(datatype_comb, layout_comb)
|
||||
for operation in self.possible_operations.all_operations:
|
||||
td = datatypes.td_from_profiler_op(operation)
|
||||
# Filter invalid epilogue schedules
|
||||
if cc_map[self.cc] == 90 and td.epilogue_schedule not in [
|
||||
cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized,
|
||||
cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative]:
|
||||
continue
|
||||
epilogue_smem_bytes = self.epilogue_functor.get_smem_size(td)
|
||||
|
||||
# Verify the maximum number of mainloop stages
|
||||
mainloop_smem_per_stage = check.calculate_smem_usage_per_stage(td, OperationKind.Gemm)
|
||||
smem_capacity_bytes = SharedMemPerCC[self.cc] << 10
|
||||
mainloop_stages = (smem_capacity_bytes - epilogue_smem_bytes) // mainloop_smem_per_stage
|
||||
if mainloop_stages < 2:
|
||||
# Mainloop stages must >= 2
|
||||
continue
|
||||
|
||||
new_possible_operations.add(operation)
|
||||
if len(new_possible_operations.all_operations) == 0:
|
||||
raise RuntimeError(
|
||||
"The epilogue consumes too much shared memory. "
|
||||
"No valid tile description is found in the generator.")
|
||||
self.possible_operations = new_possible_operations
|
||||
|
||||
|
||||
def run_setup(self):
|
||||
"""
|
||||
Steps that must be taken before caling `plan.run()`
|
||||
"""
|
||||
# Initialize the memory pool if, if not already done
|
||||
cutlass_cppgen.get_memory_pool()
|
||||
Reference in New Issue
Block a user