v4.0 update. (#2371)
This commit is contained in:
@ -253,7 +253,8 @@ def _getInstType(input_precision, accumulate_precision, math_instruction):
|
||||
|
||||
return inst
|
||||
# TODO: Computes FLOps/Bytes for GEMM - revisit for conv
|
||||
def _computeFlopsPerByte(operation, m, n, k, batch_count=1, beta=0.0):
|
||||
def _computeFlopsPerByte(operation, m, n, k, batch_count=1, beta=0.0, num_groups=1):
|
||||
assert not (batch_count > 1 and num_groups > 1)
|
||||
|
||||
# TODO: adjust for sparsity
|
||||
gmem_bytes = (
|
||||
@ -269,16 +270,15 @@ def _computeFlopsPerByte(operation, m, n, k, batch_count=1, beta=0.0):
|
||||
gmem_bytes += (DataTypeSize[operation.C.element] * m // 8) * n
|
||||
flops += 2 * m * n
|
||||
|
||||
gmem_bytes *= batch_count
|
||||
flops *= batch_count
|
||||
multiplier = max(batch_count, num_groups)
|
||||
gmem_bytes *= multiplier
|
||||
flops *= multiplier
|
||||
|
||||
return flops / gmem_bytes
|
||||
|
||||
def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
):
|
||||
profiler_reference_computing = "--verification-providers=device --providers=cutlass"
|
||||
|
||||
|
||||
# beta values for L0 and L1
|
||||
# TODO: randomize beta values for wider coverage
|
||||
beta_values = [0.5]
|
||||
@ -303,15 +303,9 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
]
|
||||
|
||||
sm100_mma_data_type_runtime_dtype = [
|
||||
'gemm_f4_f4_f32_f32_f32',
|
||||
'gemm_f6_f6_f32_f32_f32',
|
||||
'gemm_f8_f8_f32_f32_f32',
|
||||
]
|
||||
|
||||
sm100_mma_data_type_mergeable = [
|
||||
'gemm_e4m3_e4m3_f32_f32_f32',# mask out one instance for verification
|
||||
'gemm_e2m1_e2m1_f32_f32_f32',
|
||||
'gemm_e3m2_e3m2_f32_f32_f32',
|
||||
'gemm.*f4_f4_f32_f32_f32',
|
||||
'gemm.*f6_f6_f32_f32_f32',
|
||||
'gemm.*f8_f8_f32_f32_f32',
|
||||
]
|
||||
|
||||
sm100_mma_cluster_size = [
|
||||
@ -327,9 +321,6 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
]
|
||||
|
||||
# regex list must be in kernel procedural name order
|
||||
mergeable_sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_mergeable, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
|
||||
mergeable_sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_mergeable, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
|
||||
|
||||
sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
|
||||
sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
|
||||
|
||||
@ -340,25 +331,15 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
# Block Scale Gemm
|
||||
#
|
||||
|
||||
block_scaled_data_type_base = [
|
||||
block_scaled_data_type = [
|
||||
# runtime datatypes
|
||||
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
|
||||
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
|
||||
'gemm.*ue4m3xf4_ue4m3xf4_f32_f16_e5m2',
|
||||
'gemm.*ue8m0xf4_ue8m0xf6_f32_f16_e5m2',
|
||||
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1',
|
||||
'gemm.*ue8m0xf6_ue8m0xf6_f32_f16_ue8m0xe3m2',
|
||||
]
|
||||
|
||||
block_scaled_data_type_mergeable = [
|
||||
'gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
|
||||
'gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
|
||||
'gemm.*ue8m0xe2m1_ue8m0xe2m3_f32_f16_e5m2',
|
||||
'gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1',
|
||||
'gemm.*ue8m0xe2m3_ue8m0xe2m3_f32_f16_ue8m0xe3m2',
|
||||
]
|
||||
|
||||
block_scaled_data_type = block_scaled_data_type_base + block_scaled_data_type_mergeable
|
||||
|
||||
block_scaled_cluster_size = [
|
||||
'4x4x1', '2x1x1',
|
||||
'0x0x1' # dynamic cluster
|
||||
@ -366,27 +347,25 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
|
||||
block_scaled_layouts = ['tnt']
|
||||
# regex list must be in kernel procedural name order
|
||||
mergeable_block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type_mergeable, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
|
||||
mergeable_block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type_mergeable, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
|
||||
|
||||
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
|
||||
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
|
||||
|
||||
if arch == "100a" or arch == "100f":
|
||||
if arch in ["100a", "100f"]:
|
||||
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
|
||||
f"({sm100_mma_filter_regex_2sm})|" \
|
||||
f"({sm100_mma_filter_regex_1sm_runtime})|" \
|
||||
f"({sm100_mma_filter_regex_2sm_runtime})|" \
|
||||
f"({block_scaled_filter_regex_1sm})|" \
|
||||
f"({block_scaled_filter_regex_2sm})"
|
||||
elif arch == "101a" or arch == "101f":
|
||||
elif arch in ["101a", "101f",
|
||||
]:
|
||||
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
|
||||
f"({sm100_mma_filter_regex_2sm})|" \
|
||||
f"({sm100_mma_filter_regex_1sm_runtime})|" \
|
||||
f"({sm100_mma_filter_regex_2sm_runtime})|" \
|
||||
f"({block_scaled_filter_regex_1sm})|" \
|
||||
f"({block_scaled_filter_regex_2sm})"
|
||||
elif arch == "120a" or arch == "120f":
|
||||
elif arch in ["120a", "120f"]:
|
||||
|
||||
# blockscaled sm120_mma kernels
|
||||
blockscaled_sm120_mma_kernel_cta_tiles = [
|
||||
@ -403,18 +382,8 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
else:
|
||||
error_message = "unsupported arch, only support sm100a, sm100f, sm101a, sm101f, sm120a, sm120f"
|
||||
raise Exception(error_message)
|
||||
|
||||
# Statically encoded kernels are still added to generated_kernels
|
||||
# but are filtered out from the testing commands to reduce test duration.
|
||||
# The mergeable_kernel_filter specifies the kernels that are already covered
|
||||
# by the runtime datatype tests so that we safely mark them off
|
||||
# without changing the test coverage.
|
||||
mergeable_kernel_filter = f"({mergeable_sm100_mma_filter_regex_1sm})|" \
|
||||
f"({mergeable_sm100_mma_filter_regex_2sm})|" \
|
||||
f"({mergeable_block_scaled_filter_regex_1sm})|" \
|
||||
f"({mergeable_block_scaled_filter_regex_2sm})"
|
||||
elif mode == "functional_L1":
|
||||
|
||||
elif mode == "functional_L1":
|
||||
sm100_mma_cluster_size = [
|
||||
'0x0x1' # dynamic cluster
|
||||
]
|
||||
@ -486,10 +455,6 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
|
||||
audit_file_params_name = os.path.join(curr_build_dir, f"FK_{mode}_audit_params_SM{arch}_cutlass3x_gemm.csv")
|
||||
|
||||
if is_runtime_datatype_enabled:
|
||||
mergeable_kernel_filter_re = re.compile(mergeable_kernel_filter)
|
||||
|
||||
|
||||
kernel_filter_re = re.compile(kernel_filter)
|
||||
testcase_counter = 0
|
||||
kernels_emitted = 0
|
||||
@ -517,12 +482,6 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
if 'f16_f16_f16_void_f16' not in kernel_name :
|
||||
continue
|
||||
|
||||
# Filter out the statically encoded tests which are
|
||||
# covered by runtime datatype tests to avoid repetition.
|
||||
if is_runtime_datatype_enabled and len(mergeable_kernel_filter_re.findall(kernel_name)) != 0:
|
||||
continue
|
||||
|
||||
|
||||
kernels_emitted += 1
|
||||
kernel_name_set.add(kernel_name)
|
||||
hashed_kernel_name = hash_cutlass_string(kernel_name)
|
||||
@ -685,9 +644,18 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
gemm_op = "gemm"
|
||||
|
||||
profiler_reference_computing_override = profiler_reference_computing
|
||||
grouped = is_grouped(manifest.operations_by_name[kernel_name].gemm_kind)
|
||||
num_groups = 1
|
||||
if "bstensorop" in kernel_name:
|
||||
profiler_reference_computing_override = "--mode=trace"
|
||||
if grouped:
|
||||
gemm_op = "grouped_gemm"
|
||||
num_groups = 3 # small to limit test time in host block-scaled reference kernels
|
||||
batch_count = 1
|
||||
elif "bstensorop" in kernel_name:
|
||||
gemm_op = "block_scaled_gemm"
|
||||
elif is_blockwise(manifest.operations_by_name[kernel_name].gemm_kind):
|
||||
gemm_op = "blockwise_gemm"
|
||||
|
||||
problem_size_category = ['smallK','largeK'][index_k] + '_' + ['beta==0','beta!=0'][bool(beta)]
|
||||
|
||||
@ -704,7 +672,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
'n' : n,
|
||||
'k' : k,
|
||||
'beta' : beta,
|
||||
'flops_per_byte' : _computeFlopsPerByte(operation, m, n, k, batch_count, beta)
|
||||
'flops_per_byte' : _computeFlopsPerByte(operation, m, n, k, batch_count, beta, num_groups)
|
||||
},
|
||||
"runtime_params": {
|
||||
'ctas_per_mma_instruction' : ctas_per_mma_instruction,
|
||||
@ -732,6 +700,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
f" --m={str(m)}" +
|
||||
f" --n={str(n)}" +
|
||||
f" --k={str(k)}" +
|
||||
(f" --num_groups={str(num_groups)}" if grouped else "") +
|
||||
f" --cluster_m={str(cluster_shape_m)}" +
|
||||
f" --cluster_n={str(cluster_shape_n)}" +
|
||||
f" --cluster_k={str(cluster_shape_k)}" +
|
||||
@ -739,7 +708,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
f" --cluster_n_fallback={str(cluster_n_fallback)}" +
|
||||
f" --cluster_k_fallback={str(cluster_k_fallback)}" +
|
||||
f" --beta={str(beta)}" +
|
||||
f" --batch_count={str(batch_count)}" +
|
||||
("" if grouped else f" --batch_count={str(batch_count)}") +
|
||||
f" --swizzle_size={str(swizzle_size)}" +
|
||||
f" --verification-required={str(verification_required).lower()}"
|
||||
] \
|
||||
@ -752,7 +721,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
testcase_metadata.append(json.dumps(metadata_dict))
|
||||
testlist_csv_rows.append(testcase_metadata)
|
||||
testcase_counter += 1
|
||||
|
||||
|
||||
alpha = 1.0
|
||||
|
||||
if dynamic_datatype:
|
||||
|
||||
Reference in New Issue
Block a user