CUTLASS 3.4.0 (#1286)

* CUTLASS 3.4.0

* Update CHANGELOG.md

---------

Co-authored-by: Pradeep Ramani <prramani@nvidia.com>
This commit is contained in:
Pradeep Ramani
2023-12-29 12:21:31 -08:00
committed by GitHub
parent b7508e3379
commit 8236f30675
211 changed files with 11409 additions and 2763 deletions

View File

@ -14,7 +14,7 @@ import cutlass
import numpy as np
plan = cutlass.op.Gemm(element=np.float16, layout=cutlass.LayoutType.RowMajor)
A, B, C, D = [np.ones((4096, 4096), dtype=np.float16) for i in range(4)]
A, B, C, D = [np.ones((1024, 1024), dtype=np.float16) for i in range(4)]
plan.run(A, B, C, D)
```
@ -67,7 +67,7 @@ The CUTLASS Python interface currently supports the following operations:
We recommend using the CUTLASS Python interface via an [NGC PyTorch Docker container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch):
```bash
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:23.08-py3
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:23.08-py3 -p 8888:8888
```
The CUTLASS Python interface has been tested with CUDA 11.8, 12.0, and 12.1 on Python 3.8 and 3.9.
@ -99,6 +99,24 @@ If you would like to be able to make changes to CUTLASS Python interface and hav
pip install -e .
```
To test that your installation was successful, you can run:
```python
import cutlass
import numpy as np
plan = cutlass.op.Gemm(element=np.float16, layout=cutlass.LayoutType.RowMajor)
A, B, C, D = [np.ones((128, 128), dtype=np.float16) for i in range(4)]
plan.run(A, B, C, D)
```
### Deep learning framework CUDA extensions
The CUTLASS Python interface provides utilities for exporting a CUTLASS kernel to a deep learning framework CUDA extensions. Currently, PyTorch CUDA extensions can be exported, but a similar pattern could be applied for other frameworks as well. An example of this is provided [here](/examples/python/02_pytorch_extension_grouped_gemm.ipynb).
Currently, the following operations can be exported to a PyTorch CUDA extension:
* GEMM
* Grouped GEMM
* Conv2d
### Examples
Jupyter notebook examples of using the CUTLASS Python interface are located in [examples/python](/examples/python).

View File

@ -75,6 +75,7 @@ from cutlass_library import (
DataType,
EpilogueScheduleType,
KernelScheduleType,
MathOperation,
LayoutType,
OpcodeClass,
TileDescription,
@ -120,7 +121,7 @@ def get_option_registry():
this._option_registry = OptionRegistry(device_cc())
return this._option_registry
this.__version__ = '3.3.0'
this.__version__ = '3.4.0'
from cutlass.backend import create_memory_pool
from cutlass.emit.pytorch import pytorch

View File

@ -34,7 +34,8 @@ import ctypes
from cutlass_library import (
DataType,
KernelScheduleType
KernelScheduleType,
TileSchedulerType
)
from cutlass.backend.library import DataTypeSizeBytes
@ -99,6 +100,7 @@ class StrideBatched_(ctypes.Structure):
]
class GenericMainloopArguments3x_(ctypes.Structure):
"""
Structure representing the superset of possible mainloop arguments.
@ -115,6 +117,45 @@ class GenericMainloopArguments3x_(ctypes.Structure):
]
class _PersistentTileSchedulerArguments(ctypes.Structure):
_fields_ = [
("max_swizzle_size", ctypes.c_int),
("raster_order_option", ctypes.c_int),
]
class _PersistentTileSchedulerStreamKArguments(ctypes.Structure):
_fields_ = [
("splits", ctypes.c_int),
("max_swizzle_size", ctypes.c_int),
("raster_order_option", ctypes.c_int),
("reduction_mode", ctypes.c_int),
("decomposition_mode", ctypes.c_int),
]
def get_tile_scheduler_arguments_3x(
tile_scheduler: TileSchedulerType,
splits: int = 1):
max_swizzle_size = 1
raster_order_option = 0 # Heuristic
if tile_scheduler == TileSchedulerType.Persistent:
return _PersistentTileSchedulerArguments(
max_swizzle_size,
raster_order_option,
)
elif tile_scheduler == TileSchedulerType.StreamK:
reduction_mode = 0 # Deterministic
decomposition_mode = 0 # Heuristic
return _PersistentTileSchedulerStreamKArguments(
splits,
max_swizzle_size,
raster_order_option,
reduction_mode,
decomposition_mode,
)
def get_mainloop_arguments_3x(
kernel_schedule: KernelScheduleType,
element_A,
@ -172,7 +213,7 @@ def get_mainloop_arguments_3x(
return _MainloopArgumentsTma
def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor):
def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor, scheduler_args):
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
if hasattr(epilogue_functor, "visitor"):
class _EpilogueArguments(ctypes.Structure):
@ -187,7 +228,6 @@ def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor):
self.arg_C = epilogue_functor.arg_c_type(ptr_c)
self.arg_D = epilogue_functor.arg_d_type(ptr_d)
else:
class _EpilogueArguments(ctypes.Structure):
_fields_ = [
("epilogue", _EpilogueOutputOpParams),
@ -210,7 +250,7 @@ def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor):
("mainloop", mainloop_arguments),
("epilogue", _EpilogueArguments),
("hw_info", _HardwareInfo),
("splits", ctypes.c_int)
("scheduler", type(scheduler_args)),
]
return _GemmArguments, _EpilogueArguments, _EpilogueOutputOpParams, _HardwareInfo

View File

@ -87,7 +87,7 @@ class Sm90LoadSrcImpl(LoadSrcImpl):
self._type_decl = f"""
using ElementC = {DataTypeTag[self.element]};
using StrideC = {self.stride_mnl};
using {self.name_camel} = cutlass::epilogue::fusion::Sm90SrcFetch;
using {self.name_camel} = cutlass::epilogue::fusion::Sm90SrcFetch<{DataTypeTag[self.element]}>;
"""
return self._type_decl

View File

@ -44,6 +44,7 @@ from cutlass.backend.epilogue import EpilogueFunctorBase
import cutlass.backend.evt.backend
from cutlass.backend.frontend import TensorFrontend
from cutlass.utils.datatypes import is_numpy_tensor
from cutlass.backend.evt.passes.util import cc_map
class EpilogueFunctorVisitor(EpilogueFunctorBase):
@ -56,7 +57,7 @@ class EpilogueFunctorVisitor(EpilogueFunctorBase):
"""
def __init__(self, cc: int, visitor, element_compute=DataType.f32) -> None:
# Type of Emitter based on CC
self.emit_cls = getattr(cutlass.backend.evt.backend, f"Sm{cc}Emitter")
self.emit_cls = getattr(cutlass.backend.evt.backend, f"Sm{cc_map[cc]}Emitter")
# Visitor Types
self.visitor = visitor

View File

@ -45,6 +45,7 @@ from cutlass.backend.evt.passes.pass_fix_element_d import PassFixElementD
from cutlass.backend.evt.passes.pass_manager import EVTPassBase
from cutlass.backend.evt.passes.pass_no_op_elimination import PassNoOpElimination
from cutlass.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
from cutlass.backend.evt.passes.util import cc_map
class PassGetImpl(EVTPassBase):
@ -82,8 +83,8 @@ class PassGetImpl(EVTPassBase):
self.no_op_elimination()
# Lower to cc-specific impl
for node_meta in self.dag_ir.nodes_meta:
node_impl_ccs = getattr(evt_backend, f"sm{self.cc}_nodes")
node_impl_ccs = getattr(evt_backend, f"sm{cc_map[self.cc]}_nodes")
node_meta.underlying_impl = getattr(
node_impl_ccs,
f"Sm{self.cc}" + node_meta.underlying_impl.__class__.__name__
f"Sm{cc_map[self.cc]}" + node_meta.underlying_impl.__class__.__name__
)(node_meta)

View File

@ -39,6 +39,7 @@ from typing import Any
import networkx as nx
from cutlass.backend.evt.ir import DAGIR
from cutlass.backend.evt.passes.util import cc_map
class EVTPassBase:
@ -102,7 +103,7 @@ class EVTPassBase:
// sm80 specific method
return
"""
func_name = f"sm{self.cc}_{func.__name__}"
func_name = f"sm{cc_map[self.cc]}_{func.__name__}"
if hasattr(self, func_name):
return getattr(self, func_name)
else:

View File

@ -0,0 +1,43 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Utilities for passes
"""
# Map from the CC of the kernel to the EVT implementation that the CC targets
cc_map = {
80: 80,
86: 80,
89: 80,
90: 90,
}

View File

@ -82,7 +82,8 @@ from cutlass.backend.c_types import (
get_gemm_arguments_3x,
get_gemm_arguments_streamk,
get_gemm_grouped_arguments,
get_mainloop_arguments_3x
get_mainloop_arguments_3x,
get_tile_scheduler_arguments_3x,
)
from cutlass.backend.library import (
ApiVersion,
@ -554,6 +555,7 @@ class GemmArguments3x(GemmArguments2x):
mainloop,
epilogue,
hw_info,
self.operation.rt_module.scheduler_args
)
return self.arguments
@ -1163,7 +1165,9 @@ extern "C" {
operation.A.alignment,
operation.B.alignment
)
self.argument_type, self.epilogue_args, self.epilogue_type, self.hw_info = get_gemm_arguments_3x(self.mainloop_args, operation.epilogue_functor)
self.scheduler_args = get_tile_scheduler_arguments_3x(operation.tile_description.tile_scheduler)
self.argument_type, self.epilogue_args, self.epilogue_type, self.hw_info = get_gemm_arguments_3x(
self.mainloop_args, operation.epilogue_functor, self.scheduler_args)
def get_device_workspace_size(self, arguments: GemmArguments3x):
return self.get_kernel_workspace_size(ctypes.byref(arguments.get_arguments()))

View File

@ -34,6 +34,7 @@
Classes containing valid operations for a given compute capability and data types.
"""
from itertools import combinations_with_replacement
import logging
from cuda import __version__
@ -60,6 +61,7 @@ class KernelsForDataType:
def __init__(self, datatype_comb: tuple, layout_comb: tuple):
self.datatype_comb = datatype_comb
self.layout_comb = layout_comb
self.math_operations = set()
# Dictionary mapping from alignment (int) to a list of kernels that fit the alignment
# constraint for the data type combination
@ -73,6 +75,7 @@ class KernelsForDataType:
if alignment_key not in self.kernels_by_alignment:
self.kernels_by_alignment[alignment_key] = []
self.kernels_by_alignment[alignment_key].append(operation)
self.math_operations.add(operation.tile_description.math_instruction.math_operation)
def alignments(self, operand: str):
"""
@ -100,11 +103,14 @@ class KernelsForDataType:
ops.extend(alignment_ops)
return ops
def default_operation(self):
def default_operation(self, math_operation: cutlass.MathOperation):
key = sorted(list(self.kernels_by_alignment.keys()))[0]
return self.kernels_by_alignment[key][0]
kernels = self.kernels_by_alignment[key]
if math_operation is not None:
kernels = [x for x in kernels if x.tile_description.math_instruction.math_operation == math_operation]
return kernels[0]
def operations(self, alignment_A: int, alignment_B: int, alignment_C: int):
def operations(self, alignment_A: int, alignment_B: int, alignment_C: int, math_operation: cutlass.MathOperation):
"""
Returns operations satisfying the alignment constraints
@ -114,6 +120,8 @@ class KernelsForDataType:
:type alignment_B: int
:param alignment_C: alignment constraint of operations to return
:type alignment_C: int
:param math_operation: math operation to consider
:type math_operation: cutlass.MathOperation
:return: list of operations
:rtype: list
@ -126,13 +134,26 @@ class KernelsForDataType:
min_alignment = min(alignment_A, alignment_B, alignment_C)
key = f"{min_alignment} {min_alignment} {min_alignment}"
if key not in self.kernels_by_alignment:
raise Exception(
f"No operations of alignment {og_key} found for data type and layout "
f"combination {self.datatype_comb} {self.layout_comb}. Tried to fall back "
f"to alignment {key}, but that was also not compatible. Compatible alignments "
f"are {self.kernels_by_alignment.keys()}"
)
return self.kernels_by_alignment[key]
# Finally, go through all available alignment combinations and find
# one for which all values are less than those passed in.
key = None
alignments = sorted([(int(x) for x in k.split(" ")) for k in self.kernels_by_alignment.keys()], reverse=True)
for align_A, align_B, align_C in alignments:
if align_A <= alignment_A and align_B <= alignment_B and align_C <= alignment_C:
key = f"{align_A} {align_B} {align_C}"
break
if key is None:
raise Exception(
f"No operations of alignment {og_key} found for data type and layout "
f"combination {self.datatype_comb} {self.layout_comb}. Compatible alignments "
f"are {self.kernels_by_alignment.keys()}"
)
ops = self.kernels_by_alignment[key]
if math_operation is not None:
ops = [op for op in ops if op.tile_description.math_instruction.math_operation == math_operation]
return ops
def _operand_idx(self, key: str) -> int:
operand_list = ["A", "B", "C"]
@ -187,6 +208,18 @@ class KernelsForDataType:
for alignment in self.kernels_by_alignment.keys():
self.kernels_by_alignment[alignment].sort(key=key, reverse=True)
def supports_math_operation(self, math_operation: cutlass.MathOperation) -> bool:
"""
Returns whether `math_operation` is supported by at least one operation.
:param math_operation: math operation to consider
:type math_operation: cutlass.MathOperation
:return: whether math_operation is supported by at least one operation
:rtype: bool
"""
return math_operation is None or math_operation in self.math_operations
class ArchOptions:
"""
@ -213,7 +246,8 @@ class ArchOptions:
allowed_math_operations: list = [
cutlass_library.MathOperation.multiply_add,
cutlass_library.MathOperation.multiply_add_saturate,
cutlass_library.MathOperation.multiply_add_mixed_input_upcast
cutlass_library.MathOperation.multiply_add_mixed_input_upcast,
cutlass_library.MathOperation.multiply_add_fast_f32
]
):
self.cc = kernel_cc
@ -270,8 +304,6 @@ class ArchOptions:
if mi.math_operation not in self.allowed_math_operations:
continue
datatype_comb = (mi.element_a, mi.element_b, mi.element_accumulator)
# Prune operations that don't fit in shared memory
td = td_from_profiler_op(op)
if not valid_stage_count(target_cc, kernel_cc, td, verbose=False)[0]:
@ -323,6 +355,15 @@ class ArchOptions:
(cutlass_library.DataType.f64, cutlass_library.DataType.f64, cutlass_library.DataType.f64),
]
# Add FP8 A/B/C
fp8_types = [cutlass_library.DataType.e4m3, cutlass_library.DataType.e5m2]
for type_comb in combinations_with_replacement(fp8_types, 3):
types.append(type_comb)
# Add FP8 A/B with FP32 C
for type_comb in combinations_with_replacement(fp8_types, 2):
types.append(type_comb + (cutlass.DataType.f32,))
layouts = [
(cutlass_library.LayoutType.RowMajor, cutlass_library.LayoutType.RowMajor),
(cutlass_library.LayoutType.RowMajor, cutlass_library.LayoutType.ColumnMajor),
@ -395,7 +436,7 @@ class ArchOptions:
self.operations_by_opclass[oc][comb].sort()
def opclass_supports_combination(
self, op_class: cutlass_library.OpcodeClass, datatype_comb: tuple, layout_comb: tuple
self, op_class: cutlass_library.OpcodeClass, datatype_comb: tuple, layout_comb: tuple, math_operation: cutlass_library.MathOperation
) -> bool:
"""
Returns whether the provided operation class supports the provided data type and layout combination
@ -406,6 +447,8 @@ class ArchOptions:
:type datatype_comb: tuple[cutlass_library.DataType]
:param layout_comb: tuple of data types for (layout_A, layout_B)
:type layout_comb: tuple[cutlass_library.LayoutType]
:param math_operation: math operation to consider or None if any can be considered
:type math_operation: cutlass.MathOperation
:return: set of operation classes that support the provided data type and layout combination
:rtype: set
@ -413,7 +456,14 @@ class ArchOptions:
if op_class not in self.operations_by_opclass:
raise Exception(f"Unexpected or unsupported operation class {op_class}")
return (datatype_comb, layout_comb) in self.operations_by_opclass[op_class]
if operations := self.operations_by_opclass[op_class].get((datatype_comb, layout_comb)):
if math_operation is not None:
return operations.supports_math_operation(math_operation)
else:
return True
return False
def supporting_opclasses(
self,
@ -422,6 +472,7 @@ class ArchOptions:
element_accumulator: cutlass_library.DataType,
layout_a: cutlass_library.LayoutType,
layout_b: cutlass_library.LayoutType,
math_operation: cutlass_library.MathOperation,
) -> set:
"""
Returns a set of operation classes that support the provided data type combination
@ -436,6 +487,8 @@ class ArchOptions:
:type layout_a: cutlass_library.LayoutType
:param layout_b: layout of operand B
:type layout_b: cutlass_library.LayoutType
:param math_operation: math operation to consider
:type math_operation: cutlass.MathOperation
:return: set of operation classes that support the provided data type combination
:rtype: set
@ -445,7 +498,7 @@ class ArchOptions:
layout_comb = (layout_a, layout_b)
for op_class in self.operations_by_opclass.keys():
if self.opclass_supports_combination(op_class, datatype_comb, layout_comb):
if self.opclass_supports_combination(op_class, datatype_comb, layout_comb, math_operation):
supporting_op_classes.add(op_class)
return supporting_op_classes
@ -457,6 +510,7 @@ class ArchOptions:
element_accumulator: cutlass_library.DataType,
layout_a: cutlass_library.LayoutType,
layout_b: cutlass_library.LayoutType,
math_operation: cutlass_library.MathOperation,
) -> KernelsForDataType:
"""
Returns whether the provided operation class supports the provided data type combination
@ -473,13 +527,15 @@ class ArchOptions:
:type layout_a: cutlass_library.LayoutType
:param layout_b: layout of operand B
:type layout_b: cutlass_library.LayoutType
:param math_operation: math operation to consider
:type math_operation: cutlass.MathOperation
:return: container of kernels by alignment supported by the provided combination of parameters
:rtype: KernelsForDataType
"""
datatype_comb = (element_a, element_b, element_accumulator)
layout_comb = (layout_a, layout_b)
if not self.opclass_supports_combination(op_class, datatype_comb, layout_comb):
if not self.opclass_supports_combination(op_class, datatype_comb, layout_comb, math_operation):
raise Exception(
f"Data type layout combination {datatype_comb}, {layout_comb} "
f"is not supported by opcode class {op_class} on CC {self.cc}."

View File

@ -293,7 +293,7 @@ class Conv2d(OperationBase):
self.possible_op_classes = self.options.supporting_opclasses(
self._element_a, self._element_b, self._element_accumulator,
self._layout_a, self._layout_b
self._layout_a, self._layout_b, self._math_operation
)
if cutlass.OpcodeClass.TensorOp in self.possible_op_classes:
@ -301,8 +301,13 @@ class Conv2d(OperationBase):
elif cutlass.OpcodeClass.Simt in self.possible_op_classes:
self.opclass = cutlass.OpcodeClass.Simt
else:
if self._math_operation is not None:
math_op_str = f' and math operation {self._math_operation}'
else:
math_op_str = ''
raise Exception(f'No kernel configuration found for supported data type and layout '
f'combination {datatype_comb}x{layout_comb}')
f'combination {datatype_comb}x{layout_comb}{math_op_str}')
if reset_epilogue:
self._reset_epilogue_functor_activation(epilogue.identity)
@ -345,7 +350,7 @@ class Conv2d(OperationBase):
return
if isinstance(td, dict):
if self._tile_description is None:
op = self.possible_operations.default_operation()
op = self.possible_operations.default_operation(self._math_operation)
self._tile_description = datatypes.td_from_profiler_op(op)
if "cluster_shape" in td.keys():
if td["cluster_shape"] != [1, 1, 1]:
@ -397,6 +402,11 @@ class Conv2d(OperationBase):
description_str = []
for op in self.possible_operations.all_operations:
td = datatypes.td_from_profiler_op(op)
if self._math_operation is not None:
if td.math_instruction.math_operation != self._math_operation:
continue
if str(td) not in description_str:
description_str.append(str(td))
descriptions.append(td)
@ -569,7 +579,7 @@ class Conv2d(OperationBase):
if self.tile_description is not None:
tile_description = self.tile_description
else:
op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C)[0]
op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0]
tile_description = datatypes.td_from_profiler_op(op)
else:
valid, err_str = self._valid_tile_description(tile_description)

View File

@ -202,14 +202,14 @@ class Gemm(OperationBase):
:type element_C: cutlass.DataType
:param element_D: data type to be used for operand D
:type element_D: cutlass.DataType
:type layout_A: layout of operand A
:param layout_A: cutlass.LayoutType
:type layout_B: layout of operand B
:param layout_B: cutlass.LayoutType
:type layout_C: layout of operand C
:param layout_C: cutlass.LayoutType
:type layout_D: layout of operand D
:param layout_D: cutlass.LayoutType
:param layout_A: layout of operand A
:type layout_A: cutlass.LayoutType
:param layout_B: layout of operand B
:type layout_B: cutlass.LayoutType
:param layout_C: layout of operand C
:type layout_C: cutlass.LayoutType
:param layout_D: layout of operand D
:type layout_D: cutlass.LayoutType
"""
def __init__(
@ -281,17 +281,23 @@ class Gemm(OperationBase):
# Set the default op class
datatype_comb = (self._element_a, self._element_b, self._element_accumulator)
layout_comb = (self._layout_a, self._layout_b)
self.possible_op_classes = self.options.supporting_opclasses(
self._element_a, self._element_b, self._element_accumulator,
self._layout_a, self._layout_b)
self._layout_a, self._layout_b, self._math_operation)
if cutlass.OpcodeClass.TensorOp in self.possible_op_classes:
self.opclass = cutlass.OpcodeClass.TensorOp
elif cutlass.OpcodeClass.Simt in self.possible_op_classes:
self.opclass = cutlass.OpcodeClass.Simt
else:
if self._math_operation is not None:
math_op_str = f' and math operation {self._math_operation}'
else:
math_op_str = ''
raise Exception(f'No kernel configuration found for supported data type and layout '
f'combination {datatype_comb}x{layout_comb}')
f'combination {datatype_comb}x{layout_comb}{math_op_str}')
if reset_epilogue:
self._reset_epilogue_functor_activation(cutlass.epilogue.identity)
@ -349,7 +355,7 @@ class Gemm(OperationBase):
return
if isinstance(td, dict):
if self._tile_description is None:
op = self.possible_operations.default_operation()
op = self.possible_operations.default_operation(self._math_operation)
self._tile_description = datatypes.td_from_profiler_op(op)
td = self._tile_description.clone_and_update(td)
@ -394,7 +400,10 @@ class Gemm(OperationBase):
:returns: list of valid tile descriptions for the operations
:rtype: list
"""
return [datatypes.td_from_profiler_op(op) for op in self.possible_operations.all_operations]
tds = [datatypes.td_from_profiler_op(op) for op in self.possible_operations.all_operations]
if self._math_operation is not None:
tds = [td for td in tds if td.tile_description.math_instruction == self._math_operation]
return tds
def construct(
self, tile_description: TileDescription = None,
@ -423,18 +432,19 @@ class Gemm(OperationBase):
tensor_A = TensorDescription(self._element_a, self._layout_a, alignment_A)
tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B)
alignment_pref_C = max(self.possible_operations.alignments("C"))
if self._element_c != DataType.void:
alignment_pref_C = min(128 // DataTypeSize[self._element_c], alignment_pref_C)
alignment_C = check.alignment_or_default(alignment_C, alignment_pref_C)
tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C)
self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor)
if alignment_C is None:
alignment_C = max(self.possible_operations.alignments("C"))
if self._element_c != DataType.void:
alignment_C = min(128 // DataTypeSize[self._element_c], alignment_C)
if tile_description is None:
if self._tile_description is None:
op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C)[0]
op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0]
tile_description = datatypes.td_from_profiler_op(op)
# The selected op may have lower alignment than that determined above, so we must
# reset alignment here.
alignment_C = op.C.alignment
else:
tile_description = self._tile_description
else:
@ -443,6 +453,9 @@ class Gemm(OperationBase):
raise Exception(f"Invalid tile description. {err_str}")
self._tile_description = tile_description
tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C)
self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor)
operation = GemmOperationUniversal(
arch=self.current_cc,
tile_description=tile_description,
@ -599,9 +612,13 @@ class Gemm(OperationBase):
"""
dtype, layout = datatypes.get_datatype_and_layout(tensor)
if dtype != ref_type or layout != ref_layout:
raise Exception(f'Tensor {name} with type and layout ({dtype}, {layout}) '
f'does not match the expected type and '
f'layout of ({ref_type}, {ref_layout}).')
try:
# Attempt to transpose the tensor to fit the desired layout
tensor = tensor.transpose(-1, -2)
except:
raise Exception(f'Tensor {name} with type and layout ({dtype}, {layout}) '
f'does not match the expected type and '
f'layout of ({ref_type}, {ref_layout}) and transpose failed.')
def run(self, A=None, B=None, C=None, D=None,
alpha=None, beta=None, sync: bool = True, print_module: bool = False, visitor_args: dict = None) -> GemmArguments:

View File

@ -174,7 +174,7 @@ class GroupedGemm(Gemm):
tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C)
if tile_description is None:
op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C)[0]
op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0]
tile_description = datatypes.td_from_profiler_op(op)
else:
valid, err_str = self._valid_tile_description(tile_description)

View File

@ -36,7 +36,13 @@ Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv
from bisect import bisect_left
from cutlass_library import DataType, DataTypeSize, OperationKind, SharedMemPerCC
from cutlass_library import (
DataType,
DataTypeSize,
MathOperation,
OperationKind,
SharedMemPerCC
)
import cutlass
from cutlass import get_option_registry
@ -67,6 +73,7 @@ class OperationBase:
self.specified_kernel_cc = kernel_cc is not None
self.current_cc = kernel_cc if kernel_cc is not None else self._find_closest_cc(self.cc)
self.tile_description = None
self._math_operation = None
self.options = get_option_registry().options_for_cc(self.current_cc, operation_kind)
@ -197,14 +204,10 @@ class OperationBase:
self._verify_type_and_layout(tensor, ref_dtype, ref_layout, name)
return tensor
#
# Opcode Related
#
@property
def opclass(self) -> cutlass.OpcodeClass:
"""
Returns the opcode class currently in use by the GEMM
Returns the opcode class currently in use
:return: opcode class currently in use
:rtype: cutlass.OpcodeClass
@ -226,15 +229,41 @@ class OperationBase:
# Changing the op class also changes the possible operations available. Reset these.
self.possible_operations = self.options.operations(
self.op_class, self._element_a, self._element_b,
self._element_accumulator, self._layout_a, self._layout_b)
self._element_accumulator, self._layout_a, self._layout_b, self._math_operation)
# Changing the op class changes the elements per access in the epilogue. Reset this.
if self.epilogue_functor is not None:
self.epilogue_functor = self._reset_epilogue_functor_alignment(self._elements_per_access(), self.epilogue_functor)
#
# Epilogue
#
@property
def math_operation(self) -> cutlass.MathOperation:
"""
Returns the math operation currently in use
:return: math operation currently in use
:rtype: cutlass.MathOperation
"""
return self._math_operation
@math_operation.setter
def math_operation(self, mo: cutlass.MathOperation):
if isinstance(mo, str):
mo = datatypes.getattr_enum(cutlass.MathOperation, mo)
if not self.specified_kernel_cc:
if self.current_cc == 90:
# CUTLASS 3.0 kernels do not use different math operations. If one is specified, we
# revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels.
cutlass.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
self._reset_options(80)
self._reset_operations(reset_epilogue=False)
elif self.current_cc == 90:
raise Exception("CUTLASS 3.0 kernels do not use different math operations. "
"To use 2.x kernels with a specific math operation, do not set the `kernel_cc`"
"parameter when constructing the plan.")
self._math_operation = mo
self._reset_operations()
def _elements_per_access(self):
if self.op_class == cutlass.OpcodeClass.Simt:
@ -262,7 +291,7 @@ class OperationBase:
raise Exception("CUTLASS 2.x kernels require element C to be the same as element D")
self._reset_options(80)
self._reset_operations(reset_epilogue=False)
elif (self.cc == 90 and self.current_cc != 90 and activation == identity):
elif (self.cc == 90 and self.current_cc != 90 and activation == identity and self._math_operation is None):
# SM80 fallback kernels are currently used. Since an identity activation is requested,
# we can switch back to using SM90 kernels.
self._reset_options(90)
@ -272,7 +301,7 @@ class OperationBase:
raise Exception("Epilogues with elementwise fusion are not currently supported "
"in the Python interface for 3.x kernels. To use 2.x kernels "
"with fused elementwise epilogues, do not set the `kernel_cc` "
"parameter when constructing the Gemm object.")
"parameter when constructing the plan.")
return get_activation_epilogue(
activation,
@ -364,6 +393,7 @@ class OperationBase:
# The shared memory is only a concern for sm90 epilogue
# In sm80, the epilogue and mainloop share the shared memory
return
datatype_comb = self.possible_operations.datatype_comb
layout_comb = self.possible_operations.layout_comb
new_possible_operations = KernelsForDataType(datatype_comb, layout_comb)

View File

@ -176,6 +176,17 @@ def is_torch_available():
cutlass.DataType.s32: torch.int32,
cutlass.DataType.u8: torch.uint8,
}
def possibly_add_type(torch_type_name, cutlass_type):
# Only try adding the type if the version of torch being used supports it
if hasattr(torch, torch_type_name):
torch_type = getattr(torch, torch_type_name)
_torch_to_library_dict[torch_type] = cutlass_type
_library_to_torch_dict[cutlass_type] = torch_type
possibly_add_type("float8_e4m3fn", cutlass.DataType.e4m3)
possibly_add_type("float8_e5m2", cutlass.DataType.e5m2)
except ImportError:
torch_available = False
_torch_to_library_dict = {}

View File

@ -61,7 +61,8 @@ class GemmOperation:
def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None,
kernel_schedule = KernelScheduleType.ScheduleAuto, epilogue_schedule = EpilogueScheduleType.ScheduleAuto,
tile_scheduler = TileSchedulerType.Default, extra_args = None):
tile_scheduler = TileSchedulerType.Default
):
kinds_3x = {
GemmKind.Universal3x,
@ -88,6 +89,10 @@ class GemmOperation:
self.epilogue_schedule = epilogue_schedule
self.element_epilogue = element_epilogue
self.epilogue_functor = epilogue_functor
if self.is_3x and epilogue_functor == EpilogueFunctor.LinearCombination:
self.epilogue_functor = EpilogueFunctor3x.LinearCombination
self.swizzling_functor = swizzling_functor
self.tile_scheduler = tile_scheduler
@ -709,9 +714,9 @@ class EmitGemmUniversal3xInstance:
]
self.builtin_epilogue_functor_template = """
${epilogue_functor}<
${element_d},
${element_epilogue},
${element_c},
${epilogue_vector_length},
${element_accumulator},
${element_epilogue}
>
"""
@ -726,7 +731,8 @@ using ${operation_name}_epilogue =
${element_accumulator}, ${element_epilogue},
${element_c}, ${layout_c}, ${align_c},
${element_d}, ${layout_d}, ${align_d},
${epilogue_schedule}
${epilogue_schedule},
${epilogue_functor}
>::CollectiveOp;
using ${operation_name}_mainloop =
@ -757,9 +763,11 @@ struct ${operation_name} :
def instance_template(self):
return """
${compile_guard_start}
using GemmKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>;
manifest.append(
new ${gemm_kind}<GemmKernel>("${operation_name}"));
{
using GemmKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>;
manifest.append(
new ${gemm_kind}<GemmKernel>("${operation_name}"));
}
${compile_guard_end}
"""
@ -788,9 +796,8 @@ ${compile_guard_end}
# Support built-in epilogue functors or user-defined functions
if isinstance(operation.epilogue_functor, enum.Enum):
values = {
'epilogue_vector_length': str(epilogue_vector_length),
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
'epilogue_functor': EpilogueFunctor3xTag[operation.epilogue_functor],
}
epilogue_functor = SubstituteTemplate(self.builtin_epilogue_functor_template, values)
else:
@ -799,6 +806,9 @@ ${compile_guard_end}
element_a = DataTypeTag[operation.A.element]
element_b = DataTypeTag[operation.B.element]
epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule]
element_a = DataTypeTag[operation.A.element]
element_b = DataTypeTag[operation.B.element]
epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule]
values = {
'operation_name': operation.procedural_name(),
'operation_suffix': self.operation_suffix,

View File

@ -192,14 +192,14 @@ def CreateGemmUniversal3xOperator(
C = TensorDescription(data_type["c_type"], layout[2][0], layout[2][1])
D = TensorDescription(data_type["d_type"], layout[2][0], layout[2][1])
extra_args = {}
gemm_op_extra_args = {}
gemm_kind = GemmKind.Universal3x
element_compute = data_type.get("epi_type", data_type["acc_type"])
operation = GemmOperation(
gemm_kind, tile_description.minimum_compute_capability,
tile_description, A, B, C, element_compute, epilogue_functor, swizzling_functor, D,
kernel_schedule, epilogue_schedule, tile_scheduler, extra_args)
kernel_schedule, epilogue_schedule, tile_scheduler, **gemm_op_extra_args)
manifest.append(operation)
operations.append(operation)

View File

@ -466,6 +466,13 @@ EpilogueScheduleSuffixes = {
EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma',
}
class EpilogueFunctor3x(enum.Enum):
LinearCombination = enum_auto()
#
EpilogueFunctor3xTag = {
EpilogueFunctor3x.LinearCombination: 'cutlass::epilogue::fusion::LinearCombination',
}
class TileSchedulerType(enum.Enum):
Default = enum_auto()
Persistent = enum_auto()

View File

@ -429,7 +429,7 @@ class Manifest:
self.kernel_filter_list = []
else:
self.kernel_filter_list = self.get_kernel_filters(args.kernel_filter_file)
_LOGGER.info("Using {filter_count} kernel filters from {filter_file}".format(
_LOGGER.debug("Using {filter_count} kernel filters from {filter_file}".format(
filter_count = len(self.kernel_filter_list),
filter_file = args.kernel_filter_file))

View File

@ -101,7 +101,7 @@ class Layout(LayoutBase):
# cosize(layout) Size of the codomain
def cosize(self):
return tuple_max(tuple((1, elem_scale(self.shape, self.stride))))
return self(self.size() - 1) + 1
# print and str
def __str__(self):

View File

@ -51,7 +51,7 @@ setup_pycute.perform_setup()
setup(
name='cutlass',
version='3.3.0',
version='3.4.0',
description='CUTLASS Pythonic Interface',
package_dir={'': '.'},
packages=[

View File

@ -36,7 +36,7 @@ from setuptools import setup
def perform_setup():
setup(
name='cutlass_library',
version='3.3.0',
version='3.4.0',
description='CUTLASS library generation scripts',
packages=['cutlass_library']
)

View File

@ -36,7 +36,7 @@ from setuptools import setup
def perform_setup():
setup(
name='pycute',
version='3.3.0',
version='3.4.0',
description='Python implementation of CuTe',
packages=['pycute'],
)