CUTLASS 3.6.0 (#1850)
* v3.6 * update changelog * update readme * fix typo * fixing typos * hopper gemm with weight prefetch --------- Co-authored-by: yuzhai <yuzhai@nvidia.com> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
@ -121,7 +121,7 @@ def get_option_registry():
|
||||
this._option_registry = OptionRegistry(device_cc())
|
||||
return this._option_registry
|
||||
|
||||
this.__version__ = '3.5.1'
|
||||
this.__version__ = '3.6.0'
|
||||
|
||||
from cutlass.backend import create_memory_pool
|
||||
from cutlass.emit.pytorch import pytorch
|
||||
|
||||
@ -37,7 +37,7 @@ import numpy as np
|
||||
from scipy.special import erf
|
||||
|
||||
from cutlass_library import DataType, DataTypeTag
|
||||
from cutlass.backend.c_types import MatrixCoord_
|
||||
from cutlass.backend.c_types import MatrixCoord_, tuple_factory
|
||||
from cutlass.backend.frontend import NumpyFrontend
|
||||
from cutlass.backend.library import ActivationOp, ActivationOpTag
|
||||
from cutlass.utils.datatypes import is_numpy_tensor, is_torch_available, is_torch_tensor
|
||||
@ -162,11 +162,15 @@ class LinearCombination(EpilogueFunctorBase):
|
||||
Epilogue params when using the default linear combination of EVT, which
|
||||
does not currently use {alpha,beta}_ptr_array
|
||||
"""
|
||||
|
||||
stride_type = tuple_factory((0,0,1), "int64_t", [0])
|
||||
_fields_ = [
|
||||
("alpha", c_element_epilogue),
|
||||
("beta", c_element_epilogue),
|
||||
("alpha_ptr", ctypes.c_void_p),
|
||||
("beta_ptr", ctypes.c_void_p),
|
||||
("dalpha", stride_type),
|
||||
("dbeta", stride_type),
|
||||
]
|
||||
|
||||
def __init__(self, alpha, beta, *args) -> None:
|
||||
|
||||
@ -164,7 +164,7 @@ class Sm90RowBroadcastImpl(RowBroadcastImpl):
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]},
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]}, {DataTypeTag[self.element_output]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
@ -183,7 +183,7 @@ class Sm90ColumnBroadcastImpl(ColumnBroadcastImpl):
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]},
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]}, {DataTypeTag[self.element_output]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
|
||||
@ -253,8 +253,8 @@ _CUTLASS_TYPE_TO_TORCH_TYPE = {
|
||||
DataType.f16: "torch::kF16",
|
||||
DataType.f32: "torch::kF32",
|
||||
DataType.f64: "torch::kF64",
|
||||
DataType.s8: "torch::I8",
|
||||
DataType.s32: "torch::I32",
|
||||
DataType.s8: "torch::kI8",
|
||||
DataType.s32: "torch::kI32",
|
||||
}
|
||||
|
||||
_PYTORCH_GEMM_IMPL_TEMPLATE_2x = (
|
||||
|
||||
Reference in New Issue
Block a user