v4.1 release update v2. (#2481)

This commit is contained in:
Junkai-Wu
2025-07-22 10:03:55 +08:00
committed by GitHub
parent 9baa06dd57
commit fd6cfe1ed0
179 changed files with 7878 additions and 1286 deletions

View File

@ -278,7 +278,11 @@ def _computeFlopsPerByte(operation, m, n, k, batch_count=1, beta=0.0, num_groups
def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
):
profiler_reference_computing = "--verification-providers=device --providers=cutlass"
# For functional testing, we prefer to run reference computing on device if any
reference_device_archs = ["100a"]
run_reference_on_device = True if arch in reference_device_archs and mode in ["functional_L0", "functional_L1"] else False
profiler_flags_for_verification = "device" if run_reference_on_device else "host"
# beta values for L0 and L1
# TODO: randomize beta values for wider coverage
beta_values = [0.5]
@ -408,7 +412,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
filter_regex_sm100_mma = f"({sm100_mma_filter_regex_1sm})|" \
f"({sm100_mma_filter_regex_2sm})|" \
f"({block_scaled_filter_regex_1sm})|" \
f"({block_scaled_filter_regex_2sm})|"
f"({block_scaled_filter_regex_2sm})"
# CTA tiles for sm120 MMA - only run one tile size to reduce build/test times
sm120_mma_kernel_cta_tiles = [
# h1688, s1688, i16832, i8816
@ -545,11 +549,22 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
elif "ue8m0xf8_ue8m0xf8" in kernel_name:
runtime_input_datatypes = [['e4m3','e4m3']]
if "bstensorop" in kernel_name or is_blockwise(manifest.operations_by_name[kernel_name].gemm_kind):
profiler_flags_for_verification = "host"
# reduce L1 test runtime if reference kernel is not running on device.
if mode == "functional_L1" and profiler_flags_for_verification == "host" :
problem_waves = [0.5, 2.5]
if dynamic_cluster:
if mode == "functional_L0":
runtime_cluster_shapes = [[1,1,1], [2,2,1]]
else:
runtime_cluster_shapes = [[1,1,1], [1,2,1], [2,1,1], [2,2,1], [1,4,1], [4,1,1], [2,4,1], [4,2,1], [4,4,1]]
# reduce L1 test runtime if reference kernel is not running on device.
if profiler_flags_for_verification == "host":
runtime_cluster_shapes = [[1,1,1], [1,2,1], [2,1,1], [2,2,1], [1,4,1], [4,1,1]]
cta_tile_shape_m, cta_tile_shape_n, cta_tile_shape_k = operation.tile_description.threadblock_shape
else:
runtime_cluster_shapes = [operation.tile_description.cluster_shape]
@ -643,11 +658,8 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
batch_count = 3 if mode == "functional_L0" else 5
gemm_op = "gemm"
profiler_reference_computing_override = profiler_reference_computing
grouped = is_grouped(manifest.operations_by_name[kernel_name].gemm_kind)
num_groups = 1
if "bstensorop" in kernel_name:
profiler_reference_computing_override = "--mode=trace"
if grouped:
gemm_op = "grouped_gemm"
num_groups = 3 # small to limit test time in host block-scaled reference kernels
@ -695,7 +707,9 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
metadata_dict["runtime_params"]["runtime_datatype_b"] = runtime_datatype_b
testcase_metadata = [
f"cutlass_profiler --operation={gemm_op} {profiler_reference_computing_override} --error-on-no-match --error-if-nothing-is-profiled" +
f"cutlass_profiler --operation={gemm_op}" +
(f" --verification-providers=device --providers=cutlass" if profiler_flags_for_verification == "device" else " --mode=trace") +
f" --error-on-no-match --error-if-nothing-is-profiled" +
f" --kernels={kernel_name}" +
f" --m={str(m)}" +
f" --n={str(n)}" +