Files
cutlass/python/cutlass_library/heuristics_provider.py
2025-09-15 12:21:53 -04:00

176 lines
6.5 KiB
Python

#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Providers for kernel selection heuristics
"""
import sys
import os
import glob
import logging
import ctypes
import functools
try:
import builtins
if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
raise ImportError("Disabling attempt to import cutlass_library")
from cutlass_library.library import DataType, LayoutType
except ImportError:
from library import DataType, LayoutType
class MatmulHeuristics:
def __init__(self, gpu = None):
import nvMatmulHeuristics
self.mmh_lib = nvMatmulHeuristics
self.gpu = gpu
if 'CUTLASS_NVMMH_SO_PATH' in os.environ:
nvmmhInterfaceEx = functools.partial(self.mmh_lib.NvMatmulHeuristicsInterfaceEx, path=os.environ['CUTLASS_NVMMH_SO_PATH'])
else:
nvmmhInterfaceEx = self.mmh_lib.NvMatmulHeuristicsInterfaceEx
self.lh = nvmmhInterfaceEx(
backend=self.mmh_lib.NvMatmulHeuristicsTarget["CUTLASS3"],
flags=self.mmh_lib.NvMatmulHeuristicsFlags.PERF_MODEL_BASED_AUTO_TUNING,
load_discovery_implicitly=True,
gpu=self.mmh_lib.NvMatmulHeuristicsNvidiaGpu[self.gpu] if self.gpu else None
)
self.backend = self.lh.createBackend(self.mmh_lib.NvMatmulHeuristicsTarget["CUTLASS3"])
def _layout_from_cutlass(self, layouts):
assert(len(layouts)==3)
full_layout_str = ''.join('t' if l == LayoutType.RowMajor else 'n' for l in layouts)
input_layouts = full_layout_str[:2].upper()
lh_layout = input_layouts + '_' + str("ROW_MAJOR" if full_layout_str[-1]=='t' else "COL_MAJOR")
return self.mmh_lib.NvMatmulHeuristicsMatmulLayout[lh_layout]
def _precision_from_cutlass_dtypes(self, dtypes):
dtype_to_cublas = {
DataType.f64: 'D',
DataType.f32: 'S',
DataType.f16: 'H',
DataType.bf16: 'T',
DataType.e4m3: 'Q',
DataType.e5m2: 'R',
DataType.s32: 'I',
DataType.s8: 'B',
}
dtype_a, dtype_b, dtype_compute, dtype_c, dtype_d = dtypes
a_c = dtype_to_cublas[dtype_a]
if a_c.lower() != 'q':
return a_c + dtype_to_cublas[dtype_compute] + dtype_to_cublas[dtype_d]
else:
return a_c + dtype_to_cublas[dtype_b] + dtype_to_cublas[dtype_c] + dtype_to_cublas[dtype_compute] + dtype_to_cublas[dtype_d]
def set_cta_div_n(self, div_n):
cta_n_div_requirement = ctypes.c_int(div_n)
self.lh.setBackendValueProperty(
self.backend,
self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_N_DIV_REQUIREMENT,
ctypes.byref(cta_n_div_requirement),
ctypes.sizeof(cta_n_div_requirement)
)
def set_cta_div_m(self, div_m):
cta_m_div_requirement = ctypes.c_int(div_m)
self.lh.setBackendValueProperty(
self.backend,
self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_M_DIV_REQUIREMENT,
ctypes.byref(cta_m_div_requirement),
ctypes.sizeof(cta_m_div_requirement)
)
def get_configs(self, m, n, k, batch_count, dtypes, layouts, align_a, align_b, voidC=False, use_fast_acc=True, count=1):
if use_fast_acc:
disable_fast_acc_for_fp8 = ctypes.c_int(0)
else:
disable_fast_acc_for_fp8 = ctypes.c_int(1)
self.lh.setBackendValueProperty(
self.backend,
self.mmh_lib.NvMatmulHeuristicsBackendProperty.DISABLE_FAST_ACC_FOR_FP8,
ctypes.byref(disable_fast_acc_for_fp8),
ctypes.sizeof(disable_fast_acc_for_fp8)
)
precision = self._precision_from_cutlass_dtypes(dtypes)
layout = self._layout_from_cutlass(layouts)
matmul_problem = self.lh.makeNvMatmulHeuristicsProblem(m, n, k, layout, batch_count)
configs = self.lh.getEx(matmul_problem, count, self.backend, precision=precision)
ret = []
for c in configs:
kernel = c['kernel']
problem = c['problem']
r = {}
r['estimated_runtime'] = c['runtime']
r['cta_tile_m'] = kernel.cta_tile_m
r['cta_tile_n'] = kernel.cta_tile_n
r['cta_tile_k'] = kernel.cta_tile_k
r['instr_tile_m'] = kernel.instr_tile_m
r['instr_tile_n'] = kernel.instr_tile_n
r['instr_tile_k'] = kernel.instr_tile_k
r['warp_tile_m'] = kernel.warp_tile_m
r['warp_tile_n'] = kernel.warp_tile_n
r['warp_tile_k'] = kernel.warp_tile_k
r['cluster_m'] = kernel.cluster_m
r['cluster_n'] = kernel.cluster_n
r['cluster_k'] = 1
r['layout_a'] = layouts[0]
r['layout_b'] = layouts[1]
r['layout_d'] = layouts[2]
r['dtype_a'] = dtypes[0]
r['dtype_b'] = dtypes[1]
r['dtype_acc'] = dtypes[2]
r['dtype_c'] = dtypes[3]
r['dtype_d'] = dtypes[4]
r['alignment_a'] = align_a
r['alignment_b'] = align_b
r['swizzle_size'] = kernel.swizzle_factor
r['raster_order'] = 'along_m' if kernel.cta_order==0 else 'along_n'
r['split_k_slices'] = kernel.split_k
r['use_fast_acc'] = use_fast_acc
r['voidC'] = voidC
ret.append(r)
return ret