cutlass 3.9 update (#2255)
* cutlass 3.9 update * rebase * fixes out of shared memory for blockwise Blackwell * doc format * fix issue 2253 * disable host ref by default * fix sm120 smem capacity --------- Co-authored-by: yuzhai <yuzhai@nvidia.com> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
@ -568,6 +568,9 @@ class OptionRegistry:
|
||||
def __init__(self, target_cc: int):
|
||||
self.registry = {}
|
||||
|
||||
if target_cc > 90:
|
||||
raise Exception(f"Unsupported compute capability {target_cc}. The CUTLASS Python interface only supports compute capabilities up to 90.")
|
||||
|
||||
gemm_kinds = [cutlass_library.GemmKind.Universal, cutlass_library.GemmKind.Universal3x]
|
||||
operation_kinds = [cutlass_library.OperationKind.Gemm, cutlass_library.OperationKind.Conv2d]
|
||||
# Construct options for each CC
|
||||
|
||||
@ -65,7 +65,8 @@ class GemmOperation:
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None,
|
||||
kernel_schedule = KernelScheduleType.ScheduleAuto, epilogue_schedule = EpilogueScheduleType.ScheduleAuto,
|
||||
tile_scheduler = TileSchedulerType.Default, mixed_input_mode = None, mixed_input_shuffle = False,
|
||||
ScaleFactorA = None, ScaleFactorB = None, ScaleFactorD = None):
|
||||
ScaleFactorA = None, ScaleFactorB = None, ScaleFactorD = None,
|
||||
ScaleFactorMVecSize = None, ScaleFactorNVecSize = None, ScaleFactorKVecSize = None):
|
||||
|
||||
kinds_3x = {
|
||||
GemmKind.Universal3x,
|
||||
@ -73,6 +74,8 @@ class GemmOperation:
|
||||
GemmKind.BlockScaledUniversal3x,
|
||||
GemmKind.GroupedUniversal3x,
|
||||
GemmKind.GroupedBlockScaledUniversal3x,
|
||||
GemmKind.BlockwiseUniversal3x,
|
||||
GemmKind.GroupedBlockwiseUniversal3x,
|
||||
}
|
||||
self.is_3x = gemm_kind in kinds_3x
|
||||
self.prefix = "3x" if self.is_3x else ""
|
||||
@ -91,6 +94,11 @@ class GemmOperation:
|
||||
self.ScaleFactorD = ScaleFactorD["tensor"]
|
||||
self.ScaleFactorVectorSize = ScaleFactorD["vector_size"]
|
||||
|
||||
if is_blockwise(gemm_kind):
|
||||
self.ScaleFactorMVecSize = ScaleFactorMVecSize
|
||||
self.ScaleFactorNVecSize = ScaleFactorNVecSize
|
||||
self.ScaleFactorKVecSize = ScaleFactorKVecSize
|
||||
|
||||
if self.D == None:
|
||||
self.D = self.C
|
||||
|
||||
@ -191,6 +199,8 @@ class GemmOperation:
|
||||
# Generates a string representing the MMA instruction.
|
||||
def extended_name(self):
|
||||
''' Append data types if they differ from compute type. '''
|
||||
element_sfa = ""
|
||||
element_sfb = ""
|
||||
if self.is_complex():
|
||||
extended_name = "${core_name}"
|
||||
else:
|
||||
@ -198,6 +208,10 @@ class GemmOperation:
|
||||
extended_name = "${core_name}_${element_a}_${element_b}"
|
||||
if self.C.element != self.tile_description.math_instruction.element_accumulator:
|
||||
extended_name = "${element_c}_" + extended_name
|
||||
elif is_blockwise(self.gemm_kind):
|
||||
extended_name = "${core_name}_${element_sfa}x${element_a}_${element_sfb}x${element_b}"
|
||||
element_sfa = DataTypeNames[self.accumulator_type()]
|
||||
element_sfb = DataTypeNames[self.accumulator_type()]
|
||||
else:
|
||||
extended_name = "${core_name}"
|
||||
if self.C.element != self.tile_description.math_instruction.element_accumulator:
|
||||
@ -207,7 +221,9 @@ class GemmOperation:
|
||||
|
||||
extended_name = SubstituteTemplate(extended_name, {
|
||||
'element_a': DataTypeNames[self.A.element],
|
||||
'element_sfa' : element_sfa,
|
||||
'element_b': DataTypeNames[self.B.element],
|
||||
'element_sfb' : element_sfb,
|
||||
'element_c': DataTypeNames[self.C.element],
|
||||
'core_name': self.core_name()
|
||||
})
|
||||
@ -252,6 +268,22 @@ class GemmOperation:
|
||||
element_d = d_type_names,
|
||||
core_name = self.core_name())
|
||||
|
||||
if is_blockwise(self.gemm_kind):
|
||||
d_type_names = DataTypeNames[self.D.element]
|
||||
|
||||
extended_name = "{core_name}_{sfvec_m_size}x{sfvec_k_size}{element_sfa}x{element_a}_{sfvec_n_size}x{sfvec_k_size}{element_sfb}x{element_b}_{element_acc}_{element_c}_{element_d}".format(
|
||||
element_sfa = DataTypeNames[self.accumulator_type()],
|
||||
element_a = DataTypeNames[self.A.element],
|
||||
element_sfb = DataTypeNames[self.accumulator_type()],
|
||||
element_b = DataTypeNames[self.B.element],
|
||||
element_acc = DataTypeNames[self.accumulator_type()],
|
||||
element_c = DataTypeNames[self.C.element],
|
||||
element_d = d_type_names,
|
||||
sfvec_m_size = self.ScaleFactorMVecSize,
|
||||
sfvec_n_size = self.ScaleFactorNVecSize,
|
||||
sfvec_k_size = self.ScaleFactorKVecSize,
|
||||
core_name = self.core_name())
|
||||
|
||||
if self.mixed_input_mode != None:
|
||||
extended_name = extended_name + self.mixed_input_mode_name()
|
||||
return extended_name
|
||||
@ -761,6 +793,7 @@ class EmitGemmUniversal3xInstance:
|
||||
"cutlass/gemm/kernel/gemm_universal.hpp",
|
||||
"cutlass/gemm/collective/collective_builder.hpp",
|
||||
"cutlass/epilogue/collective/collective_builder.hpp",
|
||||
"cutlass/detail/blockwise_scale_layout.hpp",
|
||||
]
|
||||
self.builtin_epilogue_functor_template = \
|
||||
"""${epilogue_functor}<
|
||||
@ -786,6 +819,7 @@ using ${operation_name}_epilogue =
|
||||
>::CollectiveOp;
|
||||
|
||||
${mixed_dtype_prepare_code}
|
||||
${blockwise_prepare_code}
|
||||
|
||||
using ${operation_name}_mainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
@ -853,6 +887,18 @@ ${compile_guard_end}
|
||||
def pointerize_if_grouped(operation, layout):
|
||||
return layout if not is_grouped(operation.gemm_kind) else layout + "* "
|
||||
|
||||
@staticmethod
|
||||
def transform_layout_A_if_blockwise(operation, layout):
|
||||
layout_sfa = f"{operation.procedural_name()}_LayoutSFA"
|
||||
layout_sfa = layout_sfa if not is_grouped(operation.gemm_kind) else layout_sfa + "* "
|
||||
return layout if not is_blockwise(operation.gemm_kind) else f"cute::tuple<{layout}, {layout_sfa}>"
|
||||
|
||||
@staticmethod
|
||||
def transform_layout_B_if_blockwise(operation, layout):
|
||||
layout_sfb = f"{operation.procedural_name()}_LayoutSFB"
|
||||
layout_sfb = layout_sfb if not is_grouped(operation.gemm_kind) else layout_sfb + "* "
|
||||
return layout if not is_blockwise(operation.gemm_kind) else f"cute::tuple<{layout}, {layout_sfb}>"
|
||||
|
||||
@staticmethod
|
||||
def problem_shape(operation):
|
||||
gemm_shape_type = "cute::Shape<int,int,int,int>"
|
||||
@ -1017,14 +1063,25 @@ using {operation_name_str}_LayoutNarrowReordered = decltype(cute::tile_to_shape(
|
||||
else:
|
||||
element_b = narrow_element
|
||||
|
||||
blockwise_prepare_code = ""
|
||||
if is_blockwise(operation.gemm_kind):
|
||||
sfm_vec_size = operation.ScaleFactorMVecSize
|
||||
sfn_vec_size = operation.ScaleFactorNVecSize
|
||||
sfk_vec_size = operation.ScaleFactorKVecSize
|
||||
blockwise_prepare_code = f"""
|
||||
using {operation_name_str}_ScaleConfig = cutlass::detail::Sm{operation.arch}BlockwiseScaleConfig<{sfm_vec_size}, {sfn_vec_size}, {sfk_vec_size}>;
|
||||
using {operation_name_str}_LayoutSFA = decltype({operation_name_str}_ScaleConfig::deduce_layoutSFA());
|
||||
using {operation_name_str}_LayoutSFB = decltype({operation_name_str}_ScaleConfig::deduce_layoutSFB());
|
||||
"""
|
||||
|
||||
values = {
|
||||
'operation_name': operation_name_str,
|
||||
'operation_suffix': self.operation_suffix,
|
||||
'problem_shape': self.problem_shape(operation),
|
||||
'element_a': element_a,
|
||||
'layout_a': self.pointerize_if_grouped(operation, layout_a_str),
|
||||
'layout_a': self.transform_layout_A_if_blockwise(operation, self.pointerize_if_grouped(operation, layout_a_str)),
|
||||
'element_b': element_b,
|
||||
'layout_b': self.pointerize_if_grouped(operation, layout_b_str),
|
||||
'layout_b': self.transform_layout_B_if_blockwise(operation, self.pointerize_if_grouped(operation, layout_b_str)),
|
||||
'element_c': DataTypeTag[operation.C.element],
|
||||
'layout_c': self.pointerize_if_grouped(operation, LayoutTag[instance_layout_C]),
|
||||
'element_d': DataTypeTag[operation.D.element],
|
||||
@ -1057,7 +1114,8 @@ using {operation_name_str}_LayoutNarrowReordered = decltype(cute::tile_to_shape(
|
||||
'epilogue_vector_length': str(epilogue_vector_length),
|
||||
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
||||
'tile_scheduler': str(TileSchedulerTag[operation.tile_scheduler]),
|
||||
'mixed_dtype_prepare_code': mixed_dtype_prepare_code
|
||||
'mixed_dtype_prepare_code': mixed_dtype_prepare_code,
|
||||
'blockwise_prepare_code' : blockwise_prepare_code
|
||||
}
|
||||
|
||||
return SubstituteTemplate(self.gemm_template, values)
|
||||
@ -1387,6 +1445,8 @@ class EmitGemmConfigurationLibrary:
|
||||
GemmKind.Grouped: EmitGemmGroupedInstance,
|
||||
GemmKind.GroupedUniversal3x: EmitGemmUniversal3xInstance,
|
||||
GemmKind.GroupedBlockScaledUniversal3x: EmitGemmUniversal3xInstance,
|
||||
GemmKind.BlockwiseUniversal3x: EmitGemmUniversal3xInstance,
|
||||
GemmKind.GroupedBlockwiseUniversal3x: EmitGemmUniversal3xInstance,
|
||||
}
|
||||
|
||||
self.gemm_kind_wrappers = {
|
||||
@ -1401,6 +1461,8 @@ class EmitGemmConfigurationLibrary:
|
||||
GemmKind.Grouped: 'GemmGroupedOperation',
|
||||
GemmKind.GroupedUniversal3x: 'GroupedGemmUniversal3xOperation',
|
||||
GemmKind.GroupedBlockScaledUniversal3x: 'GroupedBlockScaledGemmUniversal3xOperation',
|
||||
GemmKind.BlockwiseUniversal3x: 'BlockwiseGemmUniversal3xOperation',
|
||||
GemmKind.GroupedBlockwiseUniversal3x: 'GroupedBlockwiseGemmUniversal3xOperation',
|
||||
}
|
||||
|
||||
self.wmma_guard_start = "#if defined(CUTLASS_ARCH_WMMA_SM${sm_number}_ENABLED)"
|
||||
@ -1460,6 +1522,7 @@ void initialize_${configuration_name}(Manifest &manifest) {
|
||||
("grouped_gemm_operation_3x.hpp", None),
|
||||
("sparse_gemm_operation_3x.hpp", None),
|
||||
("block_scaled_gemm_operation_3x.hpp", None),
|
||||
("blockwise_gemm_operation_3x.hpp", None),
|
||||
("cutlass/arch/wmma.h", None),
|
||||
("cutlass/numeric_types.h", None)
|
||||
])
|
||||
|
||||
@ -219,6 +219,15 @@ def CreateGemmUniversal3xOperator(
|
||||
gemm_op_extra_args["ScaleFactorD"] = { "tensor": TensorDescription(data_type["sfd_type"]["type"], data_type["sfd_type"]["layout"]),
|
||||
"vector_size" : data_type["sfd_type"]["vector_size"]}
|
||||
assert is_block_scaled(gemm_kind)
|
||||
|
||||
if tile_description.explicit_vector_sizes != None:
|
||||
assert len(tile_description.explicit_vector_sizes) == 3
|
||||
gemm_op_extra_args["ScaleFactorMVecSize"] = tile_description.explicit_vector_sizes[0]
|
||||
gemm_op_extra_args["ScaleFactorNVecSize"] = tile_description.explicit_vector_sizes[1]
|
||||
gemm_op_extra_args["ScaleFactorKVecSize"] = tile_description.explicit_vector_sizes[2]
|
||||
assert is_blockwise(gemm_kind)
|
||||
else:
|
||||
assert not is_blockwise(gemm_kind)
|
||||
|
||||
A_dtype = data_type["a_type"]
|
||||
B_dtype = data_type["b_type"]
|
||||
@ -5811,6 +5820,87 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
|
||||
stream_k_schedules,
|
||||
tile_schedulers=[TileSchedulerType.StreamK])
|
||||
|
||||
def GenerateSM90_TensorOp_fp8_WGMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.BlockwiseUniversal3x):
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 12, 3 if is_grouped(gemm_kind) else 0):
|
||||
return
|
||||
|
||||
instantiation_level = manifest.get_sm90_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9992)
|
||||
is_aligned = True
|
||||
|
||||
# layouts for ABC and their alignments
|
||||
layouts = [
|
||||
[[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 1]], # TN Layout
|
||||
]
|
||||
|
||||
math_instructions = generate_fp8_math_instructions_sm90(instantiation_level)
|
||||
tile_descriptions_ = generate_tile_descriptions_sm90(
|
||||
math_instructions=math_instructions,
|
||||
is_aligned=is_aligned,
|
||||
level=instantiation_level)
|
||||
|
||||
tile_descriptions = list()
|
||||
|
||||
for desc in tile_descriptions_:
|
||||
desc.explicit_vector_sizes = [1, desc.tile_shape[1], desc.tile_shape[2]]
|
||||
tile_descriptions.append(copy.deepcopy(desc))
|
||||
desc.explicit_vector_sizes = [desc.tile_shape[0], desc.tile_shape[1], desc.tile_shape[2]]
|
||||
tile_descriptions.append(copy.deepcopy(desc))
|
||||
desc.explicit_vector_sizes = [desc.tile_shape[0], desc.tile_shape[1], desc.tile_shape[2]]
|
||||
tile_descriptions.append(copy.deepcopy(desc))
|
||||
desc.explicit_vector_sizes = [1, 1, desc.tile_shape[2]]
|
||||
tile_descriptions.append(copy.deepcopy(desc))
|
||||
|
||||
for tile_desc in tile_descriptions:
|
||||
math_inst = tile_desc.math_instruction
|
||||
data_types = []
|
||||
fp8_types = [DataType.e4m3, DataType.e5m2]
|
||||
valid_types_for_d = [DataType.f32, DataType.bf16, DataType.f16, DataType.e4m3, DataType.e5m2]
|
||||
valid_types_for_c = copy.deepcopy(valid_types_for_d)
|
||||
valid_types_for_c.append(DataType.void)
|
||||
for c_type, d_type in product(valid_types_for_c, valid_types_for_d):
|
||||
data_types.append(
|
||||
generate_data_types_from_math_instruction(
|
||||
math_inst,
|
||||
element_source=c_type,
|
||||
element_dest=d_type,
|
||||
)
|
||||
)
|
||||
else:
|
||||
for d_type in valid_types_for_d:
|
||||
data_types.append(
|
||||
generate_data_types_from_math_instruction(
|
||||
math_inst,
|
||||
element_source=DataType.void,
|
||||
element_dest=d_type,
|
||||
)
|
||||
)
|
||||
|
||||
for layout in layouts:
|
||||
for data_type in data_types:
|
||||
# Inconsistency: alignments aren't fixed in FP8
|
||||
# layout = fix_alignments(data_type, layout, alignment_bits=128)
|
||||
|
||||
schedules, stream_k_schedules = get_valid_schedules(
|
||||
tile_description=tile_desc,
|
||||
cuda_version=cuda_version,
|
||||
is_aligned=is_aligned,
|
||||
data_types=data_type,
|
||||
instantiation_level=instantiation_level,
|
||||
layout=layout,
|
||||
gemm_kind=gemm_kind,
|
||||
enable_fp8_fast_acc=False,
|
||||
)
|
||||
|
||||
if len(schedules):
|
||||
CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules, gemm_kind=gemm_kind)
|
||||
if len(stream_k_schedules):
|
||||
assert CudaToolkitVersionSatisfies(cuda_version, 12, 1)
|
||||
CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type,
|
||||
stream_k_schedules,
|
||||
tile_schedulers=[TileSchedulerType.StreamK],
|
||||
gemm_kind=gemm_kind)
|
||||
|
||||
|
||||
|
||||
def GenerateSM90_TensorOp_fp8_WGMMA_alignx_gemm(manifest, cuda_version):
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 12, 0):
|
||||
@ -7499,6 +7589,245 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
||||
[[kernel_schedule, epi_schedule]], tile_schedulers=tile_schedulers, gemm_kind=gemm_kind)
|
||||
|
||||
def GenerateSM100_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.BlockwiseUniversal3x):
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 12, 8):
|
||||
return
|
||||
|
||||
grouped = is_grouped(gemm_kind)
|
||||
|
||||
# layouts for ABC and their alignments.
|
||||
layouts = [
|
||||
[[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]],
|
||||
[[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]],
|
||||
[[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]],
|
||||
[[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]],
|
||||
[[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]],
|
||||
[[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]],
|
||||
[[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]],
|
||||
[[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]],
|
||||
]
|
||||
|
||||
min_cc = 100
|
||||
max_cc = 100
|
||||
epi_type = DataType.f32
|
||||
|
||||
math_instructions_1sm = [
|
||||
# inst 64x128
|
||||
MathInstruction(
|
||||
[64, 128, 32],
|
||||
DataType.f8, DataType.f8, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
MathInstruction(
|
||||
[64, 128, 32],
|
||||
DataType.e4m3, DataType.e4m3, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
MathInstruction(
|
||||
[64, 128, 32],
|
||||
DataType.e4m3, DataType.e5m2, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
MathInstruction(
|
||||
[64, 128, 32],
|
||||
DataType.e5m2, DataType.e4m3, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
# inst 128x32
|
||||
MathInstruction(
|
||||
[128, 32, 32],
|
||||
DataType.f8, DataType.f8, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
MathInstruction(
|
||||
[128, 32, 32],
|
||||
DataType.e4m3, DataType.e4m3, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
MathInstruction(
|
||||
[128, 32, 32],
|
||||
DataType.e4m3, DataType.e5m2, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
MathInstruction(
|
||||
[128, 32, 32],
|
||||
DataType.e5m2, DataType.e4m3, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
# inst 128x64
|
||||
MathInstruction(
|
||||
[128, 64, 32],
|
||||
DataType.f8, DataType.f8, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
MathInstruction(
|
||||
[128, 64, 32],
|
||||
DataType.e4m3, DataType.e4m3, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
MathInstruction(
|
||||
[128, 64, 32],
|
||||
DataType.e4m3, DataType.e5m2, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
MathInstruction(
|
||||
[128, 64, 32],
|
||||
DataType.e5m2, DataType.e4m3, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
# inst 128x128
|
||||
MathInstruction(
|
||||
[128, 128, 32],
|
||||
DataType.f8, DataType.f8, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
MathInstruction(
|
||||
[128, 128, 32],
|
||||
DataType.e4m3, DataType.e4m3, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
MathInstruction(
|
||||
[128, 128, 32],
|
||||
DataType.e4m3, DataType.e5m2, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
MathInstruction(
|
||||
[128, 128, 32],
|
||||
DataType.e5m2, DataType.e4m3, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
# inst 128x256
|
||||
MathInstruction(
|
||||
[128, 256, 32],
|
||||
DataType.e4m3, DataType.e4m3, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
MathInstruction(
|
||||
[128, 256, 32],
|
||||
DataType.e4m3, DataType.e5m2, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
MathInstruction(
|
||||
[128, 256, 32],
|
||||
DataType.e5m2, DataType.e4m3, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add)]
|
||||
|
||||
cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1], [4,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
tile_schedulers = [
|
||||
TileSchedulerType.Default,
|
||||
]
|
||||
|
||||
# 1xSM MMA kernels
|
||||
for math_inst in math_instructions_1sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in cluster_shapes_1sm:
|
||||
multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape
|
||||
tile_descriptions.append(
|
||||
TileDescription([
|
||||
math_inst.instruction_shape[0] * multiplier_1sm[0],
|
||||
math_inst.instruction_shape[1] * multiplier_1sm[1],
|
||||
math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape,
|
||||
[math_inst.instruction_shape[0], math_inst.instruction_shape[1],
|
||||
math_inst.instruction_shape[2] * 4]))
|
||||
tile_descriptions.append(
|
||||
TileDescription([
|
||||
math_inst.instruction_shape[0] * multiplier_1sm[0],
|
||||
math_inst.instruction_shape[1] * multiplier_1sm[1],
|
||||
math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape,
|
||||
[1, math_inst.instruction_shape[1],
|
||||
math_inst.instruction_shape[2] * 4]))
|
||||
tile_descriptions.append(
|
||||
TileDescription([
|
||||
math_inst.instruction_shape[0] * multiplier_1sm[0],
|
||||
math_inst.instruction_shape[1] * multiplier_1sm[1],
|
||||
math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape,
|
||||
[math_inst.instruction_shape[0], 1,
|
||||
math_inst.instruction_shape[2] * 4]))
|
||||
|
||||
data_types = [
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.f16,
|
||||
"d_type" : DataType.f16,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.bf16,
|
||||
"d_type" : DataType.bf16,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.f32,
|
||||
"d_type" : DataType.f32,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.void,
|
||||
"d_type" : DataType.f16,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.void,
|
||||
"d_type" : DataType.bf16,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.void,
|
||||
"d_type" : DataType.f32,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
},
|
||||
]
|
||||
|
||||
# Set alignment d based on Destination format.
|
||||
for layout in layouts:
|
||||
layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]]
|
||||
|
||||
is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8)
|
||||
for data_type in data_types:
|
||||
if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\
|
||||
( data_type["d_type"] == DataType.e5m2 ):
|
||||
continue
|
||||
|
||||
is_runtime_datatype_a = is_runtime_datatype(data_type["a_type"])
|
||||
is_runtime_datatype_b = is_runtime_datatype(data_type["d_type"])
|
||||
|
||||
# A/B datatypes should be both static or dynamic
|
||||
if (is_runtime_datatype_a != is_runtime_datatype_b):
|
||||
continue
|
||||
|
||||
# grouped GEMM does not support runtime data type yet
|
||||
if grouped and (is_runtime_datatype_a or is_runtime_datatype_b):
|
||||
continue
|
||||
kernel_schedule = to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100, grouped)
|
||||
epi_schedule = to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
||||
[[kernel_schedule, epi_schedule]],
|
||||
tile_schedulers=tile_schedulers, gemm_kind=gemm_kind)
|
||||
|
||||
def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
|
||||
# SM100 MMA with mixed F4/F6/F8 inputs + without block scale
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 12, 0):
|
||||
@ -10318,6 +10647,11 @@ def GenerateSM100(manifest, cuda_version):
|
||||
|
||||
# StreamK is included in regular generation
|
||||
GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version)
|
||||
|
||||
# Blockwise kernels
|
||||
GenerateSM100_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version)
|
||||
GenerateSM100_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockwiseUniversal3x)
|
||||
|
||||
#
|
||||
# Sparse Gemm
|
||||
#
|
||||
@ -10755,6 +11089,8 @@ def GenerateSM90(manifest, cuda_version):
|
||||
GenerateSM90_SparseTensorOp_tf32_WGMMA_gemm(manifest, cuda_version)
|
||||
GenerateSM90_SparseTensorOp_int8_WGMMA_gemm(manifest, cuda_version)
|
||||
GenerateSM90_SparseTensorOp_fp8_WGMMA_gemm(manifest, cuda_version)
|
||||
GenerateSM90_TensorOp_fp8_WGMMA_gemm_with_blockwise(manifest, cuda_version)
|
||||
GenerateSM90_TensorOp_fp8_WGMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockwiseUniversal3x)
|
||||
|
||||
###################################################################################################
|
||||
|
||||
@ -10819,6 +11155,8 @@ if __name__ == "__main__":
|
||||
|
||||
manifest = Manifest(args)
|
||||
|
||||
archs = args.architectures.split(';')
|
||||
|
||||
GenerateSM50(manifest, args.cuda_version)
|
||||
GenerateSM60(manifest, args.cuda_version)
|
||||
GenerateSM61(manifest, args.cuda_version)
|
||||
@ -10827,8 +11165,8 @@ if __name__ == "__main__":
|
||||
GenerateSM80(manifest, args.cuda_version)
|
||||
GenerateSM89(manifest, args.cuda_version)
|
||||
GenerateSM90(manifest, args.cuda_version)
|
||||
|
||||
blackwell_enabled_arch = args.architectures in ["100a", "101a", "120a"]
|
||||
|
||||
blackwell_enabled_arch = any(arch in ["100a", "101a", "120a"] for arch in archs)
|
||||
if blackwell_enabled_arch:
|
||||
GenerateSM100(manifest, args.cuda_version)
|
||||
GenerateSM120(manifest, args.cuda_version)
|
||||
|
||||
@ -324,8 +324,12 @@ def is_complex(data_type):
|
||||
def is_block_scaled(gemm_kind):
|
||||
return gemm_kind in (GemmKind.BlockScaledUniversal3x, GemmKind.GroupedBlockScaledUniversal3x)
|
||||
|
||||
def is_blockwise(gemm_kind):
|
||||
return gemm_kind in (GemmKind.BlockwiseUniversal3x, GemmKind.GroupedBlockwiseUniversal3x)
|
||||
|
||||
def is_grouped(gemm_kind):
|
||||
return gemm_kind in (GemmKind.GroupedUniversal3x, GemmKind.GroupedBlockScaledUniversal3x)
|
||||
return gemm_kind in (GemmKind.GroupedUniversal3x,
|
||||
GemmKind.GroupedBlockScaledUniversal3x, GemmKind.GroupedBlockwiseUniversal3x)
|
||||
|
||||
#
|
||||
def get_complex_from_real(real_type):
|
||||
@ -493,6 +497,9 @@ class KernelScheduleType(enum.Enum):
|
||||
PtrArrayTmaWarpSpecializedPingpong = enum_auto()
|
||||
PtrArrayTmaWarpSpecializedPingpongFP8FastAccum = enum_auto()
|
||||
|
||||
BlockwiseTmaWarpSpecializedCooperative = enum_auto()
|
||||
PtrArrayBlockwiseTmaWarpSpecializedCooperative = enum_auto()
|
||||
|
||||
TmaWarpSpecialized1SmSm100 = enum_auto()
|
||||
TmaWarpSpecialized2SmSm100 = enum_auto()
|
||||
ImplicitTmaWarpSpecialized1SmSm100 = enum_auto()
|
||||
@ -518,6 +525,13 @@ class KernelScheduleType(enum.Enum):
|
||||
Mxf8f6f4TmaWarpSpecialized1SmSm100 = enum_auto()
|
||||
Mxf8f6f4TmaWarpSpecialized2SmSm100 = enum_auto()
|
||||
|
||||
BlockwiseTmaWarpSpecialized1SmSm100 = enum_auto()
|
||||
BlockwiseTmaWarpSpecialized2SmSm100 = enum_auto()
|
||||
|
||||
PtrArrayBlockwiseTmaWarpSpecialized1SmSm100 = enum_auto()
|
||||
PtrArrayBlockwiseTmaWarpSpecialized2SmSm100 = enum_auto()
|
||||
|
||||
|
||||
Mxf4TmaWarpSpecialized1SmSm100 = enum_auto()
|
||||
Mxf4TmaWarpSpecialized2SmSm100 = enum_auto()
|
||||
Nvf4TmaWarpSpecialized1SmSm100 = enum_auto()
|
||||
@ -547,6 +561,8 @@ KernelScheduleTag = {
|
||||
KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum',
|
||||
KernelScheduleType.ImplicitTmaWarpSpecializedSm90: 'cutlass::conv::KernelImplicitTmaWarpSpecializedSm90',
|
||||
|
||||
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum',
|
||||
|
||||
KernelScheduleType.TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmSm100',
|
||||
KernelScheduleType.TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmSm100',
|
||||
|
||||
@ -564,6 +580,12 @@ KernelScheduleTag = {
|
||||
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100',
|
||||
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100',
|
||||
|
||||
KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100',
|
||||
KernelScheduleType.BlockwiseTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100',
|
||||
|
||||
KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100',
|
||||
KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise2SmSm100',
|
||||
|
||||
KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmMxf4Sm100',
|
||||
KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmMxf4Sm100',
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100',
|
||||
@ -574,6 +596,8 @@ KernelScheduleTag = {
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong',
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum',
|
||||
|
||||
KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum',
|
||||
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledSm100",
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledSm100",
|
||||
KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100",
|
||||
@ -608,6 +632,8 @@ KernelScheduleSuffixes = {
|
||||
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum',
|
||||
KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum',
|
||||
KernelScheduleType.ImplicitTmaWarpSpecializedSm90: '_warpspecialized',
|
||||
|
||||
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative: '_warpspecialized_cooperative',
|
||||
|
||||
KernelScheduleType.TmaWarpSpecialized1SmSm100: '_1sm',
|
||||
KernelScheduleType.TmaWarpSpecialized2SmSm100: '_2sm',
|
||||
@ -626,6 +652,11 @@ KernelScheduleSuffixes = {
|
||||
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: '_q_1sm',
|
||||
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: '_q_2sm',
|
||||
|
||||
KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100: '_1sm',
|
||||
KernelScheduleType.BlockwiseTmaWarpSpecialized2SmSm100: '_2sm',
|
||||
KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized1SmSm100: '_1sm',
|
||||
KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized2SmSm100: '_2sm',
|
||||
|
||||
KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm',
|
||||
KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm',
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm',
|
||||
@ -636,6 +667,8 @@ KernelScheduleSuffixes = {
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_warpspecialized_pingpong',
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum',
|
||||
|
||||
KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative: '_warpspecialized_cooperative',
|
||||
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: '_1sm',
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: '_2sm',
|
||||
KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm',
|
||||
@ -730,6 +763,7 @@ def to_grouped_schedule(schedule, grouped):
|
||||
group_schedule_map = {
|
||||
# SM90
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative : KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative,
|
||||
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative,
|
||||
KernelScheduleType.TmaWarpSpecializedPingpong : KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong,
|
||||
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum : KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum,
|
||||
KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum : KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum,
|
||||
@ -745,6 +779,9 @@ def to_grouped_schedule(schedule, grouped):
|
||||
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100,
|
||||
EpilogueScheduleType.TmaWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm,
|
||||
EpilogueScheduleType.TmaWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm,
|
||||
KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized1SmSm100,
|
||||
KernelScheduleType.BlockwiseTmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized2SmSm100,
|
||||
|
||||
}
|
||||
|
||||
return group_schedule_map[schedule]
|
||||
@ -932,6 +969,8 @@ class GemmKind(enum.Enum):
|
||||
BlockScaledUniversal3x = enum_auto()
|
||||
GroupedUniversal3x = enum_auto()
|
||||
GroupedBlockScaledUniversal3x = enum_auto()
|
||||
BlockwiseUniversal3x = enum_auto()
|
||||
GroupedBlockwiseUniversal3x = enum_auto()
|
||||
|
||||
#
|
||||
GemmKindNames = {
|
||||
@ -945,7 +984,9 @@ GemmKindNames = {
|
||||
GemmKind.Grouped: "gemm_grouped",
|
||||
GemmKind.BlockScaledUniversal3x: "gemm",
|
||||
GemmKind.GroupedUniversal3x: "gemm_grouped",
|
||||
GemmKind.GroupedBlockScaledUniversal3x: "gemm_grouped"
|
||||
GemmKind.GroupedBlockScaledUniversal3x: "gemm_grouped",
|
||||
GemmKind.BlockwiseUniversal3x: "gemm",
|
||||
GemmKind.GroupedBlockwiseUniversal3x: "gemm_grouped"
|
||||
}
|
||||
|
||||
#
|
||||
@ -1149,7 +1190,7 @@ class MathInstruction:
|
||||
#
|
||||
class TileDescription:
|
||||
|
||||
def __init__(self, threadblock_shape, stages, warp_count, math_instruction, min_compute, max_compute, cluster_shape = [1,1,1]):
|
||||
def __init__(self, threadblock_shape, stages, warp_count, math_instruction, min_compute, max_compute, cluster_shape = [1,1,1], explicit_vector_sizes = None):
|
||||
self.threadblock_shape = threadblock_shape
|
||||
self.tile_shape = threadblock_shape
|
||||
self.stages = stages
|
||||
@ -1158,6 +1199,7 @@ class TileDescription:
|
||||
self.minimum_compute_capability = min_compute
|
||||
self.maximum_compute_capability = max_compute
|
||||
self.cluster_shape = cluster_shape
|
||||
self.explicit_vector_sizes = explicit_vector_sizes
|
||||
|
||||
def procedural_name(self):
|
||||
if self.minimum_compute_capability >= 90:
|
||||
|
||||
@ -511,16 +511,23 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
|
||||
return [], []
|
||||
if CudaToolkitVersionSatisfies(cuda_version, 12, 1) and can_do_cooperative and can_do_tma_epilogue:
|
||||
schedules = []
|
||||
schedules.append(
|
||||
[
|
||||
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
|
||||
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
|
||||
])
|
||||
schedules.append(
|
||||
[
|
||||
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),
|
||||
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
|
||||
])
|
||||
if is_blockwise(gemm_kind):
|
||||
schedules.append(
|
||||
[
|
||||
to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, grouped),
|
||||
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
|
||||
])
|
||||
else:
|
||||
schedules.append(
|
||||
[
|
||||
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
|
||||
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
|
||||
])
|
||||
schedules.append(
|
||||
[
|
||||
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),
|
||||
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
|
||||
])
|
||||
return schedules, []
|
||||
return [], []
|
||||
|
||||
@ -547,26 +554,34 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
|
||||
epilogue_schedule = EpilogueScheduleType.TmaWarpSpecialized
|
||||
if a_type_size > b_type_size:
|
||||
epilogue_schedule = EpilogueScheduleType.EpilogueTransposed
|
||||
schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecialized,
|
||||
epilogue_schedule
|
||||
])
|
||||
schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedPingpong,
|
||||
epilogue_schedule
|
||||
])
|
||||
|
||||
if not is_blockwise(gemm_kind):
|
||||
schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecialized,
|
||||
epilogue_schedule
|
||||
])
|
||||
schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedPingpong,
|
||||
epilogue_schedule
|
||||
])
|
||||
if cta_m >= 128:
|
||||
if a_type_size > b_type_size:
|
||||
epilogue_schedule = EpilogueScheduleType.EpilogueTransposed
|
||||
else:
|
||||
epilogue_schedule = EpilogueScheduleType.TmaWarpSpecializedCooperative
|
||||
schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
epilogue_schedule
|
||||
])
|
||||
if is_blockwise(gemm_kind):
|
||||
schedules.append([
|
||||
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative,
|
||||
epilogue_schedule
|
||||
])
|
||||
else:
|
||||
schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
epilogue_schedule
|
||||
])
|
||||
return schedules, []
|
||||
|
||||
if not is_aligned:
|
||||
if not is_aligned and not is_blockwise(gemm_kind):
|
||||
schedules = [[KernelScheduleType.CpAsyncWarpSpecialized,
|
||||
default_epilogue]]
|
||||
stream_k_schedules = []
|
||||
@ -585,7 +600,7 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
|
||||
|
||||
schedules = []
|
||||
# Pruning: emit Void-C and Grouped kernels with persistent kernels only
|
||||
if (level >= 1 or not is_void_c) and not grouped:
|
||||
if (level >= 1 or not is_void_c) and not grouped and not is_blockwise(gemm_kind):
|
||||
# Pruning: don't stamp out fp8 kernels with auto schedule
|
||||
if not is_fp8:
|
||||
schedules.append([KernelScheduleType.ScheduleAuto, auto_epilogue])
|
||||
@ -596,7 +611,7 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
|
||||
if can_do_tma_epilogue:
|
||||
assert not requires_transposed_epilogue
|
||||
# Inconsistency: fp8 pingpong only gets stamped out with fast accum
|
||||
if not is_fp8 or level >= 1:
|
||||
if (not is_fp8 or level >= 1) and not is_blockwise(gemm_kind):
|
||||
schedules.append([
|
||||
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpong, grouped),
|
||||
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized, grouped)
|
||||
@ -618,14 +633,24 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
|
||||
schedules.append([to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, grouped), to_grouped_schedule(default_epilogue, grouped)])
|
||||
|
||||
if can_do_cooperative:
|
||||
schedules.append([
|
||||
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
|
||||
to_grouped_schedule(default_epilogue, grouped)
|
||||
])
|
||||
stream_k_schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
default_epilogue
|
||||
])
|
||||
if is_blockwise(gemm_kind):
|
||||
schedules.append([
|
||||
to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, grouped),
|
||||
to_grouped_schedule(default_epilogue, grouped)
|
||||
])
|
||||
stream_k_schedules.append([
|
||||
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative,
|
||||
default_epilogue
|
||||
])
|
||||
else:
|
||||
schedules.append([
|
||||
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
|
||||
to_grouped_schedule(default_epilogue, grouped)
|
||||
])
|
||||
stream_k_schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
default_epilogue
|
||||
])
|
||||
if can_do_fp8_fast_accum:
|
||||
schedules.append([
|
||||
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),
|
||||
@ -640,14 +665,24 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
|
||||
if can_do_tma_epilogue:
|
||||
assert not requires_transposed_epilogue
|
||||
if can_do_cooperative:
|
||||
schedules.append([
|
||||
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
|
||||
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
|
||||
])
|
||||
stream_k_schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
EpilogueScheduleType.TmaWarpSpecializedCooperative
|
||||
])
|
||||
if is_blockwise(gemm_kind):
|
||||
schedules.append([
|
||||
to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, grouped),
|
||||
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
|
||||
])
|
||||
stream_k_schedules.append([
|
||||
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative,
|
||||
EpilogueScheduleType.TmaWarpSpecializedCooperative
|
||||
])
|
||||
else:
|
||||
schedules.append([
|
||||
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
|
||||
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
|
||||
])
|
||||
stream_k_schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
EpilogueScheduleType.TmaWarpSpecializedCooperative
|
||||
])
|
||||
if can_do_fp8_fast_accum:
|
||||
schedules.append([
|
||||
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),
|
||||
|
||||
Reference in New Issue
Block a user