v4.1 release update v2. (#2481)
This commit is contained in:
@ -278,7 +278,11 @@ def _computeFlopsPerByte(operation, m, n, k, batch_count=1, beta=0.0, num_groups
|
||||
|
||||
def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
):
|
||||
profiler_reference_computing = "--verification-providers=device --providers=cutlass"
|
||||
# For functional testing, we prefer to run reference computing on device if any
|
||||
reference_device_archs = ["100a"]
|
||||
run_reference_on_device = True if arch in reference_device_archs and mode in ["functional_L0", "functional_L1"] else False
|
||||
profiler_flags_for_verification = "device" if run_reference_on_device else "host"
|
||||
|
||||
# beta values for L0 and L1
|
||||
# TODO: randomize beta values for wider coverage
|
||||
beta_values = [0.5]
|
||||
@ -408,7 +412,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
filter_regex_sm100_mma = f"({sm100_mma_filter_regex_1sm})|" \
|
||||
f"({sm100_mma_filter_regex_2sm})|" \
|
||||
f"({block_scaled_filter_regex_1sm})|" \
|
||||
f"({block_scaled_filter_regex_2sm})|"
|
||||
f"({block_scaled_filter_regex_2sm})"
|
||||
# CTA tiles for sm120 MMA - only run one tile size to reduce build/test times
|
||||
sm120_mma_kernel_cta_tiles = [
|
||||
# h1688, s1688, i16832, i8816
|
||||
@ -545,11 +549,22 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
elif "ue8m0xf8_ue8m0xf8" in kernel_name:
|
||||
runtime_input_datatypes = [['e4m3','e4m3']]
|
||||
|
||||
if "bstensorop" in kernel_name or is_blockwise(manifest.operations_by_name[kernel_name].gemm_kind):
|
||||
profiler_flags_for_verification = "host"
|
||||
|
||||
# reduce L1 test runtime if reference kernel is not running on device.
|
||||
if mode == "functional_L1" and profiler_flags_for_verification == "host" :
|
||||
problem_waves = [0.5, 2.5]
|
||||
|
||||
|
||||
if dynamic_cluster:
|
||||
if mode == "functional_L0":
|
||||
runtime_cluster_shapes = [[1,1,1], [2,2,1]]
|
||||
else:
|
||||
runtime_cluster_shapes = [[1,1,1], [1,2,1], [2,1,1], [2,2,1], [1,4,1], [4,1,1], [2,4,1], [4,2,1], [4,4,1]]
|
||||
# reduce L1 test runtime if reference kernel is not running on device.
|
||||
if profiler_flags_for_verification == "host":
|
||||
runtime_cluster_shapes = [[1,1,1], [1,2,1], [2,1,1], [2,2,1], [1,4,1], [4,1,1]]
|
||||
cta_tile_shape_m, cta_tile_shape_n, cta_tile_shape_k = operation.tile_description.threadblock_shape
|
||||
else:
|
||||
runtime_cluster_shapes = [operation.tile_description.cluster_shape]
|
||||
@ -643,11 +658,8 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
batch_count = 3 if mode == "functional_L0" else 5
|
||||
gemm_op = "gemm"
|
||||
|
||||
profiler_reference_computing_override = profiler_reference_computing
|
||||
grouped = is_grouped(manifest.operations_by_name[kernel_name].gemm_kind)
|
||||
num_groups = 1
|
||||
if "bstensorop" in kernel_name:
|
||||
profiler_reference_computing_override = "--mode=trace"
|
||||
if grouped:
|
||||
gemm_op = "grouped_gemm"
|
||||
num_groups = 3 # small to limit test time in host block-scaled reference kernels
|
||||
@ -695,7 +707,9 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
metadata_dict["runtime_params"]["runtime_datatype_b"] = runtime_datatype_b
|
||||
|
||||
testcase_metadata = [
|
||||
f"cutlass_profiler --operation={gemm_op} {profiler_reference_computing_override} --error-on-no-match --error-if-nothing-is-profiled" +
|
||||
f"cutlass_profiler --operation={gemm_op}" +
|
||||
(f" --verification-providers=device --providers=cutlass" if profiler_flags_for_verification == "device" else " --mode=trace") +
|
||||
f" --error-on-no-match --error-if-nothing-is-profiled" +
|
||||
f" --kernels={kernel_name}" +
|
||||
f" --m={str(m)}" +
|
||||
f" --n={str(n)}" +
|
||||
|
||||
Reference in New Issue
Block a user