v4.0 update. (#2371)

This commit is contained in:
Junkai-Wu
2025-06-06 14:39:20 +08:00
committed by GitHub
parent 2e2af190bd
commit 8bdbfca682
254 changed files with 29751 additions and 1980 deletions

View File

@ -253,7 +253,8 @@ def _getInstType(input_precision, accumulate_precision, math_instruction):
return inst
# TODO: Computes FLOps/Bytes for GEMM - revisit for conv
def _computeFlopsPerByte(operation, m, n, k, batch_count=1, beta=0.0):
def _computeFlopsPerByte(operation, m, n, k, batch_count=1, beta=0.0, num_groups=1):
assert not (batch_count > 1 and num_groups > 1)
# TODO: adjust for sparsity
gmem_bytes = (
@ -269,16 +270,15 @@ def _computeFlopsPerByte(operation, m, n, k, batch_count=1, beta=0.0):
gmem_bytes += (DataTypeSize[operation.C.element] * m // 8) * n
flops += 2 * m * n
gmem_bytes *= batch_count
flops *= batch_count
multiplier = max(batch_count, num_groups)
gmem_bytes *= multiplier
flops *= multiplier
return flops / gmem_bytes
def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
):
profiler_reference_computing = "--verification-providers=device --providers=cutlass"
# beta values for L0 and L1
# TODO: randomize beta values for wider coverage
beta_values = [0.5]
@ -303,15 +303,9 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
]
sm100_mma_data_type_runtime_dtype = [
'gemm_f4_f4_f32_f32_f32',
'gemm_f6_f6_f32_f32_f32',
'gemm_f8_f8_f32_f32_f32',
]
sm100_mma_data_type_mergeable = [
'gemm_e4m3_e4m3_f32_f32_f32',# mask out one instance for verification
'gemm_e2m1_e2m1_f32_f32_f32',
'gemm_e3m2_e3m2_f32_f32_f32',
'gemm.*f4_f4_f32_f32_f32',
'gemm.*f6_f6_f32_f32_f32',
'gemm.*f8_f8_f32_f32_f32',
]
sm100_mma_cluster_size = [
@ -327,9 +321,6 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
]
# regex list must be in kernel procedural name order
mergeable_sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_mergeable, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
mergeable_sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_mergeable, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
@ -340,25 +331,15 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
# Block Scale Gemm
#
block_scaled_data_type_base = [
block_scaled_data_type = [
# runtime datatypes
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
'gemm.*ue4m3xf4_ue4m3xf4_f32_f16_e5m2',
'gemm.*ue8m0xf4_ue8m0xf6_f32_f16_e5m2',
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1',
'gemm.*ue8m0xf6_ue8m0xf6_f32_f16_ue8m0xe3m2',
]
block_scaled_data_type_mergeable = [
'gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
'gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
'gemm.*ue8m0xe2m1_ue8m0xe2m3_f32_f16_e5m2',
'gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1',
'gemm.*ue8m0xe2m3_ue8m0xe2m3_f32_f16_ue8m0xe3m2',
]
block_scaled_data_type = block_scaled_data_type_base + block_scaled_data_type_mergeable
block_scaled_cluster_size = [
'4x4x1', '2x1x1',
'0x0x1' # dynamic cluster
@ -366,27 +347,25 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
block_scaled_layouts = ['tnt']
# regex list must be in kernel procedural name order
mergeable_block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type_mergeable, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
mergeable_block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type_mergeable, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
if arch == "100a" or arch == "100f":
if arch in ["100a", "100f"]:
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
f"({sm100_mma_filter_regex_2sm})|" \
f"({sm100_mma_filter_regex_1sm_runtime})|" \
f"({sm100_mma_filter_regex_2sm_runtime})|" \
f"({block_scaled_filter_regex_1sm})|" \
f"({block_scaled_filter_regex_2sm})"
elif arch == "101a" or arch == "101f":
elif arch in ["101a", "101f",
]:
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
f"({sm100_mma_filter_regex_2sm})|" \
f"({sm100_mma_filter_regex_1sm_runtime})|" \
f"({sm100_mma_filter_regex_2sm_runtime})|" \
f"({block_scaled_filter_regex_1sm})|" \
f"({block_scaled_filter_regex_2sm})"
elif arch == "120a" or arch == "120f":
elif arch in ["120a", "120f"]:
# blockscaled sm120_mma kernels
blockscaled_sm120_mma_kernel_cta_tiles = [
@ -403,18 +382,8 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
else:
error_message = "unsupported arch, only support sm100a, sm100f, sm101a, sm101f, sm120a, sm120f"
raise Exception(error_message)
# Statically encoded kernels are still added to generated_kernels
# but are filtered out from the testing commands to reduce test duration.
# The mergeable_kernel_filter specifies the kernels that are already covered
# by the runtime datatype tests so that we safely mark them off
# without changing the test coverage.
mergeable_kernel_filter = f"({mergeable_sm100_mma_filter_regex_1sm})|" \
f"({mergeable_sm100_mma_filter_regex_2sm})|" \
f"({mergeable_block_scaled_filter_regex_1sm})|" \
f"({mergeable_block_scaled_filter_regex_2sm})"
elif mode == "functional_L1":
elif mode == "functional_L1":
sm100_mma_cluster_size = [
'0x0x1' # dynamic cluster
]
@ -486,10 +455,6 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
audit_file_params_name = os.path.join(curr_build_dir, f"FK_{mode}_audit_params_SM{arch}_cutlass3x_gemm.csv")
if is_runtime_datatype_enabled:
mergeable_kernel_filter_re = re.compile(mergeable_kernel_filter)
kernel_filter_re = re.compile(kernel_filter)
testcase_counter = 0
kernels_emitted = 0
@ -517,12 +482,6 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
if 'f16_f16_f16_void_f16' not in kernel_name :
continue
# Filter out the statically encoded tests which are
# covered by runtime datatype tests to avoid repetition.
if is_runtime_datatype_enabled and len(mergeable_kernel_filter_re.findall(kernel_name)) != 0:
continue
kernels_emitted += 1
kernel_name_set.add(kernel_name)
hashed_kernel_name = hash_cutlass_string(kernel_name)
@ -685,9 +644,18 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
gemm_op = "gemm"
profiler_reference_computing_override = profiler_reference_computing
grouped = is_grouped(manifest.operations_by_name[kernel_name].gemm_kind)
num_groups = 1
if "bstensorop" in kernel_name:
profiler_reference_computing_override = "--mode=trace"
if grouped:
gemm_op = "grouped_gemm"
num_groups = 3 # small to limit test time in host block-scaled reference kernels
batch_count = 1
elif "bstensorop" in kernel_name:
gemm_op = "block_scaled_gemm"
elif is_blockwise(manifest.operations_by_name[kernel_name].gemm_kind):
gemm_op = "blockwise_gemm"
problem_size_category = ['smallK','largeK'][index_k] + '_' + ['beta==0','beta!=0'][bool(beta)]
@ -704,7 +672,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
'n' : n,
'k' : k,
'beta' : beta,
'flops_per_byte' : _computeFlopsPerByte(operation, m, n, k, batch_count, beta)
'flops_per_byte' : _computeFlopsPerByte(operation, m, n, k, batch_count, beta, num_groups)
},
"runtime_params": {
'ctas_per_mma_instruction' : ctas_per_mma_instruction,
@ -732,6 +700,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
f" --m={str(m)}" +
f" --n={str(n)}" +
f" --k={str(k)}" +
(f" --num_groups={str(num_groups)}" if grouped else "") +
f" --cluster_m={str(cluster_shape_m)}" +
f" --cluster_n={str(cluster_shape_n)}" +
f" --cluster_k={str(cluster_shape_k)}" +
@ -739,7 +708,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
f" --cluster_n_fallback={str(cluster_n_fallback)}" +
f" --cluster_k_fallback={str(cluster_k_fallback)}" +
f" --beta={str(beta)}" +
f" --batch_count={str(batch_count)}" +
("" if grouped else f" --batch_count={str(batch_count)}") +
f" --swizzle_size={str(swizzle_size)}" +
f" --verification-required={str(verification_required).lower()}"
] \
@ -752,7 +721,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
testcase_metadata.append(json.dumps(metadata_dict))
testlist_csv_rows.append(testcase_metadata)
testcase_counter += 1
alpha = 1.0
if dynamic_datatype: