update 3.8 v2 (#2112)

* update 3.8 v2

* update 3.8

---------

Co-authored-by: yuzhai <yuzhai@nvidia.com>
This commit is contained in:
Yujia Zhai
2025-02-19 19:03:14 -08:00
committed by GitHub
parent e9627ce55b
commit b84e9802d8
166 changed files with 3986 additions and 4037 deletions

View File

@ -90,19 +90,32 @@ def hash_cutlass_string(input_string):
def transform_hashed_string(hashed_kernel_name, runtime_datatype_a, runtime_datatype_b):
# Define a dictionary mapping the detected types to runtime values
datatype_map = {
'_f4_': '_' + runtime_datatype_a + '_',
'_f6_': '_' + runtime_datatype_b + '_',
'_f8_': '_' + runtime_datatype_a + '_',
'f4_f4': runtime_datatype_a + '_' + runtime_datatype_b,
'f4_f6': runtime_datatype_a + '_' + runtime_datatype_b,
'f4_f8': runtime_datatype_a + '_' + runtime_datatype_b,
'f6_f4': runtime_datatype_a + '_' + runtime_datatype_b,
'f6_f6': runtime_datatype_a + '_' + runtime_datatype_b,
'f6_f8': runtime_datatype_a + '_' + runtime_datatype_b,
'f8_f4': runtime_datatype_a + '_' + runtime_datatype_b,
'f8_f6': runtime_datatype_a + '_' + runtime_datatype_b,
'f8_f8': runtime_datatype_a + '_' + runtime_datatype_b,
'ue8m0xf4_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
'ue4m3xf4_ue4m3xf4': 'ue4m3x' + runtime_datatype_a + '_ue4m3x' + runtime_datatype_b,
'ue8m0xf4_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
'ue8m0xf4_ue8m0xf8': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
'ue8m0xf6_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
'ue8m0xf6_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
'ue8m0xf8_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
'ue8m0xf8_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
'ue8m0xf8_ue8m0xf8': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
}
# Use regex to identify and replace _f4_, _f6_, or _f8_ in the kernel name
def substitute(match):
datatype = match.group(0) # This is the matched "_f4_", "_f6_", or "_f8_"
return datatype_map.get(datatype, datatype) # Replace or leave as is
# Regular expression to detect all the keys in datatype_map
pattern = re.compile(r'(' + '|'.join(map(re.escape, datatype_map.keys())) + r')')
# Replace detected patterns using the dictionary
updated_kernel_name = pattern.sub(lambda match: datatype_map[match.group(0)], hashed_kernel_name)
# Regex to find "_f4_", "_f6_", or "_f8_" in the hashed_kernel_name
updated_kernel_name = re.sub(r'_f4_|_f6_|_f8_', substitute, hashed_kernel_name)
return updated_kernel_name
# This helper function reports foundational kernel features: datatypes, layouts, alignment and stream-k.