CUTLASS 3.6.0 (#1850)
* v3.6 * update changelog * update readme * fix typo * fixing typos * hopper gemm with weight prefetch --------- Co-authored-by: yuzhai <yuzhai@nvidia.com> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
@ -121,7 +121,7 @@ def get_option_registry():
|
||||
this._option_registry = OptionRegistry(device_cc())
|
||||
return this._option_registry
|
||||
|
||||
this.__version__ = '3.5.1'
|
||||
this.__version__ = '3.6.0'
|
||||
|
||||
from cutlass.backend import create_memory_pool
|
||||
from cutlass.emit.pytorch import pytorch
|
||||
|
||||
@ -37,7 +37,7 @@ import numpy as np
|
||||
from scipy.special import erf
|
||||
|
||||
from cutlass_library import DataType, DataTypeTag
|
||||
from cutlass.backend.c_types import MatrixCoord_
|
||||
from cutlass.backend.c_types import MatrixCoord_, tuple_factory
|
||||
from cutlass.backend.frontend import NumpyFrontend
|
||||
from cutlass.backend.library import ActivationOp, ActivationOpTag
|
||||
from cutlass.utils.datatypes import is_numpy_tensor, is_torch_available, is_torch_tensor
|
||||
@ -162,11 +162,15 @@ class LinearCombination(EpilogueFunctorBase):
|
||||
Epilogue params when using the default linear combination of EVT, which
|
||||
does not currently use {alpha,beta}_ptr_array
|
||||
"""
|
||||
|
||||
stride_type = tuple_factory((0,0,1), "int64_t", [0])
|
||||
_fields_ = [
|
||||
("alpha", c_element_epilogue),
|
||||
("beta", c_element_epilogue),
|
||||
("alpha_ptr", ctypes.c_void_p),
|
||||
("beta_ptr", ctypes.c_void_p),
|
||||
("dalpha", stride_type),
|
||||
("dbeta", stride_type),
|
||||
]
|
||||
|
||||
def __init__(self, alpha, beta, *args) -> None:
|
||||
|
||||
@ -164,7 +164,7 @@ class Sm90RowBroadcastImpl(RowBroadcastImpl):
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]},
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]}, {DataTypeTag[self.element_output]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
@ -183,7 +183,7 @@ class Sm90ColumnBroadcastImpl(ColumnBroadcastImpl):
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]},
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]}, {DataTypeTag[self.element_output]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
|
||||
@ -253,8 +253,8 @@ _CUTLASS_TYPE_TO_TORCH_TYPE = {
|
||||
DataType.f16: "torch::kF16",
|
||||
DataType.f32: "torch::kF32",
|
||||
DataType.f64: "torch::kF64",
|
||||
DataType.s8: "torch::I8",
|
||||
DataType.s32: "torch::I32",
|
||||
DataType.s8: "torch::kI8",
|
||||
DataType.s32: "torch::kI32",
|
||||
}
|
||||
|
||||
_PYTORCH_GEMM_IMPL_TEMPLATE_2x = (
|
||||
|
||||
@ -94,9 +94,12 @@ using ${operation_name}_mainloop =
|
||||
${kernel_schedule}
|
||||
>::CollectiveOp;
|
||||
|
||||
using ${operation_name}_problem_shape = cutlass::conv::ConvProblemShape<${conv_kind}, ${operation_name}_mainloop::NumSpatialDimensions>;
|
||||
|
||||
// Unit tests call this "ConvKernel".
|
||||
// Conv operator ${operation_name}
|
||||
using ${operation_name}_base = cutlass::conv::kernel::ConvUniversal<
|
||||
${operation_name}_problem_shape,
|
||||
${operation_name}_mainloop,
|
||||
${operation_name}_epilogue,
|
||||
${tile_scheduler}
|
||||
|
||||
@ -710,14 +710,14 @@ class EmitGemmUniversal3xInstance:
|
||||
"cutlass/gemm/collective/collective_builder.hpp",
|
||||
"cutlass/epilogue/collective/collective_builder.hpp",
|
||||
]
|
||||
self.builtin_epilogue_functor_template = """
|
||||
${epilogue_functor}<
|
||||
self.builtin_epilogue_functor_template = \
|
||||
"""${epilogue_functor}<
|
||||
${element_d},
|
||||
${element_epilogue},
|
||||
${element_c},
|
||||
${element_epilogue}
|
||||
>
|
||||
"""
|
||||
>"""
|
||||
|
||||
self.gemm_template = """
|
||||
|
||||
using ${operation_name}_epilogue =
|
||||
@ -778,7 +778,6 @@ ${compile_guard_end}
|
||||
|
||||
opcode_class_main = operation.tile_description.math_instruction.opcode_class
|
||||
opcode_class_epi = opcode_class_main
|
||||
|
||||
tile_shape = operation.tile_description.tile_shape
|
||||
instruction_shape = operation.tile_description.math_instruction.instruction_shape
|
||||
cluster_m = operation.tile_description.cluster_shape[0]
|
||||
@ -1057,14 +1056,14 @@ class EmitGemmGroupedInstance:
|
||||
"cutlass/gemm/kernel/default_gemm_grouped.h",
|
||||
"cutlass/gemm/device/gemm_grouped.h"
|
||||
]
|
||||
self.builtin_epilogue_functor_template = """
|
||||
${epilogue_functor}<
|
||||
self.builtin_epilogue_functor_template = \
|
||||
"""${epilogue_functor}<
|
||||
${element_c},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>
|
||||
"""
|
||||
>"""
|
||||
|
||||
self.gemm_template = """
|
||||
// Gemm operator ${operation_name}
|
||||
using ${operation_name}_base =
|
||||
@ -1183,6 +1182,7 @@ class EmitGemmConfigurationLibrary:
|
||||
GemmKind.Sparse: EmitSparseGemmInstance,
|
||||
GemmKind.Universal: EmitGemmUniversalInstance,
|
||||
GemmKind.Universal3x: EmitGemmUniversal3xInstance,
|
||||
GemmKind.SparseUniversal3x: EmitGemmUniversal3xInstance,
|
||||
GemmKind.PlanarComplex: EmitGemmPlanarComplexInstance,
|
||||
GemmKind.PlanarComplexArray: EmitGemmPlanarComplexArrayInstance,
|
||||
GemmKind.Grouped: EmitGemmGroupedInstance
|
||||
@ -1193,6 +1193,7 @@ class EmitGemmConfigurationLibrary:
|
||||
GemmKind.Sparse: 'GemmSparseOperation',
|
||||
GemmKind.Universal: 'GemmUniversalOperation',
|
||||
GemmKind.Universal3x: 'GemmUniversal3xOperation',
|
||||
GemmKind.SparseUniversal3x: 'SparseGemmUniversal3xOperation',
|
||||
GemmKind.PlanarComplex: 'GemmPlanarComplexOperation',
|
||||
GemmKind.PlanarComplexArray: 'GemmPlanarComplexArrayOperation',
|
||||
GemmKind.Grouped: 'GemmGroupedOperation'
|
||||
@ -1252,6 +1253,7 @@ void initialize_${configuration_name}(Manifest &manifest) {
|
||||
("library_internal.h", None),
|
||||
("gemm_operation.h", None),
|
||||
("gemm_operation_3x.hpp", None),
|
||||
("sparse_gemm_operation_3x.hpp", None),
|
||||
("cutlass/arch/wmma.h", None),
|
||||
("cutlass/numeric_types.h", None)
|
||||
])
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -588,12 +588,14 @@ OpcodeClassNames = {
|
||||
OpcodeClass.Simt: 'simt',
|
||||
OpcodeClass.TensorOp: 'tensorop',
|
||||
OpcodeClass.WmmaTensorOp: 'wmma_tensorop',
|
||||
OpcodeClass.SparseTensorOp: 'sptensorop',
|
||||
}
|
||||
|
||||
OpcodeClassTag = {
|
||||
OpcodeClass.Simt: 'cutlass::arch::OpClassSimt',
|
||||
OpcodeClass.TensorOp: 'cutlass::arch::OpClassTensorOp',
|
||||
OpcodeClass.WmmaTensorOp: 'cutlass::arch::OpClassWmmaTensorOp',
|
||||
OpcodeClass.SparseTensorOp: 'cutlass::arch::OpClassSparseTensorOp',
|
||||
}
|
||||
|
||||
###################################################################################################
|
||||
|
||||
@ -520,7 +520,9 @@ class Manifest:
|
||||
raise RuntimeError("The list of architectures (CMake option CUTLASS_NVCC_ARCHS) must be semicolon-delimited.\nDon't use commas to separate the architectures; use semicolons.\nYou specified the list as: " + args.architectures)
|
||||
architectures = args.architectures.split(';') if len(args.architectures) else ['50',]
|
||||
|
||||
arch_conditional_cc = ['90a']
|
||||
arch_conditional_cc = [
|
||||
'90a',
|
||||
]
|
||||
architectures = [x if x not in arch_conditional_cc else x.split('a')[0] for x in architectures]
|
||||
|
||||
self.compute_capabilities = [int(x) for x in architectures]
|
||||
@ -560,6 +562,18 @@ class Manifest:
|
||||
self.operation_count = 0
|
||||
self.operations_by_name = {}
|
||||
self.disable_full_archs_compilation = args.disable_full_archs_compilation
|
||||
self.is_kernel_filter_set_to_all = args.instantiation_level == "max" and args.kernels != ''
|
||||
|
||||
def get_sm90_instantiation_level(self, pruned_level=0, default_level=111, exhaustive_level=9999):
|
||||
# Non-negative integer which determines how many kernels are instantiated.
|
||||
# 0 = 0000 generates the fewest kernels, 9999 generates all possible combinations.
|
||||
# increasing first digit reduces schedule / mixed type pruning,
|
||||
# increasing second digit generates more cluster sizes,
|
||||
# increasing third digit generates more MMA shapes,
|
||||
# increasing fourth digit generates more instruction shapes.
|
||||
return exhaustive_level if self.is_kernel_filter_set_to_all else (
|
||||
pruned_level if self.kernel_filter == '' else default_level
|
||||
)
|
||||
|
||||
|
||||
def get_kernel_filters (self, kernelListFile):
|
||||
@ -601,6 +615,7 @@ class Manifest:
|
||||
enabled = not (self.filter_by_cc)
|
||||
|
||||
for cc in self.compute_capabilities:
|
||||
|
||||
if cc >= operation.tile_description.minimum_compute_capability and \
|
||||
cc <= operation.tile_description.maximum_compute_capability and \
|
||||
(cc not in SharedMemPerCC or SharedMemPerCC[cc] >= CalculateSmemUsage(operation)):
|
||||
|
||||
212
python/cutlass_library/sm90_shapes.py
Normal file
212
python/cutlass_library/sm90_shapes.py
Normal file
@ -0,0 +1,212 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Valid WGMMA shapes, MMA multipliers, and cluster sizes for SM90, associated with levels.
|
||||
These shape and level pairs are defined as dicts, where keys are shapes and values are their
|
||||
associated levels. If the user input level for that category (MMA multiplier, WGMMA shape, cluster
|
||||
size) is smaller than a shape's associated level, it will be excluded, and otherwise, included.
|
||||
Higher levels are therefore less likely emitted, but lower levels are more emitted more frequently.
|
||||
Level 0 is always emitted. The default behavior in `generator.py` is that level 1 is only emitted
|
||||
when the `--kernel` argument is non-empty.
|
||||
"""
|
||||
|
||||
# NOTE: more combinations are possible here.
|
||||
# Levels [0, 3] exist in order to control exactly what configs are generated in different dtypes.
|
||||
# The rest are only used in the exhaustive mode (when the corresponding level digit is 9).
|
||||
# MMA multipliers are multiplied by MMA instruction shapes (WGMMA shapes) to get CTA shapes.
|
||||
SM90_MMA_MULTIPLIERS = {
|
||||
(2, 1, 4): 0,
|
||||
(1, 1, 4): 1,
|
||||
(4, 1, 4): 2,
|
||||
(2, 2, 4): 3,
|
||||
(2, 1, 8): 4,
|
||||
(4, 1, 8): 4,
|
||||
(1, 1, 8): 4,
|
||||
(2, 2, 8): 4,
|
||||
(2, 1, 16): 5,
|
||||
(4, 1, 16): 5,
|
||||
(1, 1, 16): 5,
|
||||
(2, 2, 16): 5,
|
||||
}
|
||||
|
||||
# Level 0: only (1, 2, 1) -- fp8 dense gemms in pruned case
|
||||
# Level 1: clusters with 2 CTAs -- all but fp8 (s8, u8, f16, b16, f32, tf32) dense gemms in pruned case
|
||||
# Level 2: clusters with 1 or 2 CTAs
|
||||
# Level 3: clusters with 1, 2, or 4 CTAs
|
||||
# Level 4: clusters with 1, 2, 4, or 8 CTAs
|
||||
# Level 5: clusters with 1, 2, 4, 8, or 16 CTAs
|
||||
SM90_CLUSTER_SIZES = {
|
||||
(1, 2, 1): 0,
|
||||
(2, 1, 1): 1,
|
||||
(1, 1, 1): 2,
|
||||
(2, 2, 1): 3,
|
||||
(1, 4, 1): 3,
|
||||
(4, 1, 1): 3,
|
||||
(2, 4, 1): 4,
|
||||
(4, 2, 1): 4,
|
||||
(1, 8, 1): 4,
|
||||
(8, 1, 1): 4,
|
||||
(4, 4, 1): 5,
|
||||
}
|
||||
|
||||
|
||||
# WGMMA shapes
|
||||
# Level 0: "default" shape only,
|
||||
# Level 1: additional shapes for the unpruned case (tf32 only)
|
||||
# Level 2: shapes that are all powers of 2
|
||||
# Level 3: all other shapes
|
||||
SM90_WGMMA_SHAPES_FP16_BF16_DENSE = {
|
||||
(64, 8, 16): 2,
|
||||
(64, 16, 16): 2,
|
||||
(64, 24, 16): 3,
|
||||
(64, 32, 16): 2,
|
||||
(64, 40, 16): 3,
|
||||
(64, 48, 16): 3,
|
||||
(64, 56, 16): 3,
|
||||
(64, 64, 16): 2,
|
||||
(64, 72, 16): 3,
|
||||
(64, 80, 16): 3,
|
||||
(64, 88, 16): 3,
|
||||
(64, 96, 16): 3,
|
||||
(64, 104, 16): 3,
|
||||
(64, 112, 16): 3,
|
||||
(64, 120, 16): 3,
|
||||
(64, 128, 16): 0,
|
||||
(64, 136, 16): 3,
|
||||
(64, 144, 16): 3,
|
||||
(64, 152, 16): 3,
|
||||
(64, 160, 16): 3,
|
||||
(64, 168, 16): 3,
|
||||
(64, 176, 16): 3,
|
||||
(64, 184, 16): 3,
|
||||
(64, 192, 16): 3,
|
||||
(64, 200, 16): 3,
|
||||
(64, 208, 16): 3,
|
||||
(64, 216, 16): 3,
|
||||
(64, 224, 16): 3,
|
||||
(64, 232, 16): 3,
|
||||
(64, 240, 16): 3,
|
||||
(64, 248, 16): 3,
|
||||
(64, 256, 16): 1,
|
||||
}
|
||||
|
||||
SM90_WGMMA_SHAPES_TF32_DENSE = {
|
||||
(64, 8, 8): 2,
|
||||
(64, 16, 8): 2,
|
||||
(64, 24, 8): 3,
|
||||
(64, 32, 8): 2,
|
||||
(64, 40, 8): 3,
|
||||
(64, 48, 8): 3,
|
||||
(64, 56, 8): 3,
|
||||
(64, 64, 8): 2,
|
||||
(64, 72, 8): 3,
|
||||
(64, 80, 8): 3,
|
||||
(64, 88, 8): 3,
|
||||
(64, 96, 8): 3,
|
||||
(64, 104, 8): 3,
|
||||
(64, 112, 8): 3,
|
||||
(64, 120, 8): 3,
|
||||
(64, 128, 8): 0,
|
||||
(64, 136, 8): 3,
|
||||
(64, 144, 8): 3,
|
||||
(64, 152, 8): 3,
|
||||
(64, 160, 8): 3,
|
||||
(64, 168, 8): 3,
|
||||
(64, 176, 8): 3,
|
||||
(64, 184, 8): 3,
|
||||
(64, 192, 8): 3,
|
||||
(64, 200, 8): 3,
|
||||
(64, 208, 8): 3,
|
||||
(64, 216, 8): 3,
|
||||
(64, 224, 8): 3,
|
||||
(64, 232, 8): 3,
|
||||
(64, 240, 8): 3,
|
||||
(64, 248, 8): 3,
|
||||
(64, 256, 8): 1,
|
||||
}
|
||||
|
||||
SM90_WGMMA_SHAPES_FP8_DENSE = {
|
||||
(64, 8, 32): 2,
|
||||
(64, 16, 32): 2,
|
||||
(64, 24, 32): 3,
|
||||
(64, 32, 32): 2,
|
||||
(64, 40, 32): 3,
|
||||
(64, 48, 32): 3,
|
||||
(64, 56, 32): 3,
|
||||
(64, 64, 32): 2,
|
||||
(64, 72, 32): 3,
|
||||
(64, 80, 32): 3,
|
||||
(64, 88, 32): 3,
|
||||
(64, 96, 32): 3,
|
||||
(64, 104, 32): 3,
|
||||
(64, 112, 32): 3,
|
||||
(64, 120, 32): 3,
|
||||
(64, 128, 32): 0,
|
||||
(64, 136, 32): 3,
|
||||
(64, 144, 32): 3,
|
||||
(64, 152, 32): 3,
|
||||
(64, 160, 32): 3,
|
||||
(64, 168, 32): 3,
|
||||
(64, 176, 32): 3,
|
||||
(64, 184, 32): 3,
|
||||
(64, 192, 32): 3,
|
||||
(64, 200, 32): 3,
|
||||
(64, 208, 32): 3,
|
||||
(64, 216, 32): 3,
|
||||
(64, 224, 32): 3,
|
||||
(64, 232, 32): 3,
|
||||
(64, 240, 32): 3,
|
||||
(64, 248, 32): 3,
|
||||
(64, 256, 32): 1,
|
||||
}
|
||||
|
||||
SM90_WGMMA_SHAPES_INT8_DENSE = {
|
||||
(64, 8, 32): 2,
|
||||
(64, 16, 32): 2,
|
||||
(64, 24, 32): 3,
|
||||
(64, 32, 32): 2,
|
||||
(64, 48, 32): 3,
|
||||
(64, 64, 32): 2,
|
||||
(64, 80, 32): 3,
|
||||
(64, 96, 32): 3,
|
||||
(64, 112, 32): 3,
|
||||
(64, 128, 32): 0,
|
||||
(64, 144, 32): 3,
|
||||
(64, 160, 32): 3,
|
||||
(64, 176, 32): 3,
|
||||
(64, 192, 32): 3,
|
||||
(64, 208, 32): 3,
|
||||
(64, 224, 32): 3,
|
||||
(64, 240, 32): 3,
|
||||
(64, 256, 32): 1,
|
||||
}
|
||||
601
python/cutlass_library/sm90_utils.py
Normal file
601
python/cutlass_library/sm90_utils.py
Normal file
@ -0,0 +1,601 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Utilities for enumerating CUTLASS library SM90 kernels
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import enum
|
||||
from itertools import product
|
||||
import math
|
||||
import logging
|
||||
import os.path
|
||||
import shutil
|
||||
import sys
|
||||
import copy
|
||||
from typing import Any, Optional, Sequence, Tuple
|
||||
|
||||
try:
|
||||
import builtins
|
||||
if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
|
||||
raise ImportError("Disabling attempt to import cutlass_library")
|
||||
from cutlass_library.library import *
|
||||
except ImportError:
|
||||
from library import *
|
||||
|
||||
# NOTE: this is a duplicate of CudaToolkitVersionSatisfies in generator.py
|
||||
def CudaToolkitVersionSatisfies(semantic_ver_string, major, minor, patch = 0):
|
||||
|
||||
# by default, use the latest CUDA Toolkit version
|
||||
cuda_version = [11, 0, 132]
|
||||
|
||||
# Update cuda_version based on parsed string
|
||||
if semantic_ver_string != '':
|
||||
for i, x in enumerate([int(x) for x in semantic_ver_string.split('.')]):
|
||||
if i < len(cuda_version):
|
||||
cuda_version[i] = x
|
||||
else:
|
||||
cuda_version.append(x)
|
||||
return cuda_version >= [major, minor, patch]
|
||||
|
||||
#### Step 0: define levels
|
||||
|
||||
# One integer level controls multiple "generators" and how many
|
||||
# combinations they generate. That is the "global" level.
|
||||
# "Generators" are WGMMA shapes, MMA multipliers, cluster sizes, and
|
||||
# anything that is eventually involved in the Cartesian product
|
||||
# which yields our kernel configurations.
|
||||
# For simplicity, each generator defines their own levels,
|
||||
# starting from 0. As a rule we assume 10 or fewer levels, making
|
||||
# their level a digit.
|
||||
# The "global" level simply stacks these digits and represents them
|
||||
# as a single integer.
|
||||
#
|
||||
# For example, level 500 indicates cluster sizes are at level 5, MMA
|
||||
# multipliers are at level 0, and WGMMA shapes are at level 0 as well.
|
||||
#
|
||||
# Here we define the global level to generator level mappings.
|
||||
|
||||
|
||||
def get_wgmma_level_from_global_level(global_level: int):
|
||||
return global_level % 10
|
||||
|
||||
|
||||
def get_mma_level_from_global_level(global_level: int):
|
||||
return (global_level // 10) % 10
|
||||
|
||||
|
||||
def get_cluster_level_from_global_level(global_level: int):
|
||||
return (global_level // 100) % 10
|
||||
|
||||
|
||||
def get_pruning_level_from_global_level(global_level: int):
|
||||
return (global_level // 1000) % 10
|
||||
|
||||
|
||||
#### Step 1: generate MMA instruction shapes based on levels
|
||||
|
||||
try:
|
||||
from .sm90_shapes import (
|
||||
SM90_MMA_MULTIPLIERS,
|
||||
SM90_CLUSTER_SIZES,
|
||||
SM90_WGMMA_SHAPES_TF32_DENSE,
|
||||
SM90_WGMMA_SHAPES_FP16_BF16_DENSE,
|
||||
SM90_WGMMA_SHAPES_FP8_DENSE,
|
||||
SM90_WGMMA_SHAPES_INT8_DENSE,
|
||||
)
|
||||
except:
|
||||
from sm90_shapes import (
|
||||
SM90_MMA_MULTIPLIERS,
|
||||
SM90_CLUSTER_SIZES,
|
||||
SM90_WGMMA_SHAPES_TF32_DENSE,
|
||||
SM90_WGMMA_SHAPES_FP16_BF16_DENSE,
|
||||
SM90_WGMMA_SHAPES_FP8_DENSE,
|
||||
SM90_WGMMA_SHAPES_INT8_DENSE,
|
||||
)
|
||||
|
||||
|
||||
def generate_tf32_math_instruction_shapes_sm90(level: int):
|
||||
assert isinstance(level, int) and level >= 0
|
||||
filtered_list_of_wgmma_shapes = [
|
||||
wgmma_shape for wgmma_shape, min_level in SM90_WGMMA_SHAPES_TF32_DENSE.items() if level >= min_level
|
||||
]
|
||||
return filtered_list_of_wgmma_shapes
|
||||
|
||||
def generate_fp16_bf16_math_instruction_shapes_sm90(level: int):
|
||||
assert isinstance(level, int) and level >= 0
|
||||
filtered_list_of_wgmma_shapes = [
|
||||
wgmma_shape for wgmma_shape, min_level in SM90_WGMMA_SHAPES_FP16_BF16_DENSE.items() if level >= min_level
|
||||
]
|
||||
return filtered_list_of_wgmma_shapes
|
||||
|
||||
def generate_fp8_math_instruction_shapes_sm90(level: int):
|
||||
assert isinstance(level, int) and level >= 0
|
||||
filtered_list_of_wgmma_shapes = [
|
||||
wgmma_shape for wgmma_shape, min_level in SM90_WGMMA_SHAPES_FP8_DENSE.items() if level >= min_level
|
||||
]
|
||||
return filtered_list_of_wgmma_shapes
|
||||
|
||||
def generate_int8_math_instruction_shapes_sm90(level: int):
|
||||
assert isinstance(level, int) and level >= 0
|
||||
filtered_list_of_wgmma_shapes = [
|
||||
wgmma_shape for wgmma_shape, min_level in SM90_WGMMA_SHAPES_INT8_DENSE.items() if level >= min_level
|
||||
]
|
||||
return filtered_list_of_wgmma_shapes
|
||||
|
||||
###########
|
||||
|
||||
def generate_tf32_math_instructions_sm90(level: int):
|
||||
wgmma_level = get_wgmma_level_from_global_level(level)
|
||||
math_instructions = []
|
||||
for math_instruction_shape in generate_tf32_math_instruction_shapes_sm90(wgmma_level):
|
||||
math_instructions.append(
|
||||
MathInstruction(
|
||||
math_instruction_shape,
|
||||
DataType.tf32, DataType.tf32, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add)
|
||||
)
|
||||
return math_instructions
|
||||
|
||||
def generate_fp16_bf16_math_instructions_sm90(level: int):
|
||||
wgmma_level = get_wgmma_level_from_global_level(level)
|
||||
math_instructions = []
|
||||
for math_instruction_shape in generate_fp16_bf16_math_instruction_shapes_sm90(wgmma_level):
|
||||
math_instructions += [
|
||||
MathInstruction(
|
||||
math_instruction_shape,
|
||||
DataType.f16, DataType.f16, DataType.f16,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
MathInstruction(
|
||||
math_instruction_shape,
|
||||
DataType.f16, DataType.f16, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
MathInstruction(
|
||||
math_instruction_shape,
|
||||
DataType.bf16, DataType.bf16, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
]
|
||||
return math_instructions
|
||||
|
||||
def generate_fp8_math_instructions_sm90(level: int):
|
||||
wgmma_level = get_wgmma_level_from_global_level(level)
|
||||
math_instructions = []
|
||||
for math_instruction_shape in generate_fp8_math_instruction_shapes_sm90(wgmma_level):
|
||||
math_instructions += [
|
||||
MathInstruction(
|
||||
math_instruction_shape,
|
||||
DataType.e4m3, DataType.e4m3, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
MathInstruction(
|
||||
math_instruction_shape,
|
||||
DataType.e4m3, DataType.e5m2, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
MathInstruction(
|
||||
math_instruction_shape,
|
||||
DataType.e5m2, DataType.e4m3, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
MathInstruction(
|
||||
math_instruction_shape,
|
||||
DataType.e5m2, DataType.e5m2, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
]
|
||||
return math_instructions
|
||||
|
||||
def generate_int8_math_instructions_sm90(level: int):
|
||||
wgmma_level = get_wgmma_level_from_global_level(level)
|
||||
math_instructions = []
|
||||
for math_instruction_shape in generate_int8_math_instruction_shapes_sm90(wgmma_level):
|
||||
math_instructions += [
|
||||
MathInstruction(
|
||||
math_instruction_shape,
|
||||
DataType.s8, DataType.s8, DataType.s32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
MathInstruction(
|
||||
math_instruction_shape,
|
||||
DataType.u8, DataType.u8, DataType.s32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
]
|
||||
return math_instructions
|
||||
|
||||
def make_sparse_math_instructions(math_instructions):
|
||||
sparse_instructions = []
|
||||
for inst in math_instructions:
|
||||
if inst.opcode_class == OpcodeClass.TensorOp:
|
||||
sparse_instructions.append(MathInstruction(
|
||||
(inst.instruction_shape[0], inst.instruction_shape[1], inst.instruction_shape[2] * 2),
|
||||
inst.element_a, inst.element_b, inst.element_accumulator,
|
||||
OpcodeClass.SparseTensorOp,
|
||||
inst.math_operation),)
|
||||
return sparse_instructions
|
||||
|
||||
|
||||
#### Step 2: generate tile descriptions from math instruction shapes
|
||||
|
||||
def is_tile_desc_valid(tile_description):
|
||||
if tile_description.minimum_compute_capability != 90 or tile_description.maximum_compute_capability != 90:
|
||||
return False
|
||||
|
||||
element_a, element_b, element_accum = (
|
||||
tile_description.math_instruction.element_a,
|
||||
tile_description.math_instruction.element_b,
|
||||
tile_description.math_instruction.element_accumulator
|
||||
)
|
||||
|
||||
cluster_shape, cta_shape, inst_shape = (
|
||||
tile_description.cluster_shape,
|
||||
tile_description.threadblock_shape,
|
||||
tile_description.math_instruction.instruction_shape
|
||||
)
|
||||
grid_size = (
|
||||
cta_shape[0] * cluster_shape[0] +
|
||||
cta_shape[1] * cluster_shape[1] +
|
||||
cta_shape[2] * cluster_shape[2]
|
||||
)
|
||||
cluster_size = cluster_shape[0] * cluster_shape[1] * cluster_shape[2]
|
||||
|
||||
# Maximum number of CTAs per cluster is 8 for Hopper, but up to 16 is
|
||||
# allowed for non portable clusters.
|
||||
if cluster_size > 16 or cluster_size < 1:
|
||||
return False
|
||||
|
||||
if grid_size < 1:
|
||||
return False
|
||||
|
||||
# SM90 WGMMA shapes are always 64 across M, therefore
|
||||
# CTA shape across M must always be a multiple of 64.
|
||||
if cta_shape[0] < 64 or cta_shape[0] % 64 != 0:
|
||||
return False
|
||||
|
||||
# The minimum WGMMA shape across N is 8, and increments
|
||||
# vary across different dtypes, but they're never smaller
|
||||
# than 8. The minimum CTA shape allowed across N though is 16.
|
||||
if cta_shape[1] < 16 or cta_shape[1] % 8 != 0:
|
||||
return False
|
||||
|
||||
# SM90 WGMMA shapes across K are always 8 for 32 bit dense
|
||||
# operations, 16 for 16 bit, and 32 for 8 bit. In any case,
|
||||
# the CTA shape across K should be a multiple of 8 and at least
|
||||
# twice the WGMMA shape across K.
|
||||
if cta_shape[2] < 16 or cta_shape[2] % 8 != 0:
|
||||
return False
|
||||
|
||||
# Minimum of 2 stages
|
||||
if cta_shape[2] < inst_shape[2] or cta_shape[2] % inst_shape[2] != 0 or cta_shape[2] / inst_shape[2] < 2:
|
||||
return False
|
||||
|
||||
# CTA shape upper bound: <256, 256, 256>
|
||||
if cta_shape[0] > 256 or cta_shape[1] > 256 or cta_shape[2] > 256:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_mma_multipliers(level: int):
|
||||
assert isinstance(level, int) and level >= 0
|
||||
mma_level = get_mma_level_from_global_level(level)
|
||||
return [
|
||||
mma_mul for mma_mul, mma_min_level in SM90_MMA_MULTIPLIERS.items() if mma_level >= mma_min_level
|
||||
]
|
||||
|
||||
def get_cluster_sizes(level: int, is_aligned: bool):
|
||||
if not is_aligned:
|
||||
return [(1, 1, 1)]
|
||||
assert isinstance(level, int) and level >= 0
|
||||
cluster_level = get_cluster_level_from_global_level(level)
|
||||
return [
|
||||
cluster_size for cluster_size, cluster_min_level in SM90_CLUSTER_SIZES.items() if cluster_level >= cluster_min_level
|
||||
]
|
||||
|
||||
def generate_tile_descriptions_sm90(math_instructions, is_aligned: bool, level: int):
|
||||
tile_descriptions = set()
|
||||
mma_multipliers, cluster_sizes = get_mma_multipliers(level), get_cluster_sizes(level, is_aligned)
|
||||
for math_inst, mma_mul, cluster_size in product(math_instructions, mma_multipliers, cluster_sizes):
|
||||
tile_desc = TileDescription(
|
||||
threadblock_shape=[
|
||||
math_inst.instruction_shape[0] * mma_mul[0],
|
||||
math_inst.instruction_shape[1] * mma_mul[1],
|
||||
math_inst.instruction_shape[2] * mma_mul[2]
|
||||
],
|
||||
stages=0,
|
||||
warp_count=[4, 1, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=90,
|
||||
max_compute=90,
|
||||
cluster_shape=cluster_size)
|
||||
# For sparse kernels K-tile is twice as large (due to 2x MMA-K size)
|
||||
# Reduce it to same size as dense to afford more smem stages
|
||||
if math_inst.opcode_class == OpcodeClass.SparseTensorOp:
|
||||
tile_desc.threadblock_shape[2] = tile_desc.threadblock_shape[2] // 2
|
||||
if is_tile_desc_valid(tile_desc):
|
||||
tile_descriptions.add(tile_desc)
|
||||
|
||||
return tile_descriptions
|
||||
|
||||
#### Step 3: map tile description to valid schedules
|
||||
|
||||
def is_tile_desc_compatible_with_cooperative(tile_description):
|
||||
# Cooperative kernels require a minimum CTA-M of 128
|
||||
return tile_description.threadblock_shape[0] >= 128
|
||||
|
||||
|
||||
def can_tile_desc_use_shmem_in_epilogue(tile_description, data_types):
|
||||
dtype_a, dtype_b, dtype_c, dtype_d, dtype_acc, dtype_epi = (
|
||||
data_types["a_type"],
|
||||
data_types["b_type"],
|
||||
data_types["c_type"],
|
||||
data_types["d_type"],
|
||||
data_types["acc_type"],
|
||||
data_types["epi_type"]
|
||||
)
|
||||
mn = tile_description.threadblock_shape[0] * tile_description.threadblock_shape[1]
|
||||
bitsize_c, bitsize_d = DataTypeSize[dtype_c], DataTypeSize[dtype_d]
|
||||
|
||||
shmem_bits_c, shmem_bits_d = bitsize_c * mn, bitsize_d * mn
|
||||
shmem_bits_total = shmem_bits_c + shmem_bits_d
|
||||
# Magic number: 2^20
|
||||
# Existing logic suggested that tile shape 256x128 (or 128x256)
|
||||
# would run out of shmem if D is FP32, and source is needed.
|
||||
# That would be 256 * 128 * 32 == 2^21 (~262 KB), which is over the limit.
|
||||
# Hopper's max shmem size is 228 KB, and 2^20 ~= 131 KB.
|
||||
# Since epilogue can't possibly use ALL of the shmem available
|
||||
# we can just settle on 2^20 bits (~ 131 KB) being the upper bound
|
||||
# we would allow for epilogue.
|
||||
# This can be different for non-persistent kernels where epilogue and
|
||||
# mainloop shmem is shared.
|
||||
if shmem_bits_total > 2 ** 20:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types, layout,
|
||||
instantiation_level, enable_fp8_fast_acc=True):
|
||||
# Level 0: prune according to existing generator.py behavior
|
||||
# Level >= 1: no pruning
|
||||
level = get_pruning_level_from_global_level(instantiation_level)
|
||||
schedules = []
|
||||
stream_k_schedules = []
|
||||
|
||||
if not is_tile_desc_valid(tile_description):
|
||||
return schedules, stream_k_schedules
|
||||
|
||||
FP16_TYPES = [DataType.f16, DataType.bf16]
|
||||
is_fp16 = data_types["a_type"] in FP16_TYPES and data_types["b_type"] in FP16_TYPES
|
||||
|
||||
FP8_TYPES = [DataType.e4m3, DataType.e5m2]
|
||||
is_fp8 = data_types["a_type"] in FP8_TYPES and data_types["b_type"] in FP8_TYPES
|
||||
can_do_fp8_fast_accum = is_fp8 and enable_fp8_fast_acc
|
||||
|
||||
FP32_TYPES = [DataType.f32, DataType.tf32]
|
||||
is_fp32 = data_types["a_type"] in FP32_TYPES and data_types["b_type"] in FP32_TYPES
|
||||
requires_transposed_epilogue = is_fp32 and layout[0][0] == LayoutType.RowMajor and layout[1][0] == LayoutType.RowMajor
|
||||
|
||||
is_sparse = tile_description.math_instruction.opcode_class == OpcodeClass.SparseTensorOp
|
||||
|
||||
can_do_cooperative = is_tile_desc_compatible_with_cooperative(tile_description)
|
||||
can_do_tma_epilogue = is_aligned and not requires_transposed_epilogue and can_tile_desc_use_shmem_in_epilogue(tile_description, data_types)
|
||||
|
||||
default_epilogue = EpilogueScheduleType.NoSmemWarpSpecialized if not requires_transposed_epilogue else EpilogueScheduleType.EpilogueTransposed
|
||||
auto_epilogue = EpilogueScheduleType.ScheduleAuto if not requires_transposed_epilogue else EpilogueScheduleType.EpilogueTransposed
|
||||
|
||||
cta_m, cta_n, cta_k = (
|
||||
tile_description.threadblock_shape[0],
|
||||
tile_description.threadblock_shape[1],
|
||||
tile_description.threadblock_shape[2]
|
||||
)
|
||||
c_type = data_types["c_type"]
|
||||
d_type = data_types["d_type"]
|
||||
is_void_c = c_type == DataType.void
|
||||
|
||||
# Early pruning
|
||||
if level < 1:
|
||||
# Don't stamp out FP16/BF16 kernels smaller than or equal to 64x128x64
|
||||
if is_fp16 and cta_m <= 64 and cta_n <= 128 and cta_k <= 64:
|
||||
return [], []
|
||||
|
||||
# FP8 configs with CTA tile larger than or equal to 256x128x128 limit data types and schedules
|
||||
is_large_fp8_tile = is_fp8 and cta_m >= 256 and cta_n >= 128 and cta_k >= 128
|
||||
if is_large_fp8_tile:
|
||||
# Only void-C, and only FP8 outputs allowed
|
||||
if not is_void_c or d_type not in FP8_TYPES:
|
||||
return [], []
|
||||
if CudaToolkitVersionSatisfies(cuda_version, 12, 1) and can_do_cooperative and can_do_tma_epilogue:
|
||||
return [
|
||||
[
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative if not is_sparse else KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
|
||||
EpilogueScheduleType.TmaWarpSpecializedCooperative
|
||||
],
|
||||
[
|
||||
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
|
||||
EpilogueScheduleType.TmaWarpSpecializedCooperative
|
||||
],
|
||||
] , []
|
||||
return [], []
|
||||
|
||||
if is_fp8 and not is_large_fp8_tile:
|
||||
valid_dtypes_for_c = [DataType.f32, DataType.bf16, DataType.f16]
|
||||
# Prune all configs with fp8 source, and all configs with non-fp8 output
|
||||
# that have different dtypes for source and output.
|
||||
if c_type not in valid_dtypes_for_c or (d_type not in FP8_TYPES and c_type != d_type):
|
||||
return [], []
|
||||
|
||||
# FP32/TF32 kernels don't stamp out void-C
|
||||
if is_fp32 and is_void_c:
|
||||
return [], []
|
||||
|
||||
# Void-c only makes a difference for TMA epilogues
|
||||
if is_void_c and not can_do_tma_epilogue:
|
||||
return [], []
|
||||
|
||||
if not is_aligned:
|
||||
schedules = [[KernelScheduleType.CpAsyncWarpSpecialized,
|
||||
default_epilogue]]
|
||||
stream_k_schedules = []
|
||||
|
||||
if CudaToolkitVersionSatisfies(cuda_version, 12, 1) and can_do_cooperative:
|
||||
schedules.append([
|
||||
KernelScheduleType.CpAsyncWarpSpecializedCooperative,
|
||||
default_epilogue
|
||||
])
|
||||
stream_k_schedules.append([
|
||||
KernelScheduleType.CpAsyncWarpSpecializedCooperative,
|
||||
default_epilogue
|
||||
])
|
||||
|
||||
return schedules, stream_k_schedules
|
||||
|
||||
schedules = []
|
||||
# Pruning: emit Void-C kernels with persistent kernels only
|
||||
if level >= 1 or not is_void_c:
|
||||
# Pruning: don't stamp out fp8 kernels with auto schedule
|
||||
if not is_fp8:
|
||||
schedules.append([KernelScheduleType.ScheduleAuto, auto_epilogue])
|
||||
if not (is_fp8 and is_sparse):
|
||||
schedules.append([KernelScheduleType.TmaWarpSpecialized, default_epilogue])
|
||||
stream_k_schedules = []
|
||||
|
||||
if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
|
||||
# Pruning: don't stamp out fp8 ping-ponging kernel with non-tma epilogue
|
||||
if not is_fp8 or level >= 1:
|
||||
schedules.append([KernelScheduleType.TmaWarpSpecializedPingpong, default_epilogue])
|
||||
|
||||
if can_do_fp8_fast_accum:
|
||||
schedules.append([KernelScheduleType.TmaWarpSpecializedFP8FastAccum, default_epilogue])
|
||||
schedules.append([KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, default_epilogue])
|
||||
|
||||
if can_do_cooperative:
|
||||
# Sparse kernels only support FastAccum FP8 mainloop
|
||||
if not (is_fp8 and is_sparse):
|
||||
schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
default_epilogue
|
||||
])
|
||||
stream_k_schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
default_epilogue
|
||||
])
|
||||
if can_do_fp8_fast_accum:
|
||||
schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
|
||||
default_epilogue
|
||||
])
|
||||
stream_k_schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
|
||||
default_epilogue
|
||||
])
|
||||
|
||||
# persistent kernels with TMA epilogues
|
||||
if can_do_tma_epilogue:
|
||||
assert not requires_transposed_epilogue
|
||||
# Inconsistency: fp8 pingpong only gets stamped out with fast accum
|
||||
if not is_fp8 or level >= 1:
|
||||
schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedPingpong,
|
||||
EpilogueScheduleType.TmaWarpSpecialized
|
||||
])
|
||||
if can_do_fp8_fast_accum:
|
||||
schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum,
|
||||
EpilogueScheduleType.TmaWarpSpecialized
|
||||
])
|
||||
if can_do_cooperative:
|
||||
# Sparse kernels only support FastAccum FP8 mainloop
|
||||
if not (is_fp8 and is_sparse):
|
||||
schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
EpilogueScheduleType.TmaWarpSpecializedCooperative
|
||||
])
|
||||
stream_k_schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
EpilogueScheduleType.TmaWarpSpecializedCooperative
|
||||
])
|
||||
if can_do_fp8_fast_accum:
|
||||
schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
|
||||
EpilogueScheduleType.TmaWarpSpecializedCooperative
|
||||
])
|
||||
stream_k_schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
|
||||
EpilogueScheduleType.TmaWarpSpecializedCooperative
|
||||
])
|
||||
|
||||
return schedules, stream_k_schedules
|
||||
|
||||
|
||||
#### Misc: helpers
|
||||
|
||||
def generate_data_types_from_math_instruction(math_instruction, element_source = None, element_dest = None, element_epilogue = None):
|
||||
element_a, element_b = math_instruction.element_a, math_instruction.element_b
|
||||
element_accumulator = math_instruction.element_accumulator
|
||||
element_c = element_source or element_accumulator
|
||||
element_d = element_dest or element_accumulator
|
||||
element_epilogue = element_epilogue or element_accumulator
|
||||
data_types = {
|
||||
"a_type" : element_a,
|
||||
"b_type" : element_b,
|
||||
"c_type" : element_c,
|
||||
"d_type" : element_d,
|
||||
"acc_type" : element_accumulator,
|
||||
"epi_type" : element_epilogue
|
||||
}
|
||||
return data_types
|
||||
|
||||
def fix_alignments(data_types, layout, alignment_bits = 128):
|
||||
operand_keys = ["a_type", "b_type", "c_type"]
|
||||
operands_to_fix = ["c_type"]
|
||||
new_layout = []
|
||||
assert len(layout) == len(operand_keys)
|
||||
for i, k in enumerate(operand_keys):
|
||||
assert k in data_types and data_types[k] in DataTypeSize
|
||||
dtype = data_types[k]
|
||||
dtype_size_bits = DataTypeSize[dtype]
|
||||
|
||||
layout_type = layout[i][0]
|
||||
layout_alignment = layout[i][1]
|
||||
|
||||
# Don't modify alignment if dtype's been changed to void
|
||||
if k in operands_to_fix and dtype_size_bits >= 1:
|
||||
layout_alignment = alignment_bits // dtype_size_bits
|
||||
|
||||
new_layout.append([layout_type, layout_alignment])
|
||||
|
||||
return new_layout
|
||||
@ -36,7 +36,7 @@ from setuptools import setup
|
||||
def perform_setup():
|
||||
setup(
|
||||
name='cutlass_library',
|
||||
version='3.5.1',
|
||||
version='3.6.0',
|
||||
description='CUTLASS library generation scripts',
|
||||
packages=['cutlass_library']
|
||||
)
|
||||
|
||||
@ -36,7 +36,7 @@ from setuptools import setup
|
||||
def perform_setup():
|
||||
setup(
|
||||
name='pycute',
|
||||
version='3.5.1',
|
||||
version='3.6.0',
|
||||
description='Python implementation of CuTe',
|
||||
packages=['pycute'],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user