cutlass 3.9 update (#2255)

* cutlass 3.9 update

* rebase

* fixes out of shared memory for blockwise Blackwell

* doc format

* fix issue 2253

* disable host ref by default

* fix sm120 smem capacity

---------

Co-authored-by: yuzhai <yuzhai@nvidia.com>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Yujia Zhai
2025-04-24 12:42:40 -07:00
committed by GitHub
parent 8e345c5c5b
commit 331a1f5b3f
143 changed files with 18089 additions and 5935 deletions

View File

@ -568,6 +568,9 @@ class OptionRegistry:
def __init__(self, target_cc: int):
self.registry = {}
if target_cc > 90:
raise Exception(f"Unsupported compute capability {target_cc}. The CUTLASS Python interface only supports compute capabilities up to 90.")
gemm_kinds = [cutlass_library.GemmKind.Universal, cutlass_library.GemmKind.Universal3x]
operation_kinds = [cutlass_library.OperationKind.Gemm, cutlass_library.OperationKind.Conv2d]
# Construct options for each CC

View File

@ -65,7 +65,8 @@ class GemmOperation:
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None,
kernel_schedule = KernelScheduleType.ScheduleAuto, epilogue_schedule = EpilogueScheduleType.ScheduleAuto,
tile_scheduler = TileSchedulerType.Default, mixed_input_mode = None, mixed_input_shuffle = False,
ScaleFactorA = None, ScaleFactorB = None, ScaleFactorD = None):
ScaleFactorA = None, ScaleFactorB = None, ScaleFactorD = None,
ScaleFactorMVecSize = None, ScaleFactorNVecSize = None, ScaleFactorKVecSize = None):
kinds_3x = {
GemmKind.Universal3x,
@ -73,6 +74,8 @@ class GemmOperation:
GemmKind.BlockScaledUniversal3x,
GemmKind.GroupedUniversal3x,
GemmKind.GroupedBlockScaledUniversal3x,
GemmKind.BlockwiseUniversal3x,
GemmKind.GroupedBlockwiseUniversal3x,
}
self.is_3x = gemm_kind in kinds_3x
self.prefix = "3x" if self.is_3x else ""
@ -91,6 +94,11 @@ class GemmOperation:
self.ScaleFactorD = ScaleFactorD["tensor"]
self.ScaleFactorVectorSize = ScaleFactorD["vector_size"]
if is_blockwise(gemm_kind):
self.ScaleFactorMVecSize = ScaleFactorMVecSize
self.ScaleFactorNVecSize = ScaleFactorNVecSize
self.ScaleFactorKVecSize = ScaleFactorKVecSize
if self.D == None:
self.D = self.C
@ -191,6 +199,8 @@ class GemmOperation:
# Generates a string representing the MMA instruction.
def extended_name(self):
''' Append data types if they differ from compute type. '''
element_sfa = ""
element_sfb = ""
if self.is_complex():
extended_name = "${core_name}"
else:
@ -198,6 +208,10 @@ class GemmOperation:
extended_name = "${core_name}_${element_a}_${element_b}"
if self.C.element != self.tile_description.math_instruction.element_accumulator:
extended_name = "${element_c}_" + extended_name
elif is_blockwise(self.gemm_kind):
extended_name = "${core_name}_${element_sfa}x${element_a}_${element_sfb}x${element_b}"
element_sfa = DataTypeNames[self.accumulator_type()]
element_sfb = DataTypeNames[self.accumulator_type()]
else:
extended_name = "${core_name}"
if self.C.element != self.tile_description.math_instruction.element_accumulator:
@ -207,7 +221,9 @@ class GemmOperation:
extended_name = SubstituteTemplate(extended_name, {
'element_a': DataTypeNames[self.A.element],
'element_sfa' : element_sfa,
'element_b': DataTypeNames[self.B.element],
'element_sfb' : element_sfb,
'element_c': DataTypeNames[self.C.element],
'core_name': self.core_name()
})
@ -252,6 +268,22 @@ class GemmOperation:
element_d = d_type_names,
core_name = self.core_name())
if is_blockwise(self.gemm_kind):
d_type_names = DataTypeNames[self.D.element]
extended_name = "{core_name}_{sfvec_m_size}x{sfvec_k_size}{element_sfa}x{element_a}_{sfvec_n_size}x{sfvec_k_size}{element_sfb}x{element_b}_{element_acc}_{element_c}_{element_d}".format(
element_sfa = DataTypeNames[self.accumulator_type()],
element_a = DataTypeNames[self.A.element],
element_sfb = DataTypeNames[self.accumulator_type()],
element_b = DataTypeNames[self.B.element],
element_acc = DataTypeNames[self.accumulator_type()],
element_c = DataTypeNames[self.C.element],
element_d = d_type_names,
sfvec_m_size = self.ScaleFactorMVecSize,
sfvec_n_size = self.ScaleFactorNVecSize,
sfvec_k_size = self.ScaleFactorKVecSize,
core_name = self.core_name())
if self.mixed_input_mode != None:
extended_name = extended_name + self.mixed_input_mode_name()
return extended_name
@ -761,6 +793,7 @@ class EmitGemmUniversal3xInstance:
"cutlass/gemm/kernel/gemm_universal.hpp",
"cutlass/gemm/collective/collective_builder.hpp",
"cutlass/epilogue/collective/collective_builder.hpp",
"cutlass/detail/blockwise_scale_layout.hpp",
]
self.builtin_epilogue_functor_template = \
"""${epilogue_functor}<
@ -786,6 +819,7 @@ using ${operation_name}_epilogue =
>::CollectiveOp;
${mixed_dtype_prepare_code}
${blockwise_prepare_code}
using ${operation_name}_mainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
@ -853,6 +887,18 @@ ${compile_guard_end}
def pointerize_if_grouped(operation, layout):
return layout if not is_grouped(operation.gemm_kind) else layout + "* "
@staticmethod
def transform_layout_A_if_blockwise(operation, layout):
layout_sfa = f"{operation.procedural_name()}_LayoutSFA"
layout_sfa = layout_sfa if not is_grouped(operation.gemm_kind) else layout_sfa + "* "
return layout if not is_blockwise(operation.gemm_kind) else f"cute::tuple<{layout}, {layout_sfa}>"
@staticmethod
def transform_layout_B_if_blockwise(operation, layout):
layout_sfb = f"{operation.procedural_name()}_LayoutSFB"
layout_sfb = layout_sfb if not is_grouped(operation.gemm_kind) else layout_sfb + "* "
return layout if not is_blockwise(operation.gemm_kind) else f"cute::tuple<{layout}, {layout_sfb}>"
@staticmethod
def problem_shape(operation):
gemm_shape_type = "cute::Shape<int,int,int,int>"
@ -1017,14 +1063,25 @@ using {operation_name_str}_LayoutNarrowReordered = decltype(cute::tile_to_shape(
else:
element_b = narrow_element
blockwise_prepare_code = ""
if is_blockwise(operation.gemm_kind):
sfm_vec_size = operation.ScaleFactorMVecSize
sfn_vec_size = operation.ScaleFactorNVecSize
sfk_vec_size = operation.ScaleFactorKVecSize
blockwise_prepare_code = f"""
using {operation_name_str}_ScaleConfig = cutlass::detail::Sm{operation.arch}BlockwiseScaleConfig<{sfm_vec_size}, {sfn_vec_size}, {sfk_vec_size}>;
using {operation_name_str}_LayoutSFA = decltype({operation_name_str}_ScaleConfig::deduce_layoutSFA());
using {operation_name_str}_LayoutSFB = decltype({operation_name_str}_ScaleConfig::deduce_layoutSFB());
"""
values = {
'operation_name': operation_name_str,
'operation_suffix': self.operation_suffix,
'problem_shape': self.problem_shape(operation),
'element_a': element_a,
'layout_a': self.pointerize_if_grouped(operation, layout_a_str),
'layout_a': self.transform_layout_A_if_blockwise(operation, self.pointerize_if_grouped(operation, layout_a_str)),
'element_b': element_b,
'layout_b': self.pointerize_if_grouped(operation, layout_b_str),
'layout_b': self.transform_layout_B_if_blockwise(operation, self.pointerize_if_grouped(operation, layout_b_str)),
'element_c': DataTypeTag[operation.C.element],
'layout_c': self.pointerize_if_grouped(operation, LayoutTag[instance_layout_C]),
'element_d': DataTypeTag[operation.D.element],
@ -1057,7 +1114,8 @@ using {operation_name_str}_LayoutNarrowReordered = decltype(cute::tile_to_shape(
'epilogue_vector_length': str(epilogue_vector_length),
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
'tile_scheduler': str(TileSchedulerTag[operation.tile_scheduler]),
'mixed_dtype_prepare_code': mixed_dtype_prepare_code
'mixed_dtype_prepare_code': mixed_dtype_prepare_code,
'blockwise_prepare_code' : blockwise_prepare_code
}
return SubstituteTemplate(self.gemm_template, values)
@ -1387,6 +1445,8 @@ class EmitGemmConfigurationLibrary:
GemmKind.Grouped: EmitGemmGroupedInstance,
GemmKind.GroupedUniversal3x: EmitGemmUniversal3xInstance,
GemmKind.GroupedBlockScaledUniversal3x: EmitGemmUniversal3xInstance,
GemmKind.BlockwiseUniversal3x: EmitGemmUniversal3xInstance,
GemmKind.GroupedBlockwiseUniversal3x: EmitGemmUniversal3xInstance,
}
self.gemm_kind_wrappers = {
@ -1401,6 +1461,8 @@ class EmitGemmConfigurationLibrary:
GemmKind.Grouped: 'GemmGroupedOperation',
GemmKind.GroupedUniversal3x: 'GroupedGemmUniversal3xOperation',
GemmKind.GroupedBlockScaledUniversal3x: 'GroupedBlockScaledGemmUniversal3xOperation',
GemmKind.BlockwiseUniversal3x: 'BlockwiseGemmUniversal3xOperation',
GemmKind.GroupedBlockwiseUniversal3x: 'GroupedBlockwiseGemmUniversal3xOperation',
}
self.wmma_guard_start = "#if defined(CUTLASS_ARCH_WMMA_SM${sm_number}_ENABLED)"
@ -1460,6 +1522,7 @@ void initialize_${configuration_name}(Manifest &manifest) {
("grouped_gemm_operation_3x.hpp", None),
("sparse_gemm_operation_3x.hpp", None),
("block_scaled_gemm_operation_3x.hpp", None),
("blockwise_gemm_operation_3x.hpp", None),
("cutlass/arch/wmma.h", None),
("cutlass/numeric_types.h", None)
])

View File

@ -219,6 +219,15 @@ def CreateGemmUniversal3xOperator(
gemm_op_extra_args["ScaleFactorD"] = { "tensor": TensorDescription(data_type["sfd_type"]["type"], data_type["sfd_type"]["layout"]),
"vector_size" : data_type["sfd_type"]["vector_size"]}
assert is_block_scaled(gemm_kind)
if tile_description.explicit_vector_sizes != None:
assert len(tile_description.explicit_vector_sizes) == 3
gemm_op_extra_args["ScaleFactorMVecSize"] = tile_description.explicit_vector_sizes[0]
gemm_op_extra_args["ScaleFactorNVecSize"] = tile_description.explicit_vector_sizes[1]
gemm_op_extra_args["ScaleFactorKVecSize"] = tile_description.explicit_vector_sizes[2]
assert is_blockwise(gemm_kind)
else:
assert not is_blockwise(gemm_kind)
A_dtype = data_type["a_type"]
B_dtype = data_type["b_type"]
@ -5811,6 +5820,87 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
stream_k_schedules,
tile_schedulers=[TileSchedulerType.StreamK])
def GenerateSM90_TensorOp_fp8_WGMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.BlockwiseUniversal3x):
if not CudaToolkitVersionSatisfies(cuda_version, 12, 3 if is_grouped(gemm_kind) else 0):
return
instantiation_level = manifest.get_sm90_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9992)
is_aligned = True
# layouts for ABC and their alignments
layouts = [
[[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 1]], # TN Layout
]
math_instructions = generate_fp8_math_instructions_sm90(instantiation_level)
tile_descriptions_ = generate_tile_descriptions_sm90(
math_instructions=math_instructions,
is_aligned=is_aligned,
level=instantiation_level)
tile_descriptions = list()
for desc in tile_descriptions_:
desc.explicit_vector_sizes = [1, desc.tile_shape[1], desc.tile_shape[2]]
tile_descriptions.append(copy.deepcopy(desc))
desc.explicit_vector_sizes = [desc.tile_shape[0], desc.tile_shape[1], desc.tile_shape[2]]
tile_descriptions.append(copy.deepcopy(desc))
desc.explicit_vector_sizes = [desc.tile_shape[0], desc.tile_shape[1], desc.tile_shape[2]]
tile_descriptions.append(copy.deepcopy(desc))
desc.explicit_vector_sizes = [1, 1, desc.tile_shape[2]]
tile_descriptions.append(copy.deepcopy(desc))
for tile_desc in tile_descriptions:
math_inst = tile_desc.math_instruction
data_types = []
fp8_types = [DataType.e4m3, DataType.e5m2]
valid_types_for_d = [DataType.f32, DataType.bf16, DataType.f16, DataType.e4m3, DataType.e5m2]
valid_types_for_c = copy.deepcopy(valid_types_for_d)
valid_types_for_c.append(DataType.void)
for c_type, d_type in product(valid_types_for_c, valid_types_for_d):
data_types.append(
generate_data_types_from_math_instruction(
math_inst,
element_source=c_type,
element_dest=d_type,
)
)
else:
for d_type in valid_types_for_d:
data_types.append(
generate_data_types_from_math_instruction(
math_inst,
element_source=DataType.void,
element_dest=d_type,
)
)
for layout in layouts:
for data_type in data_types:
# Inconsistency: alignments aren't fixed in FP8
# layout = fix_alignments(data_type, layout, alignment_bits=128)
schedules, stream_k_schedules = get_valid_schedules(
tile_description=tile_desc,
cuda_version=cuda_version,
is_aligned=is_aligned,
data_types=data_type,
instantiation_level=instantiation_level,
layout=layout,
gemm_kind=gemm_kind,
enable_fp8_fast_acc=False,
)
if len(schedules):
CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules, gemm_kind=gemm_kind)
if len(stream_k_schedules):
assert CudaToolkitVersionSatisfies(cuda_version, 12, 1)
CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type,
stream_k_schedules,
tile_schedulers=[TileSchedulerType.StreamK],
gemm_kind=gemm_kind)
def GenerateSM90_TensorOp_fp8_WGMMA_alignx_gemm(manifest, cuda_version):
if not CudaToolkitVersionSatisfies(cuda_version, 12, 0):
@ -7499,6 +7589,245 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
[[kernel_schedule, epi_schedule]], tile_schedulers=tile_schedulers, gemm_kind=gemm_kind)
def GenerateSM100_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.BlockwiseUniversal3x):
if not CudaToolkitVersionSatisfies(cuda_version, 12, 8):
return
grouped = is_grouped(gemm_kind)
# layouts for ABC and their alignments.
layouts = [
[[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]],
[[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]],
[[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]],
[[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]],
[[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]],
[[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]],
[[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]],
[[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]],
]
min_cc = 100
max_cc = 100
epi_type = DataType.f32
math_instructions_1sm = [
# inst 64x128
MathInstruction(
[64, 128, 32],
DataType.f8, DataType.f8, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
MathInstruction(
[64, 128, 32],
DataType.e4m3, DataType.e4m3, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
MathInstruction(
[64, 128, 32],
DataType.e4m3, DataType.e5m2, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
MathInstruction(
[64, 128, 32],
DataType.e5m2, DataType.e4m3, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
# inst 128x32
MathInstruction(
[128, 32, 32],
DataType.f8, DataType.f8, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
MathInstruction(
[128, 32, 32],
DataType.e4m3, DataType.e4m3, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
MathInstruction(
[128, 32, 32],
DataType.e4m3, DataType.e5m2, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
MathInstruction(
[128, 32, 32],
DataType.e5m2, DataType.e4m3, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
# inst 128x64
MathInstruction(
[128, 64, 32],
DataType.f8, DataType.f8, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
MathInstruction(
[128, 64, 32],
DataType.e4m3, DataType.e4m3, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
MathInstruction(
[128, 64, 32],
DataType.e4m3, DataType.e5m2, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
MathInstruction(
[128, 64, 32],
DataType.e5m2, DataType.e4m3, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
# inst 128x128
MathInstruction(
[128, 128, 32],
DataType.f8, DataType.f8, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
MathInstruction(
[128, 128, 32],
DataType.e4m3, DataType.e4m3, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
MathInstruction(
[128, 128, 32],
DataType.e4m3, DataType.e5m2, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
MathInstruction(
[128, 128, 32],
DataType.e5m2, DataType.e4m3, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
# inst 128x256
MathInstruction(
[128, 256, 32],
DataType.e4m3, DataType.e4m3, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
MathInstruction(
[128, 256, 32],
DataType.e4m3, DataType.e5m2, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
MathInstruction(
[128, 256, 32],
DataType.e5m2, DataType.e4m3, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add)]
cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1], [4,4,1]
, DynamicClusterShape
]
tile_schedulers = [
TileSchedulerType.Default,
]
# 1xSM MMA kernels
for math_inst in math_instructions_1sm:
tile_descriptions = []
for cluster_shape in cluster_shapes_1sm:
multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape
tile_descriptions.append(
TileDescription([
math_inst.instruction_shape[0] * multiplier_1sm[0],
math_inst.instruction_shape[1] * multiplier_1sm[1],
math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]],
0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape,
[math_inst.instruction_shape[0], math_inst.instruction_shape[1],
math_inst.instruction_shape[2] * 4]))
tile_descriptions.append(
TileDescription([
math_inst.instruction_shape[0] * multiplier_1sm[0],
math_inst.instruction_shape[1] * multiplier_1sm[1],
math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]],
0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape,
[1, math_inst.instruction_shape[1],
math_inst.instruction_shape[2] * 4]))
tile_descriptions.append(
TileDescription([
math_inst.instruction_shape[0] * multiplier_1sm[0],
math_inst.instruction_shape[1] * multiplier_1sm[1],
math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]],
0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape,
[math_inst.instruction_shape[0], 1,
math_inst.instruction_shape[2] * 4]))
data_types = [
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.f16,
"d_type" : DataType.f16,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.bf16,
"d_type" : DataType.bf16,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.f32,
"d_type" : DataType.f32,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.void,
"d_type" : DataType.f16,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.void,
"d_type" : DataType.bf16,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.void,
"d_type" : DataType.f32,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
},
]
# Set alignment d based on Destination format.
for layout in layouts:
layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]]
is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8)
for data_type in data_types:
if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\
( data_type["d_type"] == DataType.e5m2 ):
continue
is_runtime_datatype_a = is_runtime_datatype(data_type["a_type"])
is_runtime_datatype_b = is_runtime_datatype(data_type["d_type"])
# A/B datatypes should be both static or dynamic
if (is_runtime_datatype_a != is_runtime_datatype_b):
continue
# grouped GEMM does not support runtime data type yet
if grouped and (is_runtime_datatype_a or is_runtime_datatype_b):
continue
kernel_schedule = to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100, grouped)
epi_schedule = to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
[[kernel_schedule, epi_schedule]],
tile_schedulers=tile_schedulers, gemm_kind=gemm_kind)
def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
# SM100 MMA with mixed F4/F6/F8 inputs + without block scale
if not CudaToolkitVersionSatisfies(cuda_version, 12, 0):
@ -10318,6 +10647,11 @@ def GenerateSM100(manifest, cuda_version):
# StreamK is included in regular generation
GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version)
# Blockwise kernels
GenerateSM100_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version)
GenerateSM100_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockwiseUniversal3x)
#
# Sparse Gemm
#
@ -10755,6 +11089,8 @@ def GenerateSM90(manifest, cuda_version):
GenerateSM90_SparseTensorOp_tf32_WGMMA_gemm(manifest, cuda_version)
GenerateSM90_SparseTensorOp_int8_WGMMA_gemm(manifest, cuda_version)
GenerateSM90_SparseTensorOp_fp8_WGMMA_gemm(manifest, cuda_version)
GenerateSM90_TensorOp_fp8_WGMMA_gemm_with_blockwise(manifest, cuda_version)
GenerateSM90_TensorOp_fp8_WGMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockwiseUniversal3x)
###################################################################################################
@ -10819,6 +11155,8 @@ if __name__ == "__main__":
manifest = Manifest(args)
archs = args.architectures.split(';')
GenerateSM50(manifest, args.cuda_version)
GenerateSM60(manifest, args.cuda_version)
GenerateSM61(manifest, args.cuda_version)
@ -10827,8 +11165,8 @@ if __name__ == "__main__":
GenerateSM80(manifest, args.cuda_version)
GenerateSM89(manifest, args.cuda_version)
GenerateSM90(manifest, args.cuda_version)
blackwell_enabled_arch = args.architectures in ["100a", "101a", "120a"]
blackwell_enabled_arch = any(arch in ["100a", "101a", "120a"] for arch in archs)
if blackwell_enabled_arch:
GenerateSM100(manifest, args.cuda_version)
GenerateSM120(manifest, args.cuda_version)

View File

@ -324,8 +324,12 @@ def is_complex(data_type):
def is_block_scaled(gemm_kind):
return gemm_kind in (GemmKind.BlockScaledUniversal3x, GemmKind.GroupedBlockScaledUniversal3x)
def is_blockwise(gemm_kind):
return gemm_kind in (GemmKind.BlockwiseUniversal3x, GemmKind.GroupedBlockwiseUniversal3x)
def is_grouped(gemm_kind):
return gemm_kind in (GemmKind.GroupedUniversal3x, GemmKind.GroupedBlockScaledUniversal3x)
return gemm_kind in (GemmKind.GroupedUniversal3x,
GemmKind.GroupedBlockScaledUniversal3x, GemmKind.GroupedBlockwiseUniversal3x)
#
def get_complex_from_real(real_type):
@ -493,6 +497,9 @@ class KernelScheduleType(enum.Enum):
PtrArrayTmaWarpSpecializedPingpong = enum_auto()
PtrArrayTmaWarpSpecializedPingpongFP8FastAccum = enum_auto()
BlockwiseTmaWarpSpecializedCooperative = enum_auto()
PtrArrayBlockwiseTmaWarpSpecializedCooperative = enum_auto()
TmaWarpSpecialized1SmSm100 = enum_auto()
TmaWarpSpecialized2SmSm100 = enum_auto()
ImplicitTmaWarpSpecialized1SmSm100 = enum_auto()
@ -518,6 +525,13 @@ class KernelScheduleType(enum.Enum):
Mxf8f6f4TmaWarpSpecialized1SmSm100 = enum_auto()
Mxf8f6f4TmaWarpSpecialized2SmSm100 = enum_auto()
BlockwiseTmaWarpSpecialized1SmSm100 = enum_auto()
BlockwiseTmaWarpSpecialized2SmSm100 = enum_auto()
PtrArrayBlockwiseTmaWarpSpecialized1SmSm100 = enum_auto()
PtrArrayBlockwiseTmaWarpSpecialized2SmSm100 = enum_auto()
Mxf4TmaWarpSpecialized1SmSm100 = enum_auto()
Mxf4TmaWarpSpecialized2SmSm100 = enum_auto()
Nvf4TmaWarpSpecialized1SmSm100 = enum_auto()
@ -547,6 +561,8 @@ KernelScheduleTag = {
KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum',
KernelScheduleType.ImplicitTmaWarpSpecializedSm90: 'cutlass::conv::KernelImplicitTmaWarpSpecializedSm90',
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum',
KernelScheduleType.TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmSm100',
KernelScheduleType.TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmSm100',
@ -564,6 +580,12 @@ KernelScheduleTag = {
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100',
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100',
KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100',
KernelScheduleType.BlockwiseTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100',
KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100',
KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise2SmSm100',
KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmMxf4Sm100',
KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmMxf4Sm100',
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100',
@ -574,6 +596,8 @@ KernelScheduleTag = {
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong',
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum',
KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum',
KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledSm100",
KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledSm100",
KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100",
@ -608,6 +632,8 @@ KernelScheduleSuffixes = {
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum',
KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum',
KernelScheduleType.ImplicitTmaWarpSpecializedSm90: '_warpspecialized',
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative: '_warpspecialized_cooperative',
KernelScheduleType.TmaWarpSpecialized1SmSm100: '_1sm',
KernelScheduleType.TmaWarpSpecialized2SmSm100: '_2sm',
@ -626,6 +652,11 @@ KernelScheduleSuffixes = {
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: '_q_1sm',
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: '_q_2sm',
KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100: '_1sm',
KernelScheduleType.BlockwiseTmaWarpSpecialized2SmSm100: '_2sm',
KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized1SmSm100: '_1sm',
KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized2SmSm100: '_2sm',
KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm',
KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm',
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm',
@ -636,6 +667,8 @@ KernelScheduleSuffixes = {
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_warpspecialized_pingpong',
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum',
KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative: '_warpspecialized_cooperative',
KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: '_1sm',
KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: '_2sm',
KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm',
@ -730,6 +763,7 @@ def to_grouped_schedule(schedule, grouped):
group_schedule_map = {
# SM90
KernelScheduleType.TmaWarpSpecializedCooperative : KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative,
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative,
KernelScheduleType.TmaWarpSpecializedPingpong : KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong,
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum : KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum,
KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum : KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum,
@ -745,6 +779,9 @@ def to_grouped_schedule(schedule, grouped):
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100,
EpilogueScheduleType.TmaWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm,
EpilogueScheduleType.TmaWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm,
KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized1SmSm100,
KernelScheduleType.BlockwiseTmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized2SmSm100,
}
return group_schedule_map[schedule]
@ -932,6 +969,8 @@ class GemmKind(enum.Enum):
BlockScaledUniversal3x = enum_auto()
GroupedUniversal3x = enum_auto()
GroupedBlockScaledUniversal3x = enum_auto()
BlockwiseUniversal3x = enum_auto()
GroupedBlockwiseUniversal3x = enum_auto()
#
GemmKindNames = {
@ -945,7 +984,9 @@ GemmKindNames = {
GemmKind.Grouped: "gemm_grouped",
GemmKind.BlockScaledUniversal3x: "gemm",
GemmKind.GroupedUniversal3x: "gemm_grouped",
GemmKind.GroupedBlockScaledUniversal3x: "gemm_grouped"
GemmKind.GroupedBlockScaledUniversal3x: "gemm_grouped",
GemmKind.BlockwiseUniversal3x: "gemm",
GemmKind.GroupedBlockwiseUniversal3x: "gemm_grouped"
}
#
@ -1149,7 +1190,7 @@ class MathInstruction:
#
class TileDescription:
def __init__(self, threadblock_shape, stages, warp_count, math_instruction, min_compute, max_compute, cluster_shape = [1,1,1]):
def __init__(self, threadblock_shape, stages, warp_count, math_instruction, min_compute, max_compute, cluster_shape = [1,1,1], explicit_vector_sizes = None):
self.threadblock_shape = threadblock_shape
self.tile_shape = threadblock_shape
self.stages = stages
@ -1158,6 +1199,7 @@ class TileDescription:
self.minimum_compute_capability = min_compute
self.maximum_compute_capability = max_compute
self.cluster_shape = cluster_shape
self.explicit_vector_sizes = explicit_vector_sizes
def procedural_name(self):
if self.minimum_compute_capability >= 90:

View File

@ -511,16 +511,23 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
return [], []
if CudaToolkitVersionSatisfies(cuda_version, 12, 1) and can_do_cooperative and can_do_tma_epilogue:
schedules = []
schedules.append(
[
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
])
schedules.append(
[
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
])
if is_blockwise(gemm_kind):
schedules.append(
[
to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, grouped),
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
])
else:
schedules.append(
[
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
])
schedules.append(
[
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
])
return schedules, []
return [], []
@ -547,26 +554,34 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
epilogue_schedule = EpilogueScheduleType.TmaWarpSpecialized
if a_type_size > b_type_size:
epilogue_schedule = EpilogueScheduleType.EpilogueTransposed
schedules.append([
KernelScheduleType.TmaWarpSpecialized,
epilogue_schedule
])
schedules.append([
KernelScheduleType.TmaWarpSpecializedPingpong,
epilogue_schedule
])
if not is_blockwise(gemm_kind):
schedules.append([
KernelScheduleType.TmaWarpSpecialized,
epilogue_schedule
])
schedules.append([
KernelScheduleType.TmaWarpSpecializedPingpong,
epilogue_schedule
])
if cta_m >= 128:
if a_type_size > b_type_size:
epilogue_schedule = EpilogueScheduleType.EpilogueTransposed
else:
epilogue_schedule = EpilogueScheduleType.TmaWarpSpecializedCooperative
schedules.append([
KernelScheduleType.TmaWarpSpecializedCooperative,
epilogue_schedule
])
if is_blockwise(gemm_kind):
schedules.append([
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative,
epilogue_schedule
])
else:
schedules.append([
KernelScheduleType.TmaWarpSpecializedCooperative,
epilogue_schedule
])
return schedules, []
if not is_aligned:
if not is_aligned and not is_blockwise(gemm_kind):
schedules = [[KernelScheduleType.CpAsyncWarpSpecialized,
default_epilogue]]
stream_k_schedules = []
@ -585,7 +600,7 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
schedules = []
# Pruning: emit Void-C and Grouped kernels with persistent kernels only
if (level >= 1 or not is_void_c) and not grouped:
if (level >= 1 or not is_void_c) and not grouped and not is_blockwise(gemm_kind):
# Pruning: don't stamp out fp8 kernels with auto schedule
if not is_fp8:
schedules.append([KernelScheduleType.ScheduleAuto, auto_epilogue])
@ -596,7 +611,7 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
if can_do_tma_epilogue:
assert not requires_transposed_epilogue
# Inconsistency: fp8 pingpong only gets stamped out with fast accum
if not is_fp8 or level >= 1:
if (not is_fp8 or level >= 1) and not is_blockwise(gemm_kind):
schedules.append([
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpong, grouped),
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized, grouped)
@ -618,14 +633,24 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
schedules.append([to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, grouped), to_grouped_schedule(default_epilogue, grouped)])
if can_do_cooperative:
schedules.append([
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
to_grouped_schedule(default_epilogue, grouped)
])
stream_k_schedules.append([
KernelScheduleType.TmaWarpSpecializedCooperative,
default_epilogue
])
if is_blockwise(gemm_kind):
schedules.append([
to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, grouped),
to_grouped_schedule(default_epilogue, grouped)
])
stream_k_schedules.append([
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative,
default_epilogue
])
else:
schedules.append([
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
to_grouped_schedule(default_epilogue, grouped)
])
stream_k_schedules.append([
KernelScheduleType.TmaWarpSpecializedCooperative,
default_epilogue
])
if can_do_fp8_fast_accum:
schedules.append([
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),
@ -640,14 +665,24 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
if can_do_tma_epilogue:
assert not requires_transposed_epilogue
if can_do_cooperative:
schedules.append([
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
])
stream_k_schedules.append([
KernelScheduleType.TmaWarpSpecializedCooperative,
EpilogueScheduleType.TmaWarpSpecializedCooperative
])
if is_blockwise(gemm_kind):
schedules.append([
to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, grouped),
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
])
stream_k_schedules.append([
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative,
EpilogueScheduleType.TmaWarpSpecializedCooperative
])
else:
schedules.append([
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
])
stream_k_schedules.append([
KernelScheduleType.TmaWarpSpecializedCooperative,
EpilogueScheduleType.TmaWarpSpecializedCooperative
])
if can_do_fp8_fast_accum:
schedules.append([
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),