CUTLASS 3.8 Release (#2059)

* CUTLASS 3.8 Release

* update

* Update README.md

* Revert "Update README.md"

This reverts commit b353e36fe8.

* update

* update

---------

Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
mihir-awatramani
2025-01-24 23:44:06 -08:00
committed by GitHub
parent 9eb01fa0b0
commit 389e493055
290 changed files with 91223 additions and 292 deletions

View File

@ -134,7 +134,7 @@ def get_option_registry():
this._option_registry = OptionRegistry(device_cc())
return this._option_registry
this.__version__ = '3.7.0'
this.__version__ = '3.8.0'
from cutlass.backend import create_memory_pool
from cutlass.emit.pytorch import pytorch

View File

@ -65,11 +65,15 @@ class GemmOperation:
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None,
kernel_schedule = KernelScheduleType.ScheduleAuto, epilogue_schedule = EpilogueScheduleType.ScheduleAuto,
tile_scheduler = TileSchedulerType.Default
, ScaleFactorA = None, ScaleFactorB = None, ScaleFactorD = None
):
kinds_3x = {
GemmKind.Universal3x,
GemmKind.SparseUniversal3x,
GemmKind.BlockScaledUniversal3x,
}
self.is_3x = gemm_kind in kinds_3x
self.prefix = "3x" if self.is_3x else ""
@ -82,6 +86,14 @@ class GemmOperation:
self.C = C
self.D = D
if self.gemm_kind == GemmKind.BlockScaledUniversal3x:
self.ScaleFactorA = ScaleFactorA
self.ScaleFactorB = ScaleFactorB
self.ScaleFactorD = ScaleFactorD["tensor"]
self.ScaleFactorVectorSize = ScaleFactorD["vector_size"]
if self.D == None:
self.D = self.C
@ -150,6 +162,7 @@ class GemmOperation:
OpcodeClass.TensorOp,
OpcodeClass.WmmaTensorOp,
OpcodeClass.SparseTensorOp,
OpcodeClass.BlockScaledTensorOp,
]
is_tensor_op = self.tile_description.math_instruction.opcode_class in tensor_ops
@ -207,6 +220,23 @@ class GemmOperation:
element_c = DataTypeNames[self.C.element],
element_d = DataTypeNames[self.D.element],
core_name = self.core_name())
if self.gemm_kind == GemmKind.BlockScaledUniversal3x:
d_type_names = DataTypeNames[self.D.element]
if self.ScaleFactorD.element != DataType.void:
d_type_names = DataTypeNames[self.ScaleFactorD.element] + "x" + d_type_names
extended_name = "{core_name}_{element_sfa}x{element_a}_{element_sfb}x{element_b}_{element_acc}_{element_c}_{element_d}".format(
element_sfa = DataTypeNames[self.ScaleFactorA],
element_a = DataTypeNames[self.A.element],
element_sfb = DataTypeNames[self.ScaleFactorB],
element_b = DataTypeNames[self.B.element],
element_acc = DataTypeNames[self.accumulator_type()],
element_c = DataTypeNames[self.C.element],
element_d = d_type_names,
core_name = self.core_name())
return extended_name
def datatype_name_3x(self):
@ -247,6 +277,11 @@ class GemmOperation:
# Generates a short string representing underlying epilogue schedule type
def epilogue_schedule_name_3x(self):
if self.gemm_kind == GemmKind.BlockScaledUniversal3x:
if self.ScaleFactorD.element != DataType.void:
return EpilogueScheduleSuffixes[self.epilogue_schedule] + "_epiVs" + str(self.ScaleFactorVectorSize)+ShortLayoutTypeNames[self.ScaleFactorD.layout]
return EpilogueScheduleSuffixes[self.epilogue_schedule]
# Generate a short string representing the operation class
@ -769,6 +804,32 @@ ${compile_guard_start}
${compile_guard_end}
"""
def emit_block_scale_epilogue_functor(self, operation):
block_scaled_template = """
${epilogue_functor}<
${epi_vs},
${element_d},
${element_accumulator},
${element_sfd},
${layout_sfd},
${element_c},
${element_scalar}
>
"""
block_scaled_values = {
'epi_vs' : str(operation.ScaleFactorVectorSize),
'element_d': str(DataTypeTag[operation.D.element]),
'element_sfd': str(DataTypeTag[operation.ScaleFactorD.element]),
'layout_sfd': LayoutTag[operation.ScaleFactorD.layout],
'epilogue_functor': EpilogueFunctor3xTag[EpilogueFunctor3x.LinearCombinationBlockScaleFactor],
'element_accumulator': str(DataTypeTag[operation.accumulator_type()]),
'element_scalar': str(DataTypeTag[operation.accumulator_type()]),
'element_c': str(DataTypeTag[operation.C.element]),
}
return SubstituteTemplate(block_scaled_template, block_scaled_values)
#
def emit(self, operation):
_LOGGER.debug("*** EmitGemmConfigurationLibrary::emit(operation)")
@ -778,6 +839,12 @@ ${compile_guard_end}
opcode_class_main = operation.tile_description.math_instruction.opcode_class
opcode_class_epi = opcode_class_main
if opcode_class_main == OpcodeClass.BlockScaledTensorOp:
if operation.epilogue_schedule != EpilogueScheduleType.NoSmemWarpSpecialized:
opcode_class_epi = OpcodeClass.TensorOp
tile_shape = operation.tile_description.tile_shape
instruction_shape = operation.tile_description.math_instruction.instruction_shape
cluster_m = operation.tile_description.cluster_shape[0]
@ -790,6 +857,23 @@ ${compile_guard_end}
cta_m = tile_shape[0] // cluster_m if cluster_m > 0 else tile_shape[0]
cta_n = tile_shape[1] // cluster_n if cluster_n > 0 else tile_shape[1]
# Shape passed to epilogue builder
is_sm100_kernel = (operation.arch == 100)
if is_sm100_kernel:
cta_m_per_mma_instruction = 2 if "2sm" in operation.procedural_name() else 1
if cluster_m <= 0:
cta_m = cta_m // cta_m_per_mma_instruction
if opcode_class_main in [OpcodeClass.TensorOp
, OpcodeClass.BlockScaledTensorOp
]:
tile_shape_main_m = instruction_shape[0]
tile_shape_main_n = instruction_shape[1]
tile_shape_epi_m = cta_m
tile_shape_epi_n = cta_n
# stage count set to zero indicates builder automatic stage selection
if operation.tile_description.stages > 0:
stage_count_string = f"cutlass::gemm::collective::StageCount<{str(operation.tile_description.stages)}>"
@ -811,14 +895,37 @@ ${compile_guard_end}
'epilogue_functor': EpilogueFunctor3xTag[operation.epilogue_functor],
}
epilogue_functor = SubstituteTemplate(self.builtin_epilogue_functor_template, values)
if operation.gemm_kind == GemmKind.BlockScaledUniversal3x and operation.ScaleFactorD.element != DataType.void:
epilogue_functor = self.emit_block_scale_epilogue_functor(operation)
else:
epilogue_functor = self.epilogue_functor.emit_declaration()
if operation.gemm_kind == GemmKind.BlockScaledUniversal3x and operation.ScaleFactorD.element != DataType.void:
epilogue_functor = self.emit_block_scale_epilogue_functor(operation)
#
# Cutlass3x complex kernels' ElementA(B) is a tuple in collective mainloop builder, e.g. cute::tuple<Element, Transform>, Transform : cute::identity / cute::conjugate.
element_a = DataTypeTag[operation.A.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.A.element])},{str(ComplexTransformTag3x[operation.A.complex_transform])}>"
element_b = DataTypeTag[operation.B.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.B.element])},{str(ComplexTransformTag3x[operation.B.complex_transform])}>"
epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule]
is_no_smem_epilogue = operation.epilogue_schedule == EpilogueScheduleType.NoSmemWarpSpecialized
if opcode_class_main == OpcodeClass.BlockScaledTensorOp:
if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100:
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
if not is_no_smem_epilogue:
epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized1Sm]
if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100:
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
if not is_no_smem_epilogue:
epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized2Sm]
element_a = f'cute::tuple<{str(element_a)},{str(DataTypeTag[operation.ScaleFactorA])}>'
element_b = f'cute::tuple<{str(element_b)},{str(DataTypeTag[operation.ScaleFactorB])}>'
values = {
'operation_name': operation.procedural_name(),
'operation_suffix': self.operation_suffix,
@ -1184,6 +1291,7 @@ class EmitGemmConfigurationLibrary:
GemmKind.Universal: EmitGemmUniversalInstance,
GemmKind.Universal3x: EmitGemmUniversal3xInstance,
GemmKind.SparseUniversal3x: EmitGemmUniversal3xInstance,
GemmKind.BlockScaledUniversal3x: EmitGemmUniversal3xInstance,
GemmKind.PlanarComplex: EmitGemmPlanarComplexInstance,
GemmKind.PlanarComplexArray: EmitGemmPlanarComplexArrayInstance,
GemmKind.Grouped: EmitGemmGroupedInstance
@ -1195,6 +1303,7 @@ class EmitGemmConfigurationLibrary:
GemmKind.Universal: 'GemmUniversalOperation',
GemmKind.Universal3x: 'GemmUniversal3xOperation',
GemmKind.SparseUniversal3x: 'SparseGemmUniversal3xOperation',
GemmKind.BlockScaledUniversal3x: 'BlockScaledGemmUniversal3xOperation',
GemmKind.PlanarComplex: 'GemmPlanarComplexOperation',
GemmKind.PlanarComplexArray: 'GemmPlanarComplexArrayOperation',
GemmKind.Grouped: 'GemmGroupedOperation'
@ -1255,6 +1364,7 @@ void initialize_${configuration_name}(Manifest &manifest) {
("gemm_operation.h", None),
("gemm_operation_3x.hpp", None),
("sparse_gemm_operation_3x.hpp", None),
("block_scaled_gemm_operation_3x.hpp", None),
("cutlass/arch/wmma.h", None),
("cutlass/numeric_types.h", None)
])

File diff suppressed because it is too large Load Diff

View File

@ -83,6 +83,14 @@ class DataType(enum.Enum):
s64 = enum_auto()
e4m3 = enum_auto()
e5m2 = enum_auto()
f8 = enum_auto()
f6 = enum_auto()
f4 = enum_auto()
e3m2 = enum_auto()
e2m3 = enum_auto()
e2m1 = enum_auto()
ue8m0 = enum_auto()
ue4m3 = enum_auto()
f16 = enum_auto()
bf16 = enum_auto()
f32 = enum_auto()
@ -117,6 +125,9 @@ ShortDataTypeNames = {
DataType.f64: 'd',
DataType.cf32: 'c',
DataType.cf64: 'z',
DataType.f8: 'f8',
DataType.f6: 'f6',
DataType.f4: 'f4',
}
#
@ -137,6 +148,14 @@ DataTypeNames = {
DataType.s64: "s64",
DataType.e4m3: 'e4m3',
DataType.e5m2: 'e5m2',
DataType.f8: 'f8',
DataType.f6: 'f6',
DataType.f4: 'f4',
DataType.e2m3: 'e2m3',
DataType.e3m2: 'e3m2',
DataType.e2m1: 'e2m1',
DataType.ue8m0: 'ue8m0',
DataType.ue4m3: 'ue4m3',
DataType.f16: "f16",
DataType.bf16: "bf16",
DataType.f32: "f32",
@ -178,6 +197,14 @@ DataTypeTag = {
DataType.s64: "int64_t",
DataType.e4m3: 'cutlass::float_e4m3_t',
DataType.e5m2: 'cutlass::float_e5m2_t',
DataType.f8: 'cutlass::type_erased_dynamic_float8_t',
DataType.f6: 'cutlass::type_erased_dynamic_float6_t',
DataType.f4: 'cutlass::type_erased_dynamic_float4_t',
DataType.e2m3: 'cutlass::float_e2m3_t',
DataType.e3m2: 'cutlass::float_e3m2_t',
DataType.e2m1: 'cutlass::float_e2m1_t',
DataType.ue8m0: 'cutlass::float_ue8m0_t',
DataType.ue4m3: 'cutlass::float_ue4m3_t',
DataType.f16: "cutlass::half_t",
DataType.bf16: "cutlass::bfloat16_t",
DataType.f32: "float",
@ -219,6 +246,14 @@ DataTypeSize = {
DataType.s64: 64,
DataType.e4m3: 8,
DataType.e5m2: 8,
DataType.f8: 8,
DataType.f6: 6,
DataType.f4: 4,
DataType.e2m3: 6,
DataType.e3m2: 6,
DataType.e2m1: 4,
DataType.ue8m0: 8,
DataType.ue4m3: 8,
DataType.f16: 16,
DataType.bf16: 16,
DataType.f32: 32,
@ -447,6 +482,22 @@ class KernelScheduleType(enum.Enum):
TmaWarpSpecializedCooperativeFP8FastAccum = enum_auto()
TmaWarpSpecializedPingpongFP8FastAccum = enum_auto()
ImplicitTmaWarpSpecializedSm90 = enum_auto()
TmaWarpSpecialized1SmSm100 = enum_auto()
TmaWarpSpecialized2SmSm100 = enum_auto()
BlockScaledTmaWarpSpecialized1SmSm100 = enum_auto()
BlockScaledTmaWarpSpecialized2SmSm100 = enum_auto()
Mxf8f6f4TmaWarpSpecialized1SmSm100 = enum_auto()
Mxf8f6f4TmaWarpSpecialized2SmSm100 = enum_auto()
Mxf4TmaWarpSpecialized1SmSm100 = enum_auto()
Mxf4TmaWarpSpecialized2SmSm100 = enum_auto()
Nvf4TmaWarpSpecialized1SmSm100 = enum_auto()
Nvf4TmaWarpSpecialized2SmSm100 = enum_auto()
#
KernelScheduleTag = {
KernelScheduleType.ScheduleAuto: 'cutlass::gemm::collective::KernelScheduleAuto',
@ -462,6 +513,22 @@ KernelScheduleTag = {
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum',
KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum',
KernelScheduleType.ImplicitTmaWarpSpecializedSm90: 'cutlass::conv::KernelImplicitTmaWarpSpecializedSm90',
KernelScheduleType.TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmSm100',
KernelScheduleType.TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmSm100',
KernelScheduleType.BlockScaledTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledSm100',
KernelScheduleType.BlockScaledTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledSm100',
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100',
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100',
KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmMxf4Sm100',
KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmMxf4Sm100',
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100',
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100',
}
#
@ -479,6 +546,22 @@ KernelScheduleSuffixes = {
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum',
KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum',
KernelScheduleType.ImplicitTmaWarpSpecializedSm90: '_warpspecialized',
KernelScheduleType.TmaWarpSpecialized1SmSm100: '_1sm',
KernelScheduleType.TmaWarpSpecialized2SmSm100: '_2sm',
KernelScheduleType.BlockScaledTmaWarpSpecialized1SmSm100: '_1sm',
KernelScheduleType.BlockScaledTmaWarpSpecialized2SmSm100: '_2sm',
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: '_q_1sm',
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: '_q_2sm',
KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm',
KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm',
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm',
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm',
}
class EpilogueScheduleType(enum.Enum):
@ -487,6 +570,9 @@ class EpilogueScheduleType(enum.Enum):
NoSmemWarpSpecialized = enum_auto()
TmaWarpSpecialized = enum_auto()
TmaWarpSpecializedCooperative = enum_auto()
TmaWarpSpecialized1Sm = enum_auto()
TmaWarpSpecialized2Sm = enum_auto()
#
EpilogueScheduleTag = {
EpilogueScheduleType.ScheduleAuto: 'cutlass::epilogue::collective::EpilogueScheduleAuto',
@ -494,6 +580,8 @@ EpilogueScheduleTag = {
EpilogueScheduleType.NoSmemWarpSpecialized: 'cutlass::epilogue::NoSmemWarpSpecialized',
EpilogueScheduleType.TmaWarpSpecialized: 'cutlass::epilogue::TmaWarpSpecialized',
EpilogueScheduleType.TmaWarpSpecializedCooperative: 'cutlass::epilogue::TmaWarpSpecializedCooperative',
EpilogueScheduleType.TmaWarpSpecialized1Sm: 'cutlass::epilogue::TmaWarpSpecialized1Sm',
EpilogueScheduleType.TmaWarpSpecialized2Sm: 'cutlass::epilogue::TmaWarpSpecialized2Sm',
}
#
@ -503,13 +591,18 @@ EpilogueScheduleSuffixes = {
EpilogueScheduleType.NoSmemWarpSpecialized: '_epi_nosmem',
EpilogueScheduleType.TmaWarpSpecialized: '_epi_tma',
EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma',
EpilogueScheduleType.TmaWarpSpecialized1Sm: '',
EpilogueScheduleType.TmaWarpSpecialized2Sm: '_epi_tma',
}
class EpilogueFunctor3x(enum.Enum):
LinearCombination = enum_auto()
LinearCombinationBlockScaleFactor = enum_auto()
#
EpilogueFunctor3xTag = {
EpilogueFunctor3x.LinearCombination: 'cutlass::epilogue::fusion::LinearCombination',
EpilogueFunctor3x.LinearCombinationBlockScaleFactor: 'cutlass::epilogue::fusion::LinCombBlockScaleFactor',
}
class TileSchedulerType(enum.Enum):
@ -595,12 +688,15 @@ class OpcodeClass(enum.Enum):
TensorOp = enum_auto()
WmmaTensorOp = enum_auto()
SparseTensorOp = enum_auto()
BlockScaledTensorOp = enum_auto()
OpcodeClassNames = {
OpcodeClass.Simt: 'simt',
OpcodeClass.TensorOp: 'tensorop',
OpcodeClass.WmmaTensorOp: 'wmma_tensorop',
OpcodeClass.SparseTensorOp: 'sptensorop',
OpcodeClass.BlockScaledTensorOp: 'bstensorop'
}
OpcodeClassTag = {
@ -608,6 +704,7 @@ OpcodeClassTag = {
OpcodeClass.TensorOp: 'cutlass::arch::OpClassTensorOp',
OpcodeClass.WmmaTensorOp: 'cutlass::arch::OpClassWmmaTensorOp',
OpcodeClass.SparseTensorOp: 'cutlass::arch::OpClassSparseTensorOp',
OpcodeClass.BlockScaledTensorOp: 'cutlass::arch::OpClassBlockScaledTensorOp'
}
###################################################################################################
@ -688,6 +785,8 @@ class GemmKind(enum.Enum):
PlanarComplex = enum_auto()
PlanarComplexArray = enum_auto()
Grouped = enum_auto()
BlockScaledUniversal3x = enum_auto()
#
GemmKindNames = {
GemmKind.Gemm: "gemm",
@ -698,6 +797,7 @@ GemmKindNames = {
GemmKind.PlanarComplex: "gemm_planar_complex",
GemmKind.PlanarComplexArray: "gemm_planar_complex_array",
GemmKind.Grouped: "gemm_grouped",
GemmKind.BlockScaledUniversal3x: "gemm_block_scaled"
}
#
@ -871,6 +971,8 @@ GroupModeNames = {
GroupMode.Depthwise: 'depthwise',
}
DynamicClusterShape = [0, 0, 1]
###################################################################################################
#
@ -879,6 +981,7 @@ class MathInstruction:
instruction_shape, \
element_a, element_b, element_accumulator, \
opcode_class, math_operation = MathOperation.multiply_add \
, element_scale_factor = None
):
self.instruction_shape = instruction_shape
@ -887,6 +990,8 @@ class MathInstruction:
self.element_accumulator = element_accumulator
self.opcode_class = opcode_class
self.math_operation = math_operation
self.element_scale_factor = element_scale_factor
#
class TileDescription:

View File

@ -522,6 +522,7 @@ class Manifest:
arch_conditional_cc = [
'90a',
'100a'
]
architectures = [x if x not in arch_conditional_cc else x.split('a')[0] for x in architectures]

View File

@ -36,7 +36,7 @@ from setuptools import setup
def perform_setup():
setup(
name='cutlass_library',
version='3.7.0',
version='3.8.0',
description='CUTLASS library generation scripts',
packages=['cutlass_library']
)

View File

@ -36,7 +36,7 @@ from setuptools import setup
def perform_setup():
setup(
name='pycute',
version='3.7.0',
version='3.8.0',
description='Python implementation of CuTe',
packages=['pycute'],
)