Updates for 3.4 release. (#1305)
This commit is contained in:
@ -1,3 +1,35 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass.backend.arguments import *
|
||||
from cutlass.backend.c_types import *
|
||||
from cutlass.backend.compiler import ArtifactManager
|
||||
|
||||
@ -56,16 +56,9 @@ class ArgumentBase:
|
||||
**kwargs,
|
||||
) -> None:
|
||||
# tensor_C can be interpreted as the bias with bias=True in keyword args
|
||||
if "bias" in kwargs.keys():
|
||||
self.bias = kwargs["bias"]
|
||||
else:
|
||||
# by default, tensor_C is not bias
|
||||
self.bias = False
|
||||
self.bias = kwargs.get("bias", False)
|
||||
|
||||
if "stream" in kwargs.keys():
|
||||
self.stream = kwargs["stream"]
|
||||
else:
|
||||
self.stream = cuda.CUstream(0)
|
||||
self.stream = kwargs.get("stream", cuda.CUstream(0))
|
||||
|
||||
# RMM buffers used to track tensor lifetime
|
||||
self.buffers = {}
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
################################################################################
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved
|
||||
# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
################################################################################
|
||||
#################################################################################################
|
||||
|
||||
import ctypes
|
||||
from typing import Union
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
################################################################################
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved
|
||||
# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
################################################################################
|
||||
#################################################################################################
|
||||
|
||||
import ctypes
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
################################################################################
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved
|
||||
# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
################################################################################
|
||||
#################################################################################################
|
||||
|
||||
from cutlass.backend.evt.epilogue import EpilogueFunctorVisitor
|
||||
from cutlass.backend.evt.frontend import PythonASTFrontend
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
################################################################################
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved
|
||||
# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
################################################################################
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Epilogue Visitor interface for compiling, and running visitor-based epilogue.
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
################################################################################
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved
|
||||
# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
################################################################################
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Python registration for compute nodes in EVT
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
################################################################################
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved
|
||||
# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
################################################################################
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Layout algebras
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
################################################################################
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved
|
||||
# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
################################################################################
|
||||
#################################################################################################
|
||||
|
||||
from cuda import cuda
|
||||
import numpy as np
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
################################################################################
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
################################################################################
|
||||
#################################################################################################
|
||||
|
||||
import copy
|
||||
import ctypes
|
||||
@ -712,6 +712,8 @@ class GemmGroupedArguments:
|
||||
|
||||
self.gemm_arguments = []
|
||||
|
||||
self.stream = kwargs.get("stream", cuda.CUstream(0))
|
||||
|
||||
# Process the input arguments
|
||||
for idx, problem_size in enumerate(problem_sizes):
|
||||
M, N, K = problem_size.m, problem_size.n, problem_size.k
|
||||
@ -771,11 +773,6 @@ class GemmGroupedArguments:
|
||||
self.output_op = kwargs["output_op"]
|
||||
else:
|
||||
self.output_op = self.operation.epilogue_type(1.0, 0.0)
|
||||
|
||||
if "stream" in kwargs.keys():
|
||||
self.stream = kwargs["stream"]
|
||||
else:
|
||||
self.stream = cuda.CUstream(0)
|
||||
|
||||
# Get host problem size
|
||||
self.host_problem_size_ptr = np.array(problem_size_host, dtype=np.int32).__array_interface__["data"][0]
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
################################################################################
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved
|
||||
# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
################################################################################
|
||||
#################################################################################################
|
||||
|
||||
import ctypes
|
||||
|
||||
|
||||
@ -657,7 +657,10 @@ class _ArchListSetter:
|
||||
"""
|
||||
Restores the old value of TORCH_CUDA_ARCH_LIST
|
||||
"""
|
||||
os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST] = self.old_arch_list
|
||||
if self.old_arch_list is None:
|
||||
del os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST]
|
||||
else:
|
||||
os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST] = self.old_arch_list
|
||||
|
||||
|
||||
def _jit(name: str, cc: int, cpp_file: str, cuda_file: str):
|
||||
|
||||
@ -112,6 +112,7 @@
|
||||
args.sync()
|
||||
"""
|
||||
|
||||
from cuda import cuda
|
||||
from cutlass_library import (
|
||||
ConvKind,
|
||||
ConvMode,
|
||||
@ -131,7 +132,6 @@ from cutlass.backend.library import TensorDescription, TileDescription
|
||||
from cutlass.op.op import OperationBase
|
||||
from cutlass.shape import Conv2DProblemSize, MatrixCoord
|
||||
from cutlass.utils import check, datatypes
|
||||
from cuda import cuda
|
||||
|
||||
|
||||
class Conv2d(OperationBase):
|
||||
|
||||
@ -116,6 +116,7 @@
|
||||
|
||||
from math import prod
|
||||
|
||||
from cuda import cuda
|
||||
from cutlass_library import (
|
||||
DataType,
|
||||
DataTypeSize,
|
||||
@ -131,7 +132,6 @@ from cutlass.backend.library import TensorDescription, TileDescription
|
||||
from cutlass.op.op import OperationBase
|
||||
from cutlass.shape import GemmCoord
|
||||
from cutlass.utils import check, datatypes
|
||||
from cuda import cuda
|
||||
|
||||
|
||||
class Gemm(OperationBase):
|
||||
@ -691,6 +691,7 @@ class Gemm(OperationBase):
|
||||
'D': self._get_batch_stride(D)
|
||||
}
|
||||
}
|
||||
|
||||
kwargs['stream'] = stream
|
||||
|
||||
if isinstance(self.epilogue_functor, EpilogueFunctorVisitor):
|
||||
|
||||
@ -53,6 +53,7 @@
|
||||
|
||||
from cutlass_library import DataTypeSize
|
||||
|
||||
from cuda import cuda
|
||||
from cutlass.backend.gemm_operation import (
|
||||
GemmGroupedArguments,
|
||||
GemmOperationGrouped,
|
||||
@ -65,7 +66,6 @@ from cutlass.backend.library import (
|
||||
from cutlass.op.gemm import Gemm
|
||||
from cutlass.shape import GemmCoord
|
||||
from cutlass.utils import check, datatypes
|
||||
from cuda import cuda
|
||||
|
||||
|
||||
class GroupedGemm(Gemm):
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
################################################################################
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved
|
||||
# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
################################################################################
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Utilities for expressing shapes
|
||||
|
||||
@ -1,3 +1,35 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
# Configuration file for the Sphinx documentation builder.
|
||||
#
|
||||
# For the full list of built-in configuration values, see the documentation:
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 'AS IS'
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 'AS IS'
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 'AS IS'
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
|
||||
Reference in New Issue
Block a user