update 3.8 v2 (#2112)
* update 3.8 v2 * update 3.8 --------- Co-authored-by: yuzhai <yuzhai@nvidia.com>
This commit is contained in:
@ -90,19 +90,32 @@ def hash_cutlass_string(input_string):
|
||||
def transform_hashed_string(hashed_kernel_name, runtime_datatype_a, runtime_datatype_b):
|
||||
# Define a dictionary mapping the detected types to runtime values
|
||||
datatype_map = {
|
||||
'_f4_': '_' + runtime_datatype_a + '_',
|
||||
'_f6_': '_' + runtime_datatype_b + '_',
|
||||
'_f8_': '_' + runtime_datatype_a + '_',
|
||||
'f4_f4': runtime_datatype_a + '_' + runtime_datatype_b,
|
||||
'f4_f6': runtime_datatype_a + '_' + runtime_datatype_b,
|
||||
'f4_f8': runtime_datatype_a + '_' + runtime_datatype_b,
|
||||
'f6_f4': runtime_datatype_a + '_' + runtime_datatype_b,
|
||||
'f6_f6': runtime_datatype_a + '_' + runtime_datatype_b,
|
||||
'f6_f8': runtime_datatype_a + '_' + runtime_datatype_b,
|
||||
'f8_f4': runtime_datatype_a + '_' + runtime_datatype_b,
|
||||
'f8_f6': runtime_datatype_a + '_' + runtime_datatype_b,
|
||||
'f8_f8': runtime_datatype_a + '_' + runtime_datatype_b,
|
||||
'ue8m0xf4_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
|
||||
'ue4m3xf4_ue4m3xf4': 'ue4m3x' + runtime_datatype_a + '_ue4m3x' + runtime_datatype_b,
|
||||
'ue8m0xf4_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
|
||||
'ue8m0xf4_ue8m0xf8': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
|
||||
'ue8m0xf6_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
|
||||
'ue8m0xf6_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
|
||||
'ue8m0xf8_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
|
||||
'ue8m0xf8_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
|
||||
'ue8m0xf8_ue8m0xf8': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
|
||||
}
|
||||
|
||||
# Use regex to identify and replace _f4_, _f6_, or _f8_ in the kernel name
|
||||
def substitute(match):
|
||||
datatype = match.group(0) # This is the matched "_f4_", "_f6_", or "_f8_"
|
||||
return datatype_map.get(datatype, datatype) # Replace or leave as is
|
||||
# Regular expression to detect all the keys in datatype_map
|
||||
pattern = re.compile(r'(' + '|'.join(map(re.escape, datatype_map.keys())) + r')')
|
||||
|
||||
# Replace detected patterns using the dictionary
|
||||
updated_kernel_name = pattern.sub(lambda match: datatype_map[match.group(0)], hashed_kernel_name)
|
||||
|
||||
# Regex to find "_f4_", "_f6_", or "_f8_" in the hashed_kernel_name
|
||||
updated_kernel_name = re.sub(r'_f4_|_f6_|_f8_', substitute, hashed_kernel_name)
|
||||
|
||||
return updated_kernel_name
|
||||
|
||||
# This helper function reports foundational kernel features: datatypes, layouts, alignment and stream-k.
|
||||
|
||||
Reference in New Issue
Block a user