CUTLASS 3.4.0 (#1286)
* CUTLASS 3.4.0 * Update CHANGELOG.md --------- Co-authored-by: Pradeep Ramani <prramani@nvidia.com>
This commit is contained in:
@ -14,7 +14,7 @@ import cutlass
|
||||
import numpy as np
|
||||
|
||||
plan = cutlass.op.Gemm(element=np.float16, layout=cutlass.LayoutType.RowMajor)
|
||||
A, B, C, D = [np.ones((4096, 4096), dtype=np.float16) for i in range(4)]
|
||||
A, B, C, D = [np.ones((1024, 1024), dtype=np.float16) for i in range(4)]
|
||||
plan.run(A, B, C, D)
|
||||
```
|
||||
|
||||
@ -67,7 +67,7 @@ The CUTLASS Python interface currently supports the following operations:
|
||||
We recommend using the CUTLASS Python interface via an [NGC PyTorch Docker container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch):
|
||||
|
||||
```bash
|
||||
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:23.08-py3
|
||||
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:23.08-py3 -p 8888:8888
|
||||
```
|
||||
|
||||
The CUTLASS Python interface has been tested with CUDA 11.8, 12.0, and 12.1 on Python 3.8 and 3.9.
|
||||
@ -99,6 +99,24 @@ If you would like to be able to make changes to CUTLASS Python interface and hav
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
To test that your installation was successful, you can run:
|
||||
```python
|
||||
import cutlass
|
||||
import numpy as np
|
||||
|
||||
plan = cutlass.op.Gemm(element=np.float16, layout=cutlass.LayoutType.RowMajor)
|
||||
A, B, C, D = [np.ones((128, 128), dtype=np.float16) for i in range(4)]
|
||||
plan.run(A, B, C, D)
|
||||
```
|
||||
|
||||
### Deep learning framework CUDA extensions
|
||||
The CUTLASS Python interface provides utilities for exporting a CUTLASS kernel to a deep learning framework CUDA extensions. Currently, PyTorch CUDA extensions can be exported, but a similar pattern could be applied for other frameworks as well. An example of this is provided [here](/examples/python/02_pytorch_extension_grouped_gemm.ipynb).
|
||||
|
||||
Currently, the following operations can be exported to a PyTorch CUDA extension:
|
||||
* GEMM
|
||||
* Grouped GEMM
|
||||
* Conv2d
|
||||
|
||||
### Examples
|
||||
Jupyter notebook examples of using the CUTLASS Python interface are located in [examples/python](/examples/python).
|
||||
|
||||
|
||||
@ -75,6 +75,7 @@ from cutlass_library import (
|
||||
DataType,
|
||||
EpilogueScheduleType,
|
||||
KernelScheduleType,
|
||||
MathOperation,
|
||||
LayoutType,
|
||||
OpcodeClass,
|
||||
TileDescription,
|
||||
@ -120,7 +121,7 @@ def get_option_registry():
|
||||
this._option_registry = OptionRegistry(device_cc())
|
||||
return this._option_registry
|
||||
|
||||
this.__version__ = '3.3.0'
|
||||
this.__version__ = '3.4.0'
|
||||
|
||||
from cutlass.backend import create_memory_pool
|
||||
from cutlass.emit.pytorch import pytorch
|
||||
|
||||
@ -34,7 +34,8 @@ import ctypes
|
||||
|
||||
from cutlass_library import (
|
||||
DataType,
|
||||
KernelScheduleType
|
||||
KernelScheduleType,
|
||||
TileSchedulerType
|
||||
)
|
||||
from cutlass.backend.library import DataTypeSizeBytes
|
||||
|
||||
@ -99,6 +100,7 @@ class StrideBatched_(ctypes.Structure):
|
||||
]
|
||||
|
||||
|
||||
|
||||
class GenericMainloopArguments3x_(ctypes.Structure):
|
||||
"""
|
||||
Structure representing the superset of possible mainloop arguments.
|
||||
@ -115,6 +117,45 @@ class GenericMainloopArguments3x_(ctypes.Structure):
|
||||
]
|
||||
|
||||
|
||||
class _PersistentTileSchedulerArguments(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("max_swizzle_size", ctypes.c_int),
|
||||
("raster_order_option", ctypes.c_int),
|
||||
]
|
||||
|
||||
|
||||
class _PersistentTileSchedulerStreamKArguments(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("splits", ctypes.c_int),
|
||||
("max_swizzle_size", ctypes.c_int),
|
||||
("raster_order_option", ctypes.c_int),
|
||||
("reduction_mode", ctypes.c_int),
|
||||
("decomposition_mode", ctypes.c_int),
|
||||
]
|
||||
|
||||
|
||||
def get_tile_scheduler_arguments_3x(
|
||||
tile_scheduler: TileSchedulerType,
|
||||
splits: int = 1):
|
||||
max_swizzle_size = 1
|
||||
raster_order_option = 0 # Heuristic
|
||||
if tile_scheduler == TileSchedulerType.Persistent:
|
||||
return _PersistentTileSchedulerArguments(
|
||||
max_swizzle_size,
|
||||
raster_order_option,
|
||||
)
|
||||
elif tile_scheduler == TileSchedulerType.StreamK:
|
||||
reduction_mode = 0 # Deterministic
|
||||
decomposition_mode = 0 # Heuristic
|
||||
return _PersistentTileSchedulerStreamKArguments(
|
||||
splits,
|
||||
max_swizzle_size,
|
||||
raster_order_option,
|
||||
reduction_mode,
|
||||
decomposition_mode,
|
||||
)
|
||||
|
||||
|
||||
def get_mainloop_arguments_3x(
|
||||
kernel_schedule: KernelScheduleType,
|
||||
element_A,
|
||||
@ -172,7 +213,7 @@ def get_mainloop_arguments_3x(
|
||||
return _MainloopArgumentsTma
|
||||
|
||||
|
||||
def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor):
|
||||
def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor, scheduler_args):
|
||||
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
|
||||
if hasattr(epilogue_functor, "visitor"):
|
||||
class _EpilogueArguments(ctypes.Structure):
|
||||
@ -187,7 +228,6 @@ def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor):
|
||||
self.arg_C = epilogue_functor.arg_c_type(ptr_c)
|
||||
self.arg_D = epilogue_functor.arg_d_type(ptr_d)
|
||||
else:
|
||||
|
||||
class _EpilogueArguments(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("epilogue", _EpilogueOutputOpParams),
|
||||
@ -210,7 +250,7 @@ def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor):
|
||||
("mainloop", mainloop_arguments),
|
||||
("epilogue", _EpilogueArguments),
|
||||
("hw_info", _HardwareInfo),
|
||||
("splits", ctypes.c_int)
|
||||
("scheduler", type(scheduler_args)),
|
||||
]
|
||||
|
||||
return _GemmArguments, _EpilogueArguments, _EpilogueOutputOpParams, _HardwareInfo
|
||||
|
||||
@ -87,7 +87,7 @@ class Sm90LoadSrcImpl(LoadSrcImpl):
|
||||
self._type_decl = f"""
|
||||
using ElementC = {DataTypeTag[self.element]};
|
||||
using StrideC = {self.stride_mnl};
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90SrcFetch;
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90SrcFetch<{DataTypeTag[self.element]}>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
@ -44,6 +44,7 @@ from cutlass.backend.epilogue import EpilogueFunctorBase
|
||||
import cutlass.backend.evt.backend
|
||||
from cutlass.backend.frontend import TensorFrontend
|
||||
from cutlass.utils.datatypes import is_numpy_tensor
|
||||
from cutlass.backend.evt.passes.util import cc_map
|
||||
|
||||
|
||||
class EpilogueFunctorVisitor(EpilogueFunctorBase):
|
||||
@ -56,7 +57,7 @@ class EpilogueFunctorVisitor(EpilogueFunctorBase):
|
||||
"""
|
||||
def __init__(self, cc: int, visitor, element_compute=DataType.f32) -> None:
|
||||
# Type of Emitter based on CC
|
||||
self.emit_cls = getattr(cutlass.backend.evt.backend, f"Sm{cc}Emitter")
|
||||
self.emit_cls = getattr(cutlass.backend.evt.backend, f"Sm{cc_map[cc]}Emitter")
|
||||
|
||||
# Visitor Types
|
||||
self.visitor = visitor
|
||||
|
||||
@ -45,6 +45,7 @@ from cutlass.backend.evt.passes.pass_fix_element_d import PassFixElementD
|
||||
from cutlass.backend.evt.passes.pass_manager import EVTPassBase
|
||||
from cutlass.backend.evt.passes.pass_no_op_elimination import PassNoOpElimination
|
||||
from cutlass.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
|
||||
from cutlass.backend.evt.passes.util import cc_map
|
||||
|
||||
|
||||
class PassGetImpl(EVTPassBase):
|
||||
@ -82,8 +83,8 @@ class PassGetImpl(EVTPassBase):
|
||||
self.no_op_elimination()
|
||||
# Lower to cc-specific impl
|
||||
for node_meta in self.dag_ir.nodes_meta:
|
||||
node_impl_ccs = getattr(evt_backend, f"sm{self.cc}_nodes")
|
||||
node_impl_ccs = getattr(evt_backend, f"sm{cc_map[self.cc]}_nodes")
|
||||
node_meta.underlying_impl = getattr(
|
||||
node_impl_ccs,
|
||||
f"Sm{self.cc}" + node_meta.underlying_impl.__class__.__name__
|
||||
f"Sm{cc_map[self.cc]}" + node_meta.underlying_impl.__class__.__name__
|
||||
)(node_meta)
|
||||
|
||||
@ -39,6 +39,7 @@ from typing import Any
|
||||
import networkx as nx
|
||||
|
||||
from cutlass.backend.evt.ir import DAGIR
|
||||
from cutlass.backend.evt.passes.util import cc_map
|
||||
|
||||
|
||||
class EVTPassBase:
|
||||
@ -102,7 +103,7 @@ class EVTPassBase:
|
||||
// sm80 specific method
|
||||
return
|
||||
"""
|
||||
func_name = f"sm{self.cc}_{func.__name__}"
|
||||
func_name = f"sm{cc_map[self.cc]}_{func.__name__}"
|
||||
if hasattr(self, func_name):
|
||||
return getattr(self, func_name)
|
||||
else:
|
||||
|
||||
43
python/cutlass/backend/evt/passes/util.py
Normal file
43
python/cutlass/backend/evt/passes/util.py
Normal file
@ -0,0 +1,43 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Utilities for passes
|
||||
"""
|
||||
|
||||
# Map from the CC of the kernel to the EVT implementation that the CC targets
|
||||
cc_map = {
|
||||
80: 80,
|
||||
86: 80,
|
||||
89: 80,
|
||||
90: 90,
|
||||
}
|
||||
@ -82,7 +82,8 @@ from cutlass.backend.c_types import (
|
||||
get_gemm_arguments_3x,
|
||||
get_gemm_arguments_streamk,
|
||||
get_gemm_grouped_arguments,
|
||||
get_mainloop_arguments_3x
|
||||
get_mainloop_arguments_3x,
|
||||
get_tile_scheduler_arguments_3x,
|
||||
)
|
||||
from cutlass.backend.library import (
|
||||
ApiVersion,
|
||||
@ -554,6 +555,7 @@ class GemmArguments3x(GemmArguments2x):
|
||||
mainloop,
|
||||
epilogue,
|
||||
hw_info,
|
||||
self.operation.rt_module.scheduler_args
|
||||
)
|
||||
return self.arguments
|
||||
|
||||
@ -1163,7 +1165,9 @@ extern "C" {
|
||||
operation.A.alignment,
|
||||
operation.B.alignment
|
||||
)
|
||||
self.argument_type, self.epilogue_args, self.epilogue_type, self.hw_info = get_gemm_arguments_3x(self.mainloop_args, operation.epilogue_functor)
|
||||
self.scheduler_args = get_tile_scheduler_arguments_3x(operation.tile_description.tile_scheduler)
|
||||
self.argument_type, self.epilogue_args, self.epilogue_type, self.hw_info = get_gemm_arguments_3x(
|
||||
self.mainloop_args, operation.epilogue_functor, self.scheduler_args)
|
||||
|
||||
def get_device_workspace_size(self, arguments: GemmArguments3x):
|
||||
return self.get_kernel_workspace_size(ctypes.byref(arguments.get_arguments()))
|
||||
|
||||
@ -34,6 +34,7 @@
|
||||
Classes containing valid operations for a given compute capability and data types.
|
||||
"""
|
||||
|
||||
from itertools import combinations_with_replacement
|
||||
import logging
|
||||
|
||||
from cuda import __version__
|
||||
@ -60,6 +61,7 @@ class KernelsForDataType:
|
||||
def __init__(self, datatype_comb: tuple, layout_comb: tuple):
|
||||
self.datatype_comb = datatype_comb
|
||||
self.layout_comb = layout_comb
|
||||
self.math_operations = set()
|
||||
|
||||
# Dictionary mapping from alignment (int) to a list of kernels that fit the alignment
|
||||
# constraint for the data type combination
|
||||
@ -73,6 +75,7 @@ class KernelsForDataType:
|
||||
if alignment_key not in self.kernels_by_alignment:
|
||||
self.kernels_by_alignment[alignment_key] = []
|
||||
self.kernels_by_alignment[alignment_key].append(operation)
|
||||
self.math_operations.add(operation.tile_description.math_instruction.math_operation)
|
||||
|
||||
def alignments(self, operand: str):
|
||||
"""
|
||||
@ -100,11 +103,14 @@ class KernelsForDataType:
|
||||
ops.extend(alignment_ops)
|
||||
return ops
|
||||
|
||||
def default_operation(self):
|
||||
def default_operation(self, math_operation: cutlass.MathOperation):
|
||||
key = sorted(list(self.kernels_by_alignment.keys()))[0]
|
||||
return self.kernels_by_alignment[key][0]
|
||||
kernels = self.kernels_by_alignment[key]
|
||||
if math_operation is not None:
|
||||
kernels = [x for x in kernels if x.tile_description.math_instruction.math_operation == math_operation]
|
||||
return kernels[0]
|
||||
|
||||
def operations(self, alignment_A: int, alignment_B: int, alignment_C: int):
|
||||
def operations(self, alignment_A: int, alignment_B: int, alignment_C: int, math_operation: cutlass.MathOperation):
|
||||
"""
|
||||
Returns operations satisfying the alignment constraints
|
||||
|
||||
@ -114,6 +120,8 @@ class KernelsForDataType:
|
||||
:type alignment_B: int
|
||||
:param alignment_C: alignment constraint of operations to return
|
||||
:type alignment_C: int
|
||||
:param math_operation: math operation to consider
|
||||
:type math_operation: cutlass.MathOperation
|
||||
|
||||
:return: list of operations
|
||||
:rtype: list
|
||||
@ -126,13 +134,26 @@ class KernelsForDataType:
|
||||
min_alignment = min(alignment_A, alignment_B, alignment_C)
|
||||
key = f"{min_alignment} {min_alignment} {min_alignment}"
|
||||
if key not in self.kernels_by_alignment:
|
||||
raise Exception(
|
||||
f"No operations of alignment {og_key} found for data type and layout "
|
||||
f"combination {self.datatype_comb} {self.layout_comb}. Tried to fall back "
|
||||
f"to alignment {key}, but that was also not compatible. Compatible alignments "
|
||||
f"are {self.kernels_by_alignment.keys()}"
|
||||
)
|
||||
return self.kernels_by_alignment[key]
|
||||
# Finally, go through all available alignment combinations and find
|
||||
# one for which all values are less than those passed in.
|
||||
key = None
|
||||
alignments = sorted([(int(x) for x in k.split(" ")) for k in self.kernels_by_alignment.keys()], reverse=True)
|
||||
for align_A, align_B, align_C in alignments:
|
||||
if align_A <= alignment_A and align_B <= alignment_B and align_C <= alignment_C:
|
||||
key = f"{align_A} {align_B} {align_C}"
|
||||
break
|
||||
|
||||
if key is None:
|
||||
raise Exception(
|
||||
f"No operations of alignment {og_key} found for data type and layout "
|
||||
f"combination {self.datatype_comb} {self.layout_comb}. Compatible alignments "
|
||||
f"are {self.kernels_by_alignment.keys()}"
|
||||
)
|
||||
|
||||
ops = self.kernels_by_alignment[key]
|
||||
if math_operation is not None:
|
||||
ops = [op for op in ops if op.tile_description.math_instruction.math_operation == math_operation]
|
||||
return ops
|
||||
|
||||
def _operand_idx(self, key: str) -> int:
|
||||
operand_list = ["A", "B", "C"]
|
||||
@ -187,6 +208,18 @@ class KernelsForDataType:
|
||||
for alignment in self.kernels_by_alignment.keys():
|
||||
self.kernels_by_alignment[alignment].sort(key=key, reverse=True)
|
||||
|
||||
def supports_math_operation(self, math_operation: cutlass.MathOperation) -> bool:
|
||||
"""
|
||||
Returns whether `math_operation` is supported by at least one operation.
|
||||
|
||||
:param math_operation: math operation to consider
|
||||
:type math_operation: cutlass.MathOperation
|
||||
|
||||
:return: whether math_operation is supported by at least one operation
|
||||
:rtype: bool
|
||||
"""
|
||||
return math_operation is None or math_operation in self.math_operations
|
||||
|
||||
|
||||
class ArchOptions:
|
||||
"""
|
||||
@ -213,7 +246,8 @@ class ArchOptions:
|
||||
allowed_math_operations: list = [
|
||||
cutlass_library.MathOperation.multiply_add,
|
||||
cutlass_library.MathOperation.multiply_add_saturate,
|
||||
cutlass_library.MathOperation.multiply_add_mixed_input_upcast
|
||||
cutlass_library.MathOperation.multiply_add_mixed_input_upcast,
|
||||
cutlass_library.MathOperation.multiply_add_fast_f32
|
||||
]
|
||||
):
|
||||
self.cc = kernel_cc
|
||||
@ -270,8 +304,6 @@ class ArchOptions:
|
||||
if mi.math_operation not in self.allowed_math_operations:
|
||||
continue
|
||||
|
||||
datatype_comb = (mi.element_a, mi.element_b, mi.element_accumulator)
|
||||
|
||||
# Prune operations that don't fit in shared memory
|
||||
td = td_from_profiler_op(op)
|
||||
if not valid_stage_count(target_cc, kernel_cc, td, verbose=False)[0]:
|
||||
@ -323,6 +355,15 @@ class ArchOptions:
|
||||
(cutlass_library.DataType.f64, cutlass_library.DataType.f64, cutlass_library.DataType.f64),
|
||||
]
|
||||
|
||||
# Add FP8 A/B/C
|
||||
fp8_types = [cutlass_library.DataType.e4m3, cutlass_library.DataType.e5m2]
|
||||
for type_comb in combinations_with_replacement(fp8_types, 3):
|
||||
types.append(type_comb)
|
||||
|
||||
# Add FP8 A/B with FP32 C
|
||||
for type_comb in combinations_with_replacement(fp8_types, 2):
|
||||
types.append(type_comb + (cutlass.DataType.f32,))
|
||||
|
||||
layouts = [
|
||||
(cutlass_library.LayoutType.RowMajor, cutlass_library.LayoutType.RowMajor),
|
||||
(cutlass_library.LayoutType.RowMajor, cutlass_library.LayoutType.ColumnMajor),
|
||||
@ -395,7 +436,7 @@ class ArchOptions:
|
||||
self.operations_by_opclass[oc][comb].sort()
|
||||
|
||||
def opclass_supports_combination(
|
||||
self, op_class: cutlass_library.OpcodeClass, datatype_comb: tuple, layout_comb: tuple
|
||||
self, op_class: cutlass_library.OpcodeClass, datatype_comb: tuple, layout_comb: tuple, math_operation: cutlass_library.MathOperation
|
||||
) -> bool:
|
||||
"""
|
||||
Returns whether the provided operation class supports the provided data type and layout combination
|
||||
@ -406,6 +447,8 @@ class ArchOptions:
|
||||
:type datatype_comb: tuple[cutlass_library.DataType]
|
||||
:param layout_comb: tuple of data types for (layout_A, layout_B)
|
||||
:type layout_comb: tuple[cutlass_library.LayoutType]
|
||||
:param math_operation: math operation to consider or None if any can be considered
|
||||
:type math_operation: cutlass.MathOperation
|
||||
|
||||
:return: set of operation classes that support the provided data type and layout combination
|
||||
:rtype: set
|
||||
@ -413,7 +456,14 @@ class ArchOptions:
|
||||
if op_class not in self.operations_by_opclass:
|
||||
raise Exception(f"Unexpected or unsupported operation class {op_class}")
|
||||
|
||||
return (datatype_comb, layout_comb) in self.operations_by_opclass[op_class]
|
||||
if operations := self.operations_by_opclass[op_class].get((datatype_comb, layout_comb)):
|
||||
if math_operation is not None:
|
||||
return operations.supports_math_operation(math_operation)
|
||||
else:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def supporting_opclasses(
|
||||
self,
|
||||
@ -422,6 +472,7 @@ class ArchOptions:
|
||||
element_accumulator: cutlass_library.DataType,
|
||||
layout_a: cutlass_library.LayoutType,
|
||||
layout_b: cutlass_library.LayoutType,
|
||||
math_operation: cutlass_library.MathOperation,
|
||||
) -> set:
|
||||
"""
|
||||
Returns a set of operation classes that support the provided data type combination
|
||||
@ -436,6 +487,8 @@ class ArchOptions:
|
||||
:type layout_a: cutlass_library.LayoutType
|
||||
:param layout_b: layout of operand B
|
||||
:type layout_b: cutlass_library.LayoutType
|
||||
:param math_operation: math operation to consider
|
||||
:type math_operation: cutlass.MathOperation
|
||||
|
||||
:return: set of operation classes that support the provided data type combination
|
||||
:rtype: set
|
||||
@ -445,7 +498,7 @@ class ArchOptions:
|
||||
layout_comb = (layout_a, layout_b)
|
||||
|
||||
for op_class in self.operations_by_opclass.keys():
|
||||
if self.opclass_supports_combination(op_class, datatype_comb, layout_comb):
|
||||
if self.opclass_supports_combination(op_class, datatype_comb, layout_comb, math_operation):
|
||||
supporting_op_classes.add(op_class)
|
||||
return supporting_op_classes
|
||||
|
||||
@ -457,6 +510,7 @@ class ArchOptions:
|
||||
element_accumulator: cutlass_library.DataType,
|
||||
layout_a: cutlass_library.LayoutType,
|
||||
layout_b: cutlass_library.LayoutType,
|
||||
math_operation: cutlass_library.MathOperation,
|
||||
) -> KernelsForDataType:
|
||||
"""
|
||||
Returns whether the provided operation class supports the provided data type combination
|
||||
@ -473,13 +527,15 @@ class ArchOptions:
|
||||
:type layout_a: cutlass_library.LayoutType
|
||||
:param layout_b: layout of operand B
|
||||
:type layout_b: cutlass_library.LayoutType
|
||||
:param math_operation: math operation to consider
|
||||
:type math_operation: cutlass.MathOperation
|
||||
|
||||
:return: container of kernels by alignment supported by the provided combination of parameters
|
||||
:rtype: KernelsForDataType
|
||||
"""
|
||||
datatype_comb = (element_a, element_b, element_accumulator)
|
||||
layout_comb = (layout_a, layout_b)
|
||||
if not self.opclass_supports_combination(op_class, datatype_comb, layout_comb):
|
||||
if not self.opclass_supports_combination(op_class, datatype_comb, layout_comb, math_operation):
|
||||
raise Exception(
|
||||
f"Data type layout combination {datatype_comb}, {layout_comb} "
|
||||
f"is not supported by opcode class {op_class} on CC {self.cc}."
|
||||
|
||||
@ -293,7 +293,7 @@ class Conv2d(OperationBase):
|
||||
|
||||
self.possible_op_classes = self.options.supporting_opclasses(
|
||||
self._element_a, self._element_b, self._element_accumulator,
|
||||
self._layout_a, self._layout_b
|
||||
self._layout_a, self._layout_b, self._math_operation
|
||||
)
|
||||
|
||||
if cutlass.OpcodeClass.TensorOp in self.possible_op_classes:
|
||||
@ -301,8 +301,13 @@ class Conv2d(OperationBase):
|
||||
elif cutlass.OpcodeClass.Simt in self.possible_op_classes:
|
||||
self.opclass = cutlass.OpcodeClass.Simt
|
||||
else:
|
||||
if self._math_operation is not None:
|
||||
math_op_str = f' and math operation {self._math_operation}'
|
||||
else:
|
||||
math_op_str = ''
|
||||
|
||||
raise Exception(f'No kernel configuration found for supported data type and layout '
|
||||
f'combination {datatype_comb}x{layout_comb}')
|
||||
f'combination {datatype_comb}x{layout_comb}{math_op_str}')
|
||||
|
||||
if reset_epilogue:
|
||||
self._reset_epilogue_functor_activation(epilogue.identity)
|
||||
@ -345,7 +350,7 @@ class Conv2d(OperationBase):
|
||||
return
|
||||
if isinstance(td, dict):
|
||||
if self._tile_description is None:
|
||||
op = self.possible_operations.default_operation()
|
||||
op = self.possible_operations.default_operation(self._math_operation)
|
||||
self._tile_description = datatypes.td_from_profiler_op(op)
|
||||
if "cluster_shape" in td.keys():
|
||||
if td["cluster_shape"] != [1, 1, 1]:
|
||||
@ -397,6 +402,11 @@ class Conv2d(OperationBase):
|
||||
description_str = []
|
||||
for op in self.possible_operations.all_operations:
|
||||
td = datatypes.td_from_profiler_op(op)
|
||||
|
||||
if self._math_operation is not None:
|
||||
if td.math_instruction.math_operation != self._math_operation:
|
||||
continue
|
||||
|
||||
if str(td) not in description_str:
|
||||
description_str.append(str(td))
|
||||
descriptions.append(td)
|
||||
@ -569,7 +579,7 @@ class Conv2d(OperationBase):
|
||||
if self.tile_description is not None:
|
||||
tile_description = self.tile_description
|
||||
else:
|
||||
op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C)[0]
|
||||
op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0]
|
||||
tile_description = datatypes.td_from_profiler_op(op)
|
||||
else:
|
||||
valid, err_str = self._valid_tile_description(tile_description)
|
||||
|
||||
@ -202,14 +202,14 @@ class Gemm(OperationBase):
|
||||
:type element_C: cutlass.DataType
|
||||
:param element_D: data type to be used for operand D
|
||||
:type element_D: cutlass.DataType
|
||||
:type layout_A: layout of operand A
|
||||
:param layout_A: cutlass.LayoutType
|
||||
:type layout_B: layout of operand B
|
||||
:param layout_B: cutlass.LayoutType
|
||||
:type layout_C: layout of operand C
|
||||
:param layout_C: cutlass.LayoutType
|
||||
:type layout_D: layout of operand D
|
||||
:param layout_D: cutlass.LayoutType
|
||||
:param layout_A: layout of operand A
|
||||
:type layout_A: cutlass.LayoutType
|
||||
:param layout_B: layout of operand B
|
||||
:type layout_B: cutlass.LayoutType
|
||||
:param layout_C: layout of operand C
|
||||
:type layout_C: cutlass.LayoutType
|
||||
:param layout_D: layout of operand D
|
||||
:type layout_D: cutlass.LayoutType
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -281,17 +281,23 @@ class Gemm(OperationBase):
|
||||
# Set the default op class
|
||||
datatype_comb = (self._element_a, self._element_b, self._element_accumulator)
|
||||
layout_comb = (self._layout_a, self._layout_b)
|
||||
|
||||
self.possible_op_classes = self.options.supporting_opclasses(
|
||||
self._element_a, self._element_b, self._element_accumulator,
|
||||
self._layout_a, self._layout_b)
|
||||
self._layout_a, self._layout_b, self._math_operation)
|
||||
|
||||
if cutlass.OpcodeClass.TensorOp in self.possible_op_classes:
|
||||
self.opclass = cutlass.OpcodeClass.TensorOp
|
||||
elif cutlass.OpcodeClass.Simt in self.possible_op_classes:
|
||||
self.opclass = cutlass.OpcodeClass.Simt
|
||||
else:
|
||||
if self._math_operation is not None:
|
||||
math_op_str = f' and math operation {self._math_operation}'
|
||||
else:
|
||||
math_op_str = ''
|
||||
|
||||
raise Exception(f'No kernel configuration found for supported data type and layout '
|
||||
f'combination {datatype_comb}x{layout_comb}')
|
||||
f'combination {datatype_comb}x{layout_comb}{math_op_str}')
|
||||
|
||||
if reset_epilogue:
|
||||
self._reset_epilogue_functor_activation(cutlass.epilogue.identity)
|
||||
@ -349,7 +355,7 @@ class Gemm(OperationBase):
|
||||
return
|
||||
if isinstance(td, dict):
|
||||
if self._tile_description is None:
|
||||
op = self.possible_operations.default_operation()
|
||||
op = self.possible_operations.default_operation(self._math_operation)
|
||||
self._tile_description = datatypes.td_from_profiler_op(op)
|
||||
td = self._tile_description.clone_and_update(td)
|
||||
|
||||
@ -394,7 +400,10 @@ class Gemm(OperationBase):
|
||||
:returns: list of valid tile descriptions for the operations
|
||||
:rtype: list
|
||||
"""
|
||||
return [datatypes.td_from_profiler_op(op) for op in self.possible_operations.all_operations]
|
||||
tds = [datatypes.td_from_profiler_op(op) for op in self.possible_operations.all_operations]
|
||||
if self._math_operation is not None:
|
||||
tds = [td for td in tds if td.tile_description.math_instruction == self._math_operation]
|
||||
return tds
|
||||
|
||||
def construct(
|
||||
self, tile_description: TileDescription = None,
|
||||
@ -423,18 +432,19 @@ class Gemm(OperationBase):
|
||||
tensor_A = TensorDescription(self._element_a, self._layout_a, alignment_A)
|
||||
tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B)
|
||||
|
||||
alignment_pref_C = max(self.possible_operations.alignments("C"))
|
||||
if self._element_c != DataType.void:
|
||||
alignment_pref_C = min(128 // DataTypeSize[self._element_c], alignment_pref_C)
|
||||
|
||||
alignment_C = check.alignment_or_default(alignment_C, alignment_pref_C)
|
||||
tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C)
|
||||
self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor)
|
||||
if alignment_C is None:
|
||||
alignment_C = max(self.possible_operations.alignments("C"))
|
||||
if self._element_c != DataType.void:
|
||||
alignment_C = min(128 // DataTypeSize[self._element_c], alignment_C)
|
||||
|
||||
if tile_description is None:
|
||||
if self._tile_description is None:
|
||||
op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C)[0]
|
||||
op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0]
|
||||
tile_description = datatypes.td_from_profiler_op(op)
|
||||
|
||||
# The selected op may have lower alignment than that determined above, so we must
|
||||
# reset alignment here.
|
||||
alignment_C = op.C.alignment
|
||||
else:
|
||||
tile_description = self._tile_description
|
||||
else:
|
||||
@ -443,6 +453,9 @@ class Gemm(OperationBase):
|
||||
raise Exception(f"Invalid tile description. {err_str}")
|
||||
self._tile_description = tile_description
|
||||
|
||||
tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C)
|
||||
self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor)
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=self.current_cc,
|
||||
tile_description=tile_description,
|
||||
@ -599,9 +612,13 @@ class Gemm(OperationBase):
|
||||
"""
|
||||
dtype, layout = datatypes.get_datatype_and_layout(tensor)
|
||||
if dtype != ref_type or layout != ref_layout:
|
||||
raise Exception(f'Tensor {name} with type and layout ({dtype}, {layout}) '
|
||||
f'does not match the expected type and '
|
||||
f'layout of ({ref_type}, {ref_layout}).')
|
||||
try:
|
||||
# Attempt to transpose the tensor to fit the desired layout
|
||||
tensor = tensor.transpose(-1, -2)
|
||||
except:
|
||||
raise Exception(f'Tensor {name} with type and layout ({dtype}, {layout}) '
|
||||
f'does not match the expected type and '
|
||||
f'layout of ({ref_type}, {ref_layout}) and transpose failed.')
|
||||
|
||||
def run(self, A=None, B=None, C=None, D=None,
|
||||
alpha=None, beta=None, sync: bool = True, print_module: bool = False, visitor_args: dict = None) -> GemmArguments:
|
||||
|
||||
@ -174,7 +174,7 @@ class GroupedGemm(Gemm):
|
||||
tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C)
|
||||
|
||||
if tile_description is None:
|
||||
op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C)[0]
|
||||
op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0]
|
||||
tile_description = datatypes.td_from_profiler_op(op)
|
||||
else:
|
||||
valid, err_str = self._valid_tile_description(tile_description)
|
||||
|
||||
@ -36,7 +36,13 @@ Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv
|
||||
|
||||
from bisect import bisect_left
|
||||
|
||||
from cutlass_library import DataType, DataTypeSize, OperationKind, SharedMemPerCC
|
||||
from cutlass_library import (
|
||||
DataType,
|
||||
DataTypeSize,
|
||||
MathOperation,
|
||||
OperationKind,
|
||||
SharedMemPerCC
|
||||
)
|
||||
|
||||
import cutlass
|
||||
from cutlass import get_option_registry
|
||||
@ -67,6 +73,7 @@ class OperationBase:
|
||||
self.specified_kernel_cc = kernel_cc is not None
|
||||
self.current_cc = kernel_cc if kernel_cc is not None else self._find_closest_cc(self.cc)
|
||||
self.tile_description = None
|
||||
self._math_operation = None
|
||||
|
||||
self.options = get_option_registry().options_for_cc(self.current_cc, operation_kind)
|
||||
|
||||
@ -197,14 +204,10 @@ class OperationBase:
|
||||
self._verify_type_and_layout(tensor, ref_dtype, ref_layout, name)
|
||||
return tensor
|
||||
|
||||
#
|
||||
# Opcode Related
|
||||
#
|
||||
|
||||
@property
|
||||
def opclass(self) -> cutlass.OpcodeClass:
|
||||
"""
|
||||
Returns the opcode class currently in use by the GEMM
|
||||
Returns the opcode class currently in use
|
||||
|
||||
:return: opcode class currently in use
|
||||
:rtype: cutlass.OpcodeClass
|
||||
@ -226,15 +229,41 @@ class OperationBase:
|
||||
# Changing the op class also changes the possible operations available. Reset these.
|
||||
self.possible_operations = self.options.operations(
|
||||
self.op_class, self._element_a, self._element_b,
|
||||
self._element_accumulator, self._layout_a, self._layout_b)
|
||||
self._element_accumulator, self._layout_a, self._layout_b, self._math_operation)
|
||||
|
||||
# Changing the op class changes the elements per access in the epilogue. Reset this.
|
||||
if self.epilogue_functor is not None:
|
||||
self.epilogue_functor = self._reset_epilogue_functor_alignment(self._elements_per_access(), self.epilogue_functor)
|
||||
|
||||
#
|
||||
# Epilogue
|
||||
#
|
||||
@property
|
||||
def math_operation(self) -> cutlass.MathOperation:
|
||||
"""
|
||||
Returns the math operation currently in use
|
||||
|
||||
:return: math operation currently in use
|
||||
:rtype: cutlass.MathOperation
|
||||
"""
|
||||
return self._math_operation
|
||||
|
||||
@math_operation.setter
|
||||
def math_operation(self, mo: cutlass.MathOperation):
|
||||
if isinstance(mo, str):
|
||||
mo = datatypes.getattr_enum(cutlass.MathOperation, mo)
|
||||
|
||||
if not self.specified_kernel_cc:
|
||||
if self.current_cc == 90:
|
||||
# CUTLASS 3.0 kernels do not use different math operations. If one is specified, we
|
||||
# revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels.
|
||||
cutlass.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
|
||||
self._reset_options(80)
|
||||
self._reset_operations(reset_epilogue=False)
|
||||
elif self.current_cc == 90:
|
||||
raise Exception("CUTLASS 3.0 kernels do not use different math operations. "
|
||||
"To use 2.x kernels with a specific math operation, do not set the `kernel_cc`"
|
||||
"parameter when constructing the plan.")
|
||||
|
||||
self._math_operation = mo
|
||||
self._reset_operations()
|
||||
|
||||
def _elements_per_access(self):
|
||||
if self.op_class == cutlass.OpcodeClass.Simt:
|
||||
@ -262,7 +291,7 @@ class OperationBase:
|
||||
raise Exception("CUTLASS 2.x kernels require element C to be the same as element D")
|
||||
self._reset_options(80)
|
||||
self._reset_operations(reset_epilogue=False)
|
||||
elif (self.cc == 90 and self.current_cc != 90 and activation == identity):
|
||||
elif (self.cc == 90 and self.current_cc != 90 and activation == identity and self._math_operation is None):
|
||||
# SM80 fallback kernels are currently used. Since an identity activation is requested,
|
||||
# we can switch back to using SM90 kernels.
|
||||
self._reset_options(90)
|
||||
@ -272,7 +301,7 @@ class OperationBase:
|
||||
raise Exception("Epilogues with elementwise fusion are not currently supported "
|
||||
"in the Python interface for 3.x kernels. To use 2.x kernels "
|
||||
"with fused elementwise epilogues, do not set the `kernel_cc` "
|
||||
"parameter when constructing the Gemm object.")
|
||||
"parameter when constructing the plan.")
|
||||
|
||||
return get_activation_epilogue(
|
||||
activation,
|
||||
@ -364,6 +393,7 @@ class OperationBase:
|
||||
# The shared memory is only a concern for sm90 epilogue
|
||||
# In sm80, the epilogue and mainloop share the shared memory
|
||||
return
|
||||
|
||||
datatype_comb = self.possible_operations.datatype_comb
|
||||
layout_comb = self.possible_operations.layout_comb
|
||||
new_possible_operations = KernelsForDataType(datatype_comb, layout_comb)
|
||||
|
||||
@ -176,6 +176,17 @@ def is_torch_available():
|
||||
cutlass.DataType.s32: torch.int32,
|
||||
cutlass.DataType.u8: torch.uint8,
|
||||
}
|
||||
|
||||
def possibly_add_type(torch_type_name, cutlass_type):
|
||||
# Only try adding the type if the version of torch being used supports it
|
||||
if hasattr(torch, torch_type_name):
|
||||
torch_type = getattr(torch, torch_type_name)
|
||||
_torch_to_library_dict[torch_type] = cutlass_type
|
||||
_library_to_torch_dict[cutlass_type] = torch_type
|
||||
|
||||
possibly_add_type("float8_e4m3fn", cutlass.DataType.e4m3)
|
||||
possibly_add_type("float8_e5m2", cutlass.DataType.e5m2)
|
||||
|
||||
except ImportError:
|
||||
torch_available = False
|
||||
_torch_to_library_dict = {}
|
||||
|
||||
@ -61,7 +61,8 @@ class GemmOperation:
|
||||
def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None,
|
||||
kernel_schedule = KernelScheduleType.ScheduleAuto, epilogue_schedule = EpilogueScheduleType.ScheduleAuto,
|
||||
tile_scheduler = TileSchedulerType.Default, extra_args = None):
|
||||
tile_scheduler = TileSchedulerType.Default
|
||||
):
|
||||
|
||||
kinds_3x = {
|
||||
GemmKind.Universal3x,
|
||||
@ -88,6 +89,10 @@ class GemmOperation:
|
||||
self.epilogue_schedule = epilogue_schedule
|
||||
self.element_epilogue = element_epilogue
|
||||
self.epilogue_functor = epilogue_functor
|
||||
|
||||
if self.is_3x and epilogue_functor == EpilogueFunctor.LinearCombination:
|
||||
self.epilogue_functor = EpilogueFunctor3x.LinearCombination
|
||||
|
||||
self.swizzling_functor = swizzling_functor
|
||||
self.tile_scheduler = tile_scheduler
|
||||
|
||||
@ -709,9 +714,9 @@ class EmitGemmUniversal3xInstance:
|
||||
]
|
||||
self.builtin_epilogue_functor_template = """
|
||||
${epilogue_functor}<
|
||||
${element_d},
|
||||
${element_epilogue},
|
||||
${element_c},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>
|
||||
"""
|
||||
@ -726,7 +731,8 @@ using ${operation_name}_epilogue =
|
||||
${element_accumulator}, ${element_epilogue},
|
||||
${element_c}, ${layout_c}, ${align_c},
|
||||
${element_d}, ${layout_d}, ${align_d},
|
||||
${epilogue_schedule}
|
||||
${epilogue_schedule},
|
||||
${epilogue_functor}
|
||||
>::CollectiveOp;
|
||||
|
||||
using ${operation_name}_mainloop =
|
||||
@ -757,9 +763,11 @@ struct ${operation_name} :
|
||||
def instance_template(self):
|
||||
return """
|
||||
${compile_guard_start}
|
||||
using GemmKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>;
|
||||
manifest.append(
|
||||
new ${gemm_kind}<GemmKernel>("${operation_name}"));
|
||||
{
|
||||
using GemmKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>;
|
||||
manifest.append(
|
||||
new ${gemm_kind}<GemmKernel>("${operation_name}"));
|
||||
}
|
||||
${compile_guard_end}
|
||||
"""
|
||||
|
||||
@ -788,9 +796,8 @@ ${compile_guard_end}
|
||||
# Support built-in epilogue functors or user-defined functions
|
||||
if isinstance(operation.epilogue_functor, enum.Enum):
|
||||
values = {
|
||||
'epilogue_vector_length': str(epilogue_vector_length),
|
||||
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
||||
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
||||
'epilogue_functor': EpilogueFunctor3xTag[operation.epilogue_functor],
|
||||
}
|
||||
epilogue_functor = SubstituteTemplate(self.builtin_epilogue_functor_template, values)
|
||||
else:
|
||||
@ -799,6 +806,9 @@ ${compile_guard_end}
|
||||
element_a = DataTypeTag[operation.A.element]
|
||||
element_b = DataTypeTag[operation.B.element]
|
||||
epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule]
|
||||
element_a = DataTypeTag[operation.A.element]
|
||||
element_b = DataTypeTag[operation.B.element]
|
||||
epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule]
|
||||
values = {
|
||||
'operation_name': operation.procedural_name(),
|
||||
'operation_suffix': self.operation_suffix,
|
||||
|
||||
@ -192,14 +192,14 @@ def CreateGemmUniversal3xOperator(
|
||||
C = TensorDescription(data_type["c_type"], layout[2][0], layout[2][1])
|
||||
D = TensorDescription(data_type["d_type"], layout[2][0], layout[2][1])
|
||||
|
||||
extra_args = {}
|
||||
gemm_op_extra_args = {}
|
||||
gemm_kind = GemmKind.Universal3x
|
||||
element_compute = data_type.get("epi_type", data_type["acc_type"])
|
||||
|
||||
operation = GemmOperation(
|
||||
gemm_kind, tile_description.minimum_compute_capability,
|
||||
tile_description, A, B, C, element_compute, epilogue_functor, swizzling_functor, D,
|
||||
kernel_schedule, epilogue_schedule, tile_scheduler, extra_args)
|
||||
kernel_schedule, epilogue_schedule, tile_scheduler, **gemm_op_extra_args)
|
||||
|
||||
manifest.append(operation)
|
||||
operations.append(operation)
|
||||
|
||||
@ -466,6 +466,13 @@ EpilogueScheduleSuffixes = {
|
||||
EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma',
|
||||
}
|
||||
|
||||
class EpilogueFunctor3x(enum.Enum):
|
||||
LinearCombination = enum_auto()
|
||||
#
|
||||
EpilogueFunctor3xTag = {
|
||||
EpilogueFunctor3x.LinearCombination: 'cutlass::epilogue::fusion::LinearCombination',
|
||||
}
|
||||
|
||||
class TileSchedulerType(enum.Enum):
|
||||
Default = enum_auto()
|
||||
Persistent = enum_auto()
|
||||
|
||||
@ -429,7 +429,7 @@ class Manifest:
|
||||
self.kernel_filter_list = []
|
||||
else:
|
||||
self.kernel_filter_list = self.get_kernel_filters(args.kernel_filter_file)
|
||||
_LOGGER.info("Using {filter_count} kernel filters from {filter_file}".format(
|
||||
_LOGGER.debug("Using {filter_count} kernel filters from {filter_file}".format(
|
||||
filter_count = len(self.kernel_filter_list),
|
||||
filter_file = args.kernel_filter_file))
|
||||
|
||||
|
||||
@ -101,7 +101,7 @@ class Layout(LayoutBase):
|
||||
|
||||
# cosize(layout) Size of the codomain
|
||||
def cosize(self):
|
||||
return tuple_max(tuple((1, elem_scale(self.shape, self.stride))))
|
||||
return self(self.size() - 1) + 1
|
||||
|
||||
# print and str
|
||||
def __str__(self):
|
||||
|
||||
@ -51,7 +51,7 @@ setup_pycute.perform_setup()
|
||||
|
||||
setup(
|
||||
name='cutlass',
|
||||
version='3.3.0',
|
||||
version='3.4.0',
|
||||
description='CUTLASS Pythonic Interface',
|
||||
package_dir={'': '.'},
|
||||
packages=[
|
||||
|
||||
@ -36,7 +36,7 @@ from setuptools import setup
|
||||
def perform_setup():
|
||||
setup(
|
||||
name='cutlass_library',
|
||||
version='3.3.0',
|
||||
version='3.4.0',
|
||||
description='CUTLASS library generation scripts',
|
||||
packages=['cutlass_library']
|
||||
)
|
||||
|
||||
@ -36,7 +36,7 @@ from setuptools import setup
|
||||
def perform_setup():
|
||||
setup(
|
||||
name='pycute',
|
||||
version='3.3.0',
|
||||
version='3.4.0',
|
||||
description='Python implementation of CuTe',
|
||||
packages=['pycute'],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user