CUTLASS 3.8 Release (#2059)
* CUTLASS 3.8 Release
* update
* Update README.md
* Revert "Update README.md"
This reverts commit b353e36fe8.
* update
* update
---------
Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
@ -134,7 +134,7 @@ def get_option_registry():
|
||||
this._option_registry = OptionRegistry(device_cc())
|
||||
return this._option_registry
|
||||
|
||||
this.__version__ = '3.7.0'
|
||||
this.__version__ = '3.8.0'
|
||||
|
||||
from cutlass.backend import create_memory_pool
|
||||
from cutlass.emit.pytorch import pytorch
|
||||
|
||||
@ -65,11 +65,15 @@ class GemmOperation:
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None,
|
||||
kernel_schedule = KernelScheduleType.ScheduleAuto, epilogue_schedule = EpilogueScheduleType.ScheduleAuto,
|
||||
tile_scheduler = TileSchedulerType.Default
|
||||
|
||||
, ScaleFactorA = None, ScaleFactorB = None, ScaleFactorD = None
|
||||
|
||||
):
|
||||
|
||||
kinds_3x = {
|
||||
GemmKind.Universal3x,
|
||||
GemmKind.SparseUniversal3x,
|
||||
GemmKind.BlockScaledUniversal3x,
|
||||
}
|
||||
self.is_3x = gemm_kind in kinds_3x
|
||||
self.prefix = "3x" if self.is_3x else ""
|
||||
@ -82,6 +86,14 @@ class GemmOperation:
|
||||
self.C = C
|
||||
self.D = D
|
||||
|
||||
|
||||
if self.gemm_kind == GemmKind.BlockScaledUniversal3x:
|
||||
self.ScaleFactorA = ScaleFactorA
|
||||
self.ScaleFactorB = ScaleFactorB
|
||||
self.ScaleFactorD = ScaleFactorD["tensor"]
|
||||
self.ScaleFactorVectorSize = ScaleFactorD["vector_size"]
|
||||
|
||||
|
||||
if self.D == None:
|
||||
self.D = self.C
|
||||
|
||||
@ -150,6 +162,7 @@ class GemmOperation:
|
||||
OpcodeClass.TensorOp,
|
||||
OpcodeClass.WmmaTensorOp,
|
||||
OpcodeClass.SparseTensorOp,
|
||||
OpcodeClass.BlockScaledTensorOp,
|
||||
]
|
||||
|
||||
is_tensor_op = self.tile_description.math_instruction.opcode_class in tensor_ops
|
||||
@ -207,6 +220,23 @@ class GemmOperation:
|
||||
element_c = DataTypeNames[self.C.element],
|
||||
element_d = DataTypeNames[self.D.element],
|
||||
core_name = self.core_name())
|
||||
|
||||
if self.gemm_kind == GemmKind.BlockScaledUniversal3x:
|
||||
d_type_names = DataTypeNames[self.D.element]
|
||||
|
||||
if self.ScaleFactorD.element != DataType.void:
|
||||
d_type_names = DataTypeNames[self.ScaleFactorD.element] + "x" + d_type_names
|
||||
|
||||
extended_name = "{core_name}_{element_sfa}x{element_a}_{element_sfb}x{element_b}_{element_acc}_{element_c}_{element_d}".format(
|
||||
element_sfa = DataTypeNames[self.ScaleFactorA],
|
||||
element_a = DataTypeNames[self.A.element],
|
||||
element_sfb = DataTypeNames[self.ScaleFactorB],
|
||||
element_b = DataTypeNames[self.B.element],
|
||||
element_acc = DataTypeNames[self.accumulator_type()],
|
||||
element_c = DataTypeNames[self.C.element],
|
||||
element_d = d_type_names,
|
||||
core_name = self.core_name())
|
||||
|
||||
return extended_name
|
||||
|
||||
def datatype_name_3x(self):
|
||||
@ -247,6 +277,11 @@ class GemmOperation:
|
||||
|
||||
# Generates a short string representing underlying epilogue schedule type
|
||||
def epilogue_schedule_name_3x(self):
|
||||
|
||||
if self.gemm_kind == GemmKind.BlockScaledUniversal3x:
|
||||
if self.ScaleFactorD.element != DataType.void:
|
||||
return EpilogueScheduleSuffixes[self.epilogue_schedule] + "_epiVs" + str(self.ScaleFactorVectorSize)+ShortLayoutTypeNames[self.ScaleFactorD.layout]
|
||||
|
||||
return EpilogueScheduleSuffixes[self.epilogue_schedule]
|
||||
|
||||
# Generate a short string representing the operation class
|
||||
@ -769,6 +804,32 @@ ${compile_guard_start}
|
||||
${compile_guard_end}
|
||||
"""
|
||||
|
||||
|
||||
def emit_block_scale_epilogue_functor(self, operation):
|
||||
block_scaled_template = """
|
||||
${epilogue_functor}<
|
||||
${epi_vs},
|
||||
${element_d},
|
||||
${element_accumulator},
|
||||
${element_sfd},
|
||||
${layout_sfd},
|
||||
${element_c},
|
||||
${element_scalar}
|
||||
>
|
||||
"""
|
||||
block_scaled_values = {
|
||||
'epi_vs' : str(operation.ScaleFactorVectorSize),
|
||||
'element_d': str(DataTypeTag[operation.D.element]),
|
||||
'element_sfd': str(DataTypeTag[operation.ScaleFactorD.element]),
|
||||
'layout_sfd': LayoutTag[operation.ScaleFactorD.layout],
|
||||
'epilogue_functor': EpilogueFunctor3xTag[EpilogueFunctor3x.LinearCombinationBlockScaleFactor],
|
||||
'element_accumulator': str(DataTypeTag[operation.accumulator_type()]),
|
||||
'element_scalar': str(DataTypeTag[operation.accumulator_type()]),
|
||||
'element_c': str(DataTypeTag[operation.C.element]),
|
||||
}
|
||||
return SubstituteTemplate(block_scaled_template, block_scaled_values)
|
||||
|
||||
|
||||
#
|
||||
def emit(self, operation):
|
||||
_LOGGER.debug("*** EmitGemmConfigurationLibrary::emit(operation)")
|
||||
@ -778,6 +839,12 @@ ${compile_guard_end}
|
||||
|
||||
opcode_class_main = operation.tile_description.math_instruction.opcode_class
|
||||
opcode_class_epi = opcode_class_main
|
||||
|
||||
if opcode_class_main == OpcodeClass.BlockScaledTensorOp:
|
||||
if operation.epilogue_schedule != EpilogueScheduleType.NoSmemWarpSpecialized:
|
||||
opcode_class_epi = OpcodeClass.TensorOp
|
||||
|
||||
|
||||
tile_shape = operation.tile_description.tile_shape
|
||||
instruction_shape = operation.tile_description.math_instruction.instruction_shape
|
||||
cluster_m = operation.tile_description.cluster_shape[0]
|
||||
@ -790,6 +857,23 @@ ${compile_guard_end}
|
||||
cta_m = tile_shape[0] // cluster_m if cluster_m > 0 else tile_shape[0]
|
||||
cta_n = tile_shape[1] // cluster_n if cluster_n > 0 else tile_shape[1]
|
||||
|
||||
|
||||
# Shape passed to epilogue builder
|
||||
is_sm100_kernel = (operation.arch == 100)
|
||||
if is_sm100_kernel:
|
||||
cta_m_per_mma_instruction = 2 if "2sm" in operation.procedural_name() else 1
|
||||
if cluster_m <= 0:
|
||||
cta_m = cta_m // cta_m_per_mma_instruction
|
||||
|
||||
if opcode_class_main in [OpcodeClass.TensorOp
|
||||
, OpcodeClass.BlockScaledTensorOp
|
||||
]:
|
||||
tile_shape_main_m = instruction_shape[0]
|
||||
tile_shape_main_n = instruction_shape[1]
|
||||
tile_shape_epi_m = cta_m
|
||||
tile_shape_epi_n = cta_n
|
||||
|
||||
|
||||
# stage count set to zero indicates builder automatic stage selection
|
||||
if operation.tile_description.stages > 0:
|
||||
stage_count_string = f"cutlass::gemm::collective::StageCount<{str(operation.tile_description.stages)}>"
|
||||
@ -811,14 +895,37 @@ ${compile_guard_end}
|
||||
'epilogue_functor': EpilogueFunctor3xTag[operation.epilogue_functor],
|
||||
}
|
||||
epilogue_functor = SubstituteTemplate(self.builtin_epilogue_functor_template, values)
|
||||
|
||||
if operation.gemm_kind == GemmKind.BlockScaledUniversal3x and operation.ScaleFactorD.element != DataType.void:
|
||||
epilogue_functor = self.emit_block_scale_epilogue_functor(operation)
|
||||
|
||||
|
||||
else:
|
||||
epilogue_functor = self.epilogue_functor.emit_declaration()
|
||||
|
||||
if operation.gemm_kind == GemmKind.BlockScaledUniversal3x and operation.ScaleFactorD.element != DataType.void:
|
||||
epilogue_functor = self.emit_block_scale_epilogue_functor(operation)
|
||||
|
||||
#
|
||||
# Cutlass3x complex kernels' ElementA(B) is a tuple in collective mainloop builder, e.g. cute::tuple<Element, Transform>, Transform : cute::identity / cute::conjugate.
|
||||
element_a = DataTypeTag[operation.A.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.A.element])},{str(ComplexTransformTag3x[operation.A.complex_transform])}>"
|
||||
element_b = DataTypeTag[operation.B.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.B.element])},{str(ComplexTransformTag3x[operation.B.complex_transform])}>"
|
||||
epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule]
|
||||
is_no_smem_epilogue = operation.epilogue_schedule == EpilogueScheduleType.NoSmemWarpSpecialized
|
||||
|
||||
if opcode_class_main == OpcodeClass.BlockScaledTensorOp:
|
||||
if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100:
|
||||
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
|
||||
if not is_no_smem_epilogue:
|
||||
epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized1Sm]
|
||||
if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100:
|
||||
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
|
||||
if not is_no_smem_epilogue:
|
||||
epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized2Sm]
|
||||
element_a = f'cute::tuple<{str(element_a)},{str(DataTypeTag[operation.ScaleFactorA])}>'
|
||||
element_b = f'cute::tuple<{str(element_b)},{str(DataTypeTag[operation.ScaleFactorB])}>'
|
||||
|
||||
|
||||
values = {
|
||||
'operation_name': operation.procedural_name(),
|
||||
'operation_suffix': self.operation_suffix,
|
||||
@ -1184,6 +1291,7 @@ class EmitGemmConfigurationLibrary:
|
||||
GemmKind.Universal: EmitGemmUniversalInstance,
|
||||
GemmKind.Universal3x: EmitGemmUniversal3xInstance,
|
||||
GemmKind.SparseUniversal3x: EmitGemmUniversal3xInstance,
|
||||
GemmKind.BlockScaledUniversal3x: EmitGemmUniversal3xInstance,
|
||||
GemmKind.PlanarComplex: EmitGemmPlanarComplexInstance,
|
||||
GemmKind.PlanarComplexArray: EmitGemmPlanarComplexArrayInstance,
|
||||
GemmKind.Grouped: EmitGemmGroupedInstance
|
||||
@ -1195,6 +1303,7 @@ class EmitGemmConfigurationLibrary:
|
||||
GemmKind.Universal: 'GemmUniversalOperation',
|
||||
GemmKind.Universal3x: 'GemmUniversal3xOperation',
|
||||
GemmKind.SparseUniversal3x: 'SparseGemmUniversal3xOperation',
|
||||
GemmKind.BlockScaledUniversal3x: 'BlockScaledGemmUniversal3xOperation',
|
||||
GemmKind.PlanarComplex: 'GemmPlanarComplexOperation',
|
||||
GemmKind.PlanarComplexArray: 'GemmPlanarComplexArrayOperation',
|
||||
GemmKind.Grouped: 'GemmGroupedOperation'
|
||||
@ -1255,6 +1364,7 @@ void initialize_${configuration_name}(Manifest &manifest) {
|
||||
("gemm_operation.h", None),
|
||||
("gemm_operation_3x.hpp", None),
|
||||
("sparse_gemm_operation_3x.hpp", None),
|
||||
("block_scaled_gemm_operation_3x.hpp", None),
|
||||
("cutlass/arch/wmma.h", None),
|
||||
("cutlass/numeric_types.h", None)
|
||||
])
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -83,6 +83,14 @@ class DataType(enum.Enum):
|
||||
s64 = enum_auto()
|
||||
e4m3 = enum_auto()
|
||||
e5m2 = enum_auto()
|
||||
f8 = enum_auto()
|
||||
f6 = enum_auto()
|
||||
f4 = enum_auto()
|
||||
e3m2 = enum_auto()
|
||||
e2m3 = enum_auto()
|
||||
e2m1 = enum_auto()
|
||||
ue8m0 = enum_auto()
|
||||
ue4m3 = enum_auto()
|
||||
f16 = enum_auto()
|
||||
bf16 = enum_auto()
|
||||
f32 = enum_auto()
|
||||
@ -117,6 +125,9 @@ ShortDataTypeNames = {
|
||||
DataType.f64: 'd',
|
||||
DataType.cf32: 'c',
|
||||
DataType.cf64: 'z',
|
||||
DataType.f8: 'f8',
|
||||
DataType.f6: 'f6',
|
||||
DataType.f4: 'f4',
|
||||
}
|
||||
|
||||
#
|
||||
@ -137,6 +148,14 @@ DataTypeNames = {
|
||||
DataType.s64: "s64",
|
||||
DataType.e4m3: 'e4m3',
|
||||
DataType.e5m2: 'e5m2',
|
||||
DataType.f8: 'f8',
|
||||
DataType.f6: 'f6',
|
||||
DataType.f4: 'f4',
|
||||
DataType.e2m3: 'e2m3',
|
||||
DataType.e3m2: 'e3m2',
|
||||
DataType.e2m1: 'e2m1',
|
||||
DataType.ue8m0: 'ue8m0',
|
||||
DataType.ue4m3: 'ue4m3',
|
||||
DataType.f16: "f16",
|
||||
DataType.bf16: "bf16",
|
||||
DataType.f32: "f32",
|
||||
@ -178,6 +197,14 @@ DataTypeTag = {
|
||||
DataType.s64: "int64_t",
|
||||
DataType.e4m3: 'cutlass::float_e4m3_t',
|
||||
DataType.e5m2: 'cutlass::float_e5m2_t',
|
||||
DataType.f8: 'cutlass::type_erased_dynamic_float8_t',
|
||||
DataType.f6: 'cutlass::type_erased_dynamic_float6_t',
|
||||
DataType.f4: 'cutlass::type_erased_dynamic_float4_t',
|
||||
DataType.e2m3: 'cutlass::float_e2m3_t',
|
||||
DataType.e3m2: 'cutlass::float_e3m2_t',
|
||||
DataType.e2m1: 'cutlass::float_e2m1_t',
|
||||
DataType.ue8m0: 'cutlass::float_ue8m0_t',
|
||||
DataType.ue4m3: 'cutlass::float_ue4m3_t',
|
||||
DataType.f16: "cutlass::half_t",
|
||||
DataType.bf16: "cutlass::bfloat16_t",
|
||||
DataType.f32: "float",
|
||||
@ -219,6 +246,14 @@ DataTypeSize = {
|
||||
DataType.s64: 64,
|
||||
DataType.e4m3: 8,
|
||||
DataType.e5m2: 8,
|
||||
DataType.f8: 8,
|
||||
DataType.f6: 6,
|
||||
DataType.f4: 4,
|
||||
DataType.e2m3: 6,
|
||||
DataType.e3m2: 6,
|
||||
DataType.e2m1: 4,
|
||||
DataType.ue8m0: 8,
|
||||
DataType.ue4m3: 8,
|
||||
DataType.f16: 16,
|
||||
DataType.bf16: 16,
|
||||
DataType.f32: 32,
|
||||
@ -447,6 +482,22 @@ class KernelScheduleType(enum.Enum):
|
||||
TmaWarpSpecializedCooperativeFP8FastAccum = enum_auto()
|
||||
TmaWarpSpecializedPingpongFP8FastAccum = enum_auto()
|
||||
ImplicitTmaWarpSpecializedSm90 = enum_auto()
|
||||
|
||||
TmaWarpSpecialized1SmSm100 = enum_auto()
|
||||
TmaWarpSpecialized2SmSm100 = enum_auto()
|
||||
|
||||
|
||||
BlockScaledTmaWarpSpecialized1SmSm100 = enum_auto()
|
||||
BlockScaledTmaWarpSpecialized2SmSm100 = enum_auto()
|
||||
Mxf8f6f4TmaWarpSpecialized1SmSm100 = enum_auto()
|
||||
Mxf8f6f4TmaWarpSpecialized2SmSm100 = enum_auto()
|
||||
|
||||
|
||||
Mxf4TmaWarpSpecialized1SmSm100 = enum_auto()
|
||||
Mxf4TmaWarpSpecialized2SmSm100 = enum_auto()
|
||||
Nvf4TmaWarpSpecialized1SmSm100 = enum_auto()
|
||||
Nvf4TmaWarpSpecialized2SmSm100 = enum_auto()
|
||||
|
||||
#
|
||||
KernelScheduleTag = {
|
||||
KernelScheduleType.ScheduleAuto: 'cutlass::gemm::collective::KernelScheduleAuto',
|
||||
@ -462,6 +513,22 @@ KernelScheduleTag = {
|
||||
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum',
|
||||
KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum',
|
||||
KernelScheduleType.ImplicitTmaWarpSpecializedSm90: 'cutlass::conv::KernelImplicitTmaWarpSpecializedSm90',
|
||||
|
||||
KernelScheduleType.TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmSm100',
|
||||
KernelScheduleType.TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmSm100',
|
||||
|
||||
|
||||
KernelScheduleType.BlockScaledTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledSm100',
|
||||
KernelScheduleType.BlockScaledTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledSm100',
|
||||
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100',
|
||||
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100',
|
||||
|
||||
|
||||
KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmMxf4Sm100',
|
||||
KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmMxf4Sm100',
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100',
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100',
|
||||
|
||||
}
|
||||
|
||||
#
|
||||
@ -479,6 +546,22 @@ KernelScheduleSuffixes = {
|
||||
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum',
|
||||
KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum',
|
||||
KernelScheduleType.ImplicitTmaWarpSpecializedSm90: '_warpspecialized',
|
||||
|
||||
KernelScheduleType.TmaWarpSpecialized1SmSm100: '_1sm',
|
||||
KernelScheduleType.TmaWarpSpecialized2SmSm100: '_2sm',
|
||||
|
||||
|
||||
KernelScheduleType.BlockScaledTmaWarpSpecialized1SmSm100: '_1sm',
|
||||
KernelScheduleType.BlockScaledTmaWarpSpecialized2SmSm100: '_2sm',
|
||||
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: '_q_1sm',
|
||||
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: '_q_2sm',
|
||||
|
||||
|
||||
KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm',
|
||||
KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm',
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm',
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm',
|
||||
|
||||
}
|
||||
|
||||
class EpilogueScheduleType(enum.Enum):
|
||||
@ -487,6 +570,9 @@ class EpilogueScheduleType(enum.Enum):
|
||||
NoSmemWarpSpecialized = enum_auto()
|
||||
TmaWarpSpecialized = enum_auto()
|
||||
TmaWarpSpecializedCooperative = enum_auto()
|
||||
TmaWarpSpecialized1Sm = enum_auto()
|
||||
TmaWarpSpecialized2Sm = enum_auto()
|
||||
|
||||
#
|
||||
EpilogueScheduleTag = {
|
||||
EpilogueScheduleType.ScheduleAuto: 'cutlass::epilogue::collective::EpilogueScheduleAuto',
|
||||
@ -494,6 +580,8 @@ EpilogueScheduleTag = {
|
||||
EpilogueScheduleType.NoSmemWarpSpecialized: 'cutlass::epilogue::NoSmemWarpSpecialized',
|
||||
EpilogueScheduleType.TmaWarpSpecialized: 'cutlass::epilogue::TmaWarpSpecialized',
|
||||
EpilogueScheduleType.TmaWarpSpecializedCooperative: 'cutlass::epilogue::TmaWarpSpecializedCooperative',
|
||||
EpilogueScheduleType.TmaWarpSpecialized1Sm: 'cutlass::epilogue::TmaWarpSpecialized1Sm',
|
||||
EpilogueScheduleType.TmaWarpSpecialized2Sm: 'cutlass::epilogue::TmaWarpSpecialized2Sm',
|
||||
}
|
||||
|
||||
#
|
||||
@ -503,13 +591,18 @@ EpilogueScheduleSuffixes = {
|
||||
EpilogueScheduleType.NoSmemWarpSpecialized: '_epi_nosmem',
|
||||
EpilogueScheduleType.TmaWarpSpecialized: '_epi_tma',
|
||||
EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma',
|
||||
EpilogueScheduleType.TmaWarpSpecialized1Sm: '',
|
||||
EpilogueScheduleType.TmaWarpSpecialized2Sm: '_epi_tma',
|
||||
}
|
||||
|
||||
class EpilogueFunctor3x(enum.Enum):
|
||||
LinearCombination = enum_auto()
|
||||
LinearCombinationBlockScaleFactor = enum_auto()
|
||||
|
||||
#
|
||||
EpilogueFunctor3xTag = {
|
||||
EpilogueFunctor3x.LinearCombination: 'cutlass::epilogue::fusion::LinearCombination',
|
||||
EpilogueFunctor3x.LinearCombinationBlockScaleFactor: 'cutlass::epilogue::fusion::LinCombBlockScaleFactor',
|
||||
}
|
||||
|
||||
class TileSchedulerType(enum.Enum):
|
||||
@ -595,12 +688,15 @@ class OpcodeClass(enum.Enum):
|
||||
TensorOp = enum_auto()
|
||||
WmmaTensorOp = enum_auto()
|
||||
SparseTensorOp = enum_auto()
|
||||
BlockScaledTensorOp = enum_auto()
|
||||
|
||||
|
||||
OpcodeClassNames = {
|
||||
OpcodeClass.Simt: 'simt',
|
||||
OpcodeClass.TensorOp: 'tensorop',
|
||||
OpcodeClass.WmmaTensorOp: 'wmma_tensorop',
|
||||
OpcodeClass.SparseTensorOp: 'sptensorop',
|
||||
OpcodeClass.BlockScaledTensorOp: 'bstensorop'
|
||||
}
|
||||
|
||||
OpcodeClassTag = {
|
||||
@ -608,6 +704,7 @@ OpcodeClassTag = {
|
||||
OpcodeClass.TensorOp: 'cutlass::arch::OpClassTensorOp',
|
||||
OpcodeClass.WmmaTensorOp: 'cutlass::arch::OpClassWmmaTensorOp',
|
||||
OpcodeClass.SparseTensorOp: 'cutlass::arch::OpClassSparseTensorOp',
|
||||
OpcodeClass.BlockScaledTensorOp: 'cutlass::arch::OpClassBlockScaledTensorOp'
|
||||
}
|
||||
|
||||
###################################################################################################
|
||||
@ -688,6 +785,8 @@ class GemmKind(enum.Enum):
|
||||
PlanarComplex = enum_auto()
|
||||
PlanarComplexArray = enum_auto()
|
||||
Grouped = enum_auto()
|
||||
BlockScaledUniversal3x = enum_auto()
|
||||
|
||||
#
|
||||
GemmKindNames = {
|
||||
GemmKind.Gemm: "gemm",
|
||||
@ -698,6 +797,7 @@ GemmKindNames = {
|
||||
GemmKind.PlanarComplex: "gemm_planar_complex",
|
||||
GemmKind.PlanarComplexArray: "gemm_planar_complex_array",
|
||||
GemmKind.Grouped: "gemm_grouped",
|
||||
GemmKind.BlockScaledUniversal3x: "gemm_block_scaled"
|
||||
}
|
||||
|
||||
#
|
||||
@ -871,6 +971,8 @@ GroupModeNames = {
|
||||
GroupMode.Depthwise: 'depthwise',
|
||||
}
|
||||
|
||||
DynamicClusterShape = [0, 0, 1]
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
@ -879,6 +981,7 @@ class MathInstruction:
|
||||
instruction_shape, \
|
||||
element_a, element_b, element_accumulator, \
|
||||
opcode_class, math_operation = MathOperation.multiply_add \
|
||||
, element_scale_factor = None
|
||||
):
|
||||
|
||||
self.instruction_shape = instruction_shape
|
||||
@ -887,6 +990,8 @@ class MathInstruction:
|
||||
self.element_accumulator = element_accumulator
|
||||
self.opcode_class = opcode_class
|
||||
self.math_operation = math_operation
|
||||
self.element_scale_factor = element_scale_factor
|
||||
|
||||
#
|
||||
class TileDescription:
|
||||
|
||||
|
||||
@ -522,6 +522,7 @@ class Manifest:
|
||||
|
||||
arch_conditional_cc = [
|
||||
'90a',
|
||||
'100a'
|
||||
]
|
||||
architectures = [x if x not in arch_conditional_cc else x.split('a')[0] for x in architectures]
|
||||
|
||||
|
||||
@ -36,7 +36,7 @@ from setuptools import setup
|
||||
def perform_setup():
|
||||
setup(
|
||||
name='cutlass_library',
|
||||
version='3.7.0',
|
||||
version='3.8.0',
|
||||
description='CUTLASS library generation scripts',
|
||||
packages=['cutlass_library']
|
||||
)
|
||||
|
||||
@ -36,7 +36,7 @@ from setuptools import setup
|
||||
def perform_setup():
|
||||
setup(
|
||||
name='pycute',
|
||||
version='3.7.0',
|
||||
version='3.8.0',
|
||||
description='Python implementation of CuTe',
|
||||
packages=['pycute'],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user