Updates for CUTLASS 3.4.1 (#1346)
* Updates for CUTLASS 3.4.1 * minor epi change
This commit is contained in:
@ -40,7 +40,7 @@ import cutlass_library
|
||||
def _cuda_install_path_from_nvcc() -> str:
|
||||
import subprocess
|
||||
# Attempt to detect CUDA_INSTALL_PATH based on location of NVCC
|
||||
result = subprocess.run(['which', 'nvcc'], capture_output=True)
|
||||
result = subprocess.run(['/usr/bin/which', 'nvcc'], capture_output=True)
|
||||
if result.returncode != 0:
|
||||
raise Exception(f'Unable to find nvcc via `which` utility.')
|
||||
|
||||
@ -121,7 +121,7 @@ def get_option_registry():
|
||||
this._option_registry = OptionRegistry(device_cc())
|
||||
return this._option_registry
|
||||
|
||||
this.__version__ = '3.4.0'
|
||||
this.__version__ = '3.4.1'
|
||||
|
||||
from cutlass.backend import create_memory_pool
|
||||
from cutlass.emit.pytorch import pytorch
|
||||
@ -169,7 +169,7 @@ def initialize_cuda_context():
|
||||
raise Exception("No CUDA devices found")
|
||||
device_id = 0
|
||||
|
||||
this._device_id = device_id
|
||||
this._device_id = int(device_id)
|
||||
|
||||
|
||||
def device_id() -> int:
|
||||
|
||||
@ -213,8 +213,12 @@ def get_mainloop_arguments_3x(
|
||||
return _MainloopArgumentsTma
|
||||
|
||||
|
||||
def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor, scheduler_args):
|
||||
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
|
||||
def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor, scheduler_args, default_epilogue):
|
||||
if not default_epilogue and hasattr(epilogue_functor, "epilogue_type_evt"):
|
||||
_EpilogueOutputOpParams = epilogue_functor.epilogue_type_evt
|
||||
else:
|
||||
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
|
||||
|
||||
if hasattr(epilogue_functor, "visitor"):
|
||||
class _EpilogueArguments(ctypes.Structure):
|
||||
_fields_ = [
|
||||
|
||||
@ -157,19 +157,41 @@ class LinearCombination(EpilogueFunctorBase):
|
||||
c_element_epilogue = dtype2ctype[self.element_epilogue]
|
||||
element_epilogue = self.element_epilogue
|
||||
|
||||
class _EpilogueOutputOpParams(ctypes.Structure):
|
||||
class _EpilogueOutputOpParamsEVT(ctypes.Structure):
|
||||
"""
|
||||
Epilogue params when using the default linear combination of EVT, which
|
||||
does not currently use {alpha,beta}_ptr_array
|
||||
"""
|
||||
_fields_ = [
|
||||
("alpha", c_element_epilogue),
|
||||
("beta", c_element_epilogue),
|
||||
("alpha_ptr", ctypes.c_void_p),
|
||||
("beta_ptr", ctypes.c_void_p)
|
||||
("beta_ptr", ctypes.c_void_p),
|
||||
]
|
||||
|
||||
def __init__(self, alpha, beta, *args) -> None:
|
||||
self.alpha = to_ctype_value(alpha, element_epilogue)
|
||||
self.beta = to_ctype_value(beta, element_epilogue)
|
||||
|
||||
class _EpilogueOutputOpParams(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("alpha", c_element_epilogue),
|
||||
("beta", c_element_epilogue),
|
||||
("alpha_ptr", ctypes.c_void_p),
|
||||
("beta_ptr", ctypes.c_void_p),
|
||||
("alpha_ptr_array", ctypes.c_void_p),
|
||||
("beta_ptr_array", ctypes.c_void_p),
|
||||
]
|
||||
|
||||
def __init__(self, alpha, beta, *args) -> None:
|
||||
self.alpha = to_ctype_value(alpha, element_epilogue)
|
||||
self.beta = to_ctype_value(beta, element_epilogue)
|
||||
|
||||
def to_evt_params(self) -> _EpilogueOutputOpParamsEVT:
|
||||
return _EpilogueOutputOpParamsEVT(self.alpha, self.beta)
|
||||
|
||||
self.epilogue_type = _EpilogueOutputOpParams
|
||||
self.epilogue_type_evt = _EpilogueOutputOpParamsEVT
|
||||
|
||||
def emit(self):
|
||||
return super().emit(self.tag, self.template_arguments)
|
||||
|
||||
@ -241,10 +241,10 @@ class EVTFrontendBase:
|
||||
:param name: the name of the graph
|
||||
"""
|
||||
drawer = EVTGraphDrawer(self.dag_ir, name)
|
||||
if drawer.dot_available:
|
||||
try:
|
||||
for name, graph in drawer.get_dot_graph():
|
||||
graph.write_svg(f"./{name}.svg")
|
||||
else:
|
||||
except:
|
||||
raise RuntimeError(
|
||||
"'dot' is not found in path. GraphDrawer is disabled. "
|
||||
"Please install it with 'sudo apt-get install graphviz'."
|
||||
|
||||
@ -61,22 +61,6 @@ class EVTGraphDrawer:
|
||||
self._dot_graphs = {}
|
||||
|
||||
self._dot_graphs[name] = self._to_dot(graph, name)
|
||||
self.dot_available = self._check_dot_availability()
|
||||
|
||||
def _check_dot_availability(self):
|
||||
"""
|
||||
Check if graphviz is installed
|
||||
"""
|
||||
try:
|
||||
# Run the 'dot' command and capture its output
|
||||
result = subprocess.run(
|
||||
["dot", "-V"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
# Check if the command was successful and the output contains version information
|
||||
if result.returncode == 0 and "dot - graphviz" in result.stderr:
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
return False
|
||||
|
||||
def _get_node_style(self, node):
|
||||
template = {
|
||||
|
||||
@ -325,7 +325,7 @@ class GemmArguments2x(ArgumentBase):
|
||||
def initialize(self):
|
||||
launch_config = self.operation.rt_module.plan(self)
|
||||
|
||||
# Get the host and evice workspace
|
||||
# Get the host and device workspace
|
||||
device_workspace_size = self.operation.rt_module.get_device_workspace_size(self)
|
||||
|
||||
if device_workspace_size > 0:
|
||||
@ -512,6 +512,18 @@ class GemmArguments3x(GemmArguments2x):
|
||||
super().__init__(operation, problem_size, A, B, C, D, gemm_mode, **kwargs)
|
||||
|
||||
def get_arguments(self):
|
||||
mainloop_args = get_mainloop_arguments_3x(
|
||||
self.operation.tile_description.kernel_schedule,
|
||||
self.operation.A.element,
|
||||
self.operation.B.element,
|
||||
self.operation.A.alignment,
|
||||
self.operation.B.alignment
|
||||
)
|
||||
scheduler_args = get_tile_scheduler_arguments_3x(self.operation.tile_description.tile_scheduler)
|
||||
uses_default_epilogue = self.operation.rt_module.uses_default_epilogue()
|
||||
argument_type, epilogue_args, epilogue_type, hw_info = get_gemm_arguments_3x(
|
||||
mainloop_args, self.operation.epilogue_functor, scheduler_args, uses_default_epilogue)
|
||||
|
||||
problem_size_ = GemmCoordBatched_(self.problem_size, self.batch_count)
|
||||
|
||||
if self.batch_count > 1:
|
||||
@ -539,9 +551,12 @@ class GemmArguments3x(GemmArguments2x):
|
||||
)
|
||||
|
||||
# Set of mainloop arguments needed for this kernel
|
||||
mainloop = self.operation.rt_module.mainloop_args.from_generic_mainloop_args(generic_args)
|
||||
mainloop = mainloop_args.from_generic_mainloop_args(generic_args)
|
||||
|
||||
epilogue = self.operation.rt_module.epilogue_args(
|
||||
if not uses_default_epilogue and hasattr(self.output_op, "to_evt_params"):
|
||||
self.output_op = self.output_op.to_evt_params()
|
||||
|
||||
epilogue = epilogue_args(
|
||||
self.output_op,
|
||||
int(self.ptr_C),
|
||||
stride_C,
|
||||
@ -550,15 +565,15 @@ class GemmArguments3x(GemmArguments2x):
|
||||
)
|
||||
|
||||
# Set hardware info
|
||||
hw_info = self.operation.rt_module.hw_info(0, device_sm_count())
|
||||
hw_info_ = hw_info(0, device_sm_count())
|
||||
|
||||
self.arguments = self.operation.argument_type(
|
||||
self.arguments = argument_type(
|
||||
int(self.gemm_mode),
|
||||
problem_size_,
|
||||
mainloop,
|
||||
epilogue,
|
||||
hw_info,
|
||||
self.operation.rt_module.scheduler_args
|
||||
hw_info_,
|
||||
scheduler_args
|
||||
)
|
||||
return self.arguments
|
||||
|
||||
@ -1119,6 +1134,10 @@ extern "C" {
|
||||
|
||||
using GemmType = ${operation_name}_base;
|
||||
|
||||
bool ${operation_name}_uses_default_epilogue() {
|
||||
return std::is_same_v<GemmType::CollectiveEpilogue::DispatchPolicy, cutlass::gemm::EpilogueDefault>;
|
||||
}
|
||||
|
||||
// Get the workspace size
|
||||
uint64_t ${operation_name}_get_kernel_workspace_size(GemmType::Arguments* argument) {
|
||||
return GemmType::get_workspace_size(*argument);
|
||||
@ -1163,19 +1182,10 @@ extern "C" {
|
||||
"get_grid_shape": dim3_,
|
||||
"get_block_shape": dim3_,
|
||||
"get_persistent_tiled_blk_shape_mnl": ctypes.c_uint64,
|
||||
"get_kernel_workspace_size": ctypes.c_uint64
|
||||
"get_kernel_workspace_size": ctypes.c_uint64,
|
||||
"uses_default_epilogue": ctypes.c_bool,
|
||||
}
|
||||
self.emitter = EmitGemmUniversalInstance3x("_type")
|
||||
self.mainloop_args = get_mainloop_arguments_3x(
|
||||
operation.tile_description.kernel_schedule,
|
||||
operation.A.element,
|
||||
operation.B.element,
|
||||
operation.A.alignment,
|
||||
operation.B.alignment
|
||||
)
|
||||
self.scheduler_args = get_tile_scheduler_arguments_3x(operation.tile_description.tile_scheduler)
|
||||
self.argument_type, self.epilogue_args, self.epilogue_type, self.hw_info = get_gemm_arguments_3x(
|
||||
self.mainloop_args, operation.epilogue_functor, self.scheduler_args)
|
||||
|
||||
def get_device_workspace_size(self, arguments: GemmArguments3x):
|
||||
return self.get_kernel_workspace_size(ctypes.byref(arguments.get_arguments()))
|
||||
|
||||
@ -36,7 +36,7 @@ from setuptools import setup
|
||||
def perform_setup():
|
||||
setup(
|
||||
name='cutlass_library',
|
||||
version='3.4.0',
|
||||
version='3.4.1',
|
||||
description='CUTLASS library generation scripts',
|
||||
packages=['cutlass_library']
|
||||
)
|
||||
|
||||
@ -36,7 +36,7 @@ from setuptools import setup
|
||||
def perform_setup():
|
||||
setup(
|
||||
name='pycute',
|
||||
version='3.4.0',
|
||||
version='3.4.1',
|
||||
description='Python implementation of CuTe',
|
||||
packages=['pycute'],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user