Updates for CUTLASS 3.4.1 (#1346)

* Updates for CUTLASS 3.4.1

* minor epi change
This commit is contained in:
ANIKET SHIVAM
2024-02-15 12:48:34 -08:00
committed by GitHub
parent 47a3ebbea9
commit bbe579a9e3
49 changed files with 800 additions and 451 deletions

View File

@ -40,7 +40,7 @@ import cutlass_library
def _cuda_install_path_from_nvcc() -> str:
import subprocess
# Attempt to detect CUDA_INSTALL_PATH based on location of NVCC
result = subprocess.run(['which', 'nvcc'], capture_output=True)
result = subprocess.run(['/usr/bin/which', 'nvcc'], capture_output=True)
if result.returncode != 0:
raise Exception(f'Unable to find nvcc via `which` utility.')
@ -121,7 +121,7 @@ def get_option_registry():
this._option_registry = OptionRegistry(device_cc())
return this._option_registry
this.__version__ = '3.4.0'
this.__version__ = '3.4.1'
from cutlass.backend import create_memory_pool
from cutlass.emit.pytorch import pytorch
@ -169,7 +169,7 @@ def initialize_cuda_context():
raise Exception("No CUDA devices found")
device_id = 0
this._device_id = device_id
this._device_id = int(device_id)
def device_id() -> int:

View File

@ -213,8 +213,12 @@ def get_mainloop_arguments_3x(
return _MainloopArgumentsTma
def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor, scheduler_args):
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor, scheduler_args, default_epilogue):
if not default_epilogue and hasattr(epilogue_functor, "epilogue_type_evt"):
_EpilogueOutputOpParams = epilogue_functor.epilogue_type_evt
else:
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
if hasattr(epilogue_functor, "visitor"):
class _EpilogueArguments(ctypes.Structure):
_fields_ = [

View File

@ -157,19 +157,41 @@ class LinearCombination(EpilogueFunctorBase):
c_element_epilogue = dtype2ctype[self.element_epilogue]
element_epilogue = self.element_epilogue
class _EpilogueOutputOpParams(ctypes.Structure):
class _EpilogueOutputOpParamsEVT(ctypes.Structure):
"""
Epilogue params when using the default linear combination of EVT, which
does not currently use {alpha,beta}_ptr_array
"""
_fields_ = [
("alpha", c_element_epilogue),
("beta", c_element_epilogue),
("alpha_ptr", ctypes.c_void_p),
("beta_ptr", ctypes.c_void_p)
("beta_ptr", ctypes.c_void_p),
]
def __init__(self, alpha, beta, *args) -> None:
self.alpha = to_ctype_value(alpha, element_epilogue)
self.beta = to_ctype_value(beta, element_epilogue)
class _EpilogueOutputOpParams(ctypes.Structure):
_fields_ = [
("alpha", c_element_epilogue),
("beta", c_element_epilogue),
("alpha_ptr", ctypes.c_void_p),
("beta_ptr", ctypes.c_void_p),
("alpha_ptr_array", ctypes.c_void_p),
("beta_ptr_array", ctypes.c_void_p),
]
def __init__(self, alpha, beta, *args) -> None:
self.alpha = to_ctype_value(alpha, element_epilogue)
self.beta = to_ctype_value(beta, element_epilogue)
def to_evt_params(self) -> _EpilogueOutputOpParamsEVT:
return _EpilogueOutputOpParamsEVT(self.alpha, self.beta)
self.epilogue_type = _EpilogueOutputOpParams
self.epilogue_type_evt = _EpilogueOutputOpParamsEVT
def emit(self):
return super().emit(self.tag, self.template_arguments)

View File

@ -241,10 +241,10 @@ class EVTFrontendBase:
:param name: the name of the graph
"""
drawer = EVTGraphDrawer(self.dag_ir, name)
if drawer.dot_available:
try:
for name, graph in drawer.get_dot_graph():
graph.write_svg(f"./{name}.svg")
else:
except:
raise RuntimeError(
"'dot' is not found in path. GraphDrawer is disabled. "
"Please install it with 'sudo apt-get install graphviz'."

View File

@ -61,22 +61,6 @@ class EVTGraphDrawer:
self._dot_graphs = {}
self._dot_graphs[name] = self._to_dot(graph, name)
self.dot_available = self._check_dot_availability()
def _check_dot_availability(self):
"""
Check if graphviz is installed
"""
try:
# Run the 'dot' command and capture its output
result = subprocess.run(
["dot", "-V"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
# Check if the command was successful and the output contains version information
if result.returncode == 0 and "dot - graphviz" in result.stderr:
return True
except FileNotFoundError:
pass
return False
def _get_node_style(self, node):
template = {

View File

@ -325,7 +325,7 @@ class GemmArguments2x(ArgumentBase):
def initialize(self):
launch_config = self.operation.rt_module.plan(self)
# Get the host and evice workspace
# Get the host and device workspace
device_workspace_size = self.operation.rt_module.get_device_workspace_size(self)
if device_workspace_size > 0:
@ -512,6 +512,18 @@ class GemmArguments3x(GemmArguments2x):
super().__init__(operation, problem_size, A, B, C, D, gemm_mode, **kwargs)
def get_arguments(self):
mainloop_args = get_mainloop_arguments_3x(
self.operation.tile_description.kernel_schedule,
self.operation.A.element,
self.operation.B.element,
self.operation.A.alignment,
self.operation.B.alignment
)
scheduler_args = get_tile_scheduler_arguments_3x(self.operation.tile_description.tile_scheduler)
uses_default_epilogue = self.operation.rt_module.uses_default_epilogue()
argument_type, epilogue_args, epilogue_type, hw_info = get_gemm_arguments_3x(
mainloop_args, self.operation.epilogue_functor, scheduler_args, uses_default_epilogue)
problem_size_ = GemmCoordBatched_(self.problem_size, self.batch_count)
if self.batch_count > 1:
@ -539,9 +551,12 @@ class GemmArguments3x(GemmArguments2x):
)
# Set of mainloop arguments needed for this kernel
mainloop = self.operation.rt_module.mainloop_args.from_generic_mainloop_args(generic_args)
mainloop = mainloop_args.from_generic_mainloop_args(generic_args)
epilogue = self.operation.rt_module.epilogue_args(
if not uses_default_epilogue and hasattr(self.output_op, "to_evt_params"):
self.output_op = self.output_op.to_evt_params()
epilogue = epilogue_args(
self.output_op,
int(self.ptr_C),
stride_C,
@ -550,15 +565,15 @@ class GemmArguments3x(GemmArguments2x):
)
# Set hardware info
hw_info = self.operation.rt_module.hw_info(0, device_sm_count())
hw_info_ = hw_info(0, device_sm_count())
self.arguments = self.operation.argument_type(
self.arguments = argument_type(
int(self.gemm_mode),
problem_size_,
mainloop,
epilogue,
hw_info,
self.operation.rt_module.scheduler_args
hw_info_,
scheduler_args
)
return self.arguments
@ -1119,6 +1134,10 @@ extern "C" {
using GemmType = ${operation_name}_base;
bool ${operation_name}_uses_default_epilogue() {
return std::is_same_v<GemmType::CollectiveEpilogue::DispatchPolicy, cutlass::gemm::EpilogueDefault>;
}
// Get the workspace size
uint64_t ${operation_name}_get_kernel_workspace_size(GemmType::Arguments* argument) {
return GemmType::get_workspace_size(*argument);
@ -1163,19 +1182,10 @@ extern "C" {
"get_grid_shape": dim3_,
"get_block_shape": dim3_,
"get_persistent_tiled_blk_shape_mnl": ctypes.c_uint64,
"get_kernel_workspace_size": ctypes.c_uint64
"get_kernel_workspace_size": ctypes.c_uint64,
"uses_default_epilogue": ctypes.c_bool,
}
self.emitter = EmitGemmUniversalInstance3x("_type")
self.mainloop_args = get_mainloop_arguments_3x(
operation.tile_description.kernel_schedule,
operation.A.element,
operation.B.element,
operation.A.alignment,
operation.B.alignment
)
self.scheduler_args = get_tile_scheduler_arguments_3x(operation.tile_description.tile_scheduler)
self.argument_type, self.epilogue_args, self.epilogue_type, self.hw_info = get_gemm_arguments_3x(
self.mainloop_args, operation.epilogue_functor, self.scheduler_args)
def get_device_workspace_size(self, arguments: GemmArguments3x):
return self.get_kernel_workspace_size(ctypes.byref(arguments.get_arguments()))

View File

@ -36,7 +36,7 @@ from setuptools import setup
def perform_setup():
setup(
name='cutlass_library',
version='3.4.0',
version='3.4.1',
description='CUTLASS library generation scripts',
packages=['cutlass_library']
)

View File

@ -36,7 +36,7 @@ from setuptools import setup
def perform_setup():
setup(
name='pycute',
version='3.4.0',
version='3.4.1',
description='Python implementation of CuTe',
packages=['pycute'],
)