CUTLASS 3.0.0 (#786)

* CUTLASS 3.0.0
This commit is contained in:
Vijay Thakkar
2023-01-23 17:55:28 -08:00
committed by GitHub
parent 66d9cddc83
commit 277bd6e537
377 changed files with 76396 additions and 1186 deletions

153
include/cute/util/debug.hpp Normal file
View File

@ -0,0 +1,153 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
/**
* \file
* \brief Debugging and logging functionality
*/
#include <cuda_runtime_api.h>
#include <cute/config.hpp>
namespace cute
{
/******************************************************************************
* Debug and logging macros
******************************************************************************/
/**
* Formats and prints the given message to stdout
*/
#if !defined(CUTE_LOG)
# if !defined(__CUDA_ARCH__)
# define CUTE_LOG(format, ...) printf(format, __VA_ARGS__)
# else
# define CUTE_LOG(format, ...) \
printf("[block (%d,%d,%d), thread (%d,%d,%d)]: " format, \
blockIdx.x, blockIdx.y, blockIdx.z, \
threadIdx.x, threadIdx.y, threadIdx.z, \
__VA_ARGS__);
# endif
#endif
/**
* Formats and prints the given message to stdout only if DEBUG is defined
*/
#if !defined(CUTE_LOG_DEBUG)
# ifdef DEBUG
# define CUTE_LOG_DEBUG(format, ...) CUTE_LOG(format, __VA_ARGS__)
# else
# define CUTE_LOG_DEBUG(format, ...)
# endif
#endif
/**
* \brief Perror macro with exit
*/
#if !defined(CUTE_ERROR_EXIT)
# define CUTE_ERROR_EXIT(e) \
do { \
cudaError_t code = (e); \
if (code != cudaSuccess) { \
fprintf(stderr, "<%s:%d> %s:\n %s: %s\n", \
__FILE__, __LINE__, #e, \
cudaGetErrorName(code), cudaGetErrorString(code)); \
fflush(stderr); \
exit(0); \
} \
} while (0)
#endif
#if !defined(CUTE_CHECK_LAST)
# define CUTE_CHECK_LAST() CUTE_ERROR_EXIT(cudaPeekAtLastError()); CUTE_ERROR_EXIT(cudaDeviceSynchronize())
#endif
#if !defined(CUTE_CHECK_ERROR)
# define CUTE_CHECK_ERROR(e) CUTE_ERROR_EXIT(e)
#endif
// A dummy function that uses compilation failure to print a type
template <class T>
CUTE_HOST_DEVICE
void
print_type(T&&) {
static_assert(sizeof(T) < 0, "Printing type T.");
}
//
// Device-specific helpers
//
// e.g.
// if (thread0()) print(...);
// if (block0()) print(...);
// if (thread(42)) print(...);
CUTE_HOST_DEVICE
bool
thread(int tid, int bid)
{
#if defined(__CUDA_ARCH__)
return (threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.x*blockDim.y == tid)
&& ( blockIdx.x + blockIdx.y* gridDim.x + blockIdx.z* gridDim.x* gridDim.y == bid);
#else
return true;
#endif
}
CUTE_HOST_DEVICE
bool
thread(int tid)
{
return thread(tid, 0);
}
CUTE_HOST_DEVICE
bool
thread0()
{
return thread(0,0);
}
CUTE_HOST_DEVICE
bool
block0()
{
#if defined(__CUDA_ARCH__)
return !(blockIdx.x | blockIdx.y | blockIdx.z);
#else
return true;
#endif
}
} // end namespace cute

140
include/cute/util/print.hpp Normal file
View File

@ -0,0 +1,140 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <type_traits>
#include <cute/config.hpp>
//
// CUDA compatible print and printf
//
namespace cute
{
CUTE_HOST_DEVICE
int
num_digits(int x)
{
return (x < 10 ? 1 :
(x < 100 ? 2 :
(x < 1000 ? 3 :
(x < 10000 ? 4 :
(x < 100000 ? 5 :
(x < 1000000 ? 6 :
(x < 10000000 ? 7 :
(x < 100000000 ? 8 :
(x < 1000000000 ? 9 :
10)))))))));
}
template <class T>
struct format_and_size {
using type = T;
char const* format;
int digits;
};
CUTE_HOST_DEVICE
format_and_size<int>
get_format(bool) {
return {"%*d", 3};
}
CUTE_HOST_DEVICE
format_and_size<int32_t>
get_format(int32_t) {
return {"%*d", 5};
}
CUTE_HOST_DEVICE
format_and_size<uint32_t>
get_format(uint32_t) {
return {"%*d", 5};
}
CUTE_HOST_DEVICE
format_and_size<int64_t>
get_format(int64_t) {
return {"%*d", 5};
}
CUTE_HOST_DEVICE
format_and_size<uint64_t>
get_format(uint64_t) {
return {"%*d", 5};
}
CUTE_HOST_DEVICE
format_and_size<float>
get_format(half_t) {
return {"%*.2f", 8};
}
CUTE_HOST_DEVICE
format_and_size<float>
get_format(float) {
return {"%*.2e", 10};
}
CUTE_HOST_DEVICE
format_and_size<double>
get_format(double) {
return {"%*.3e", 11};
}
//
// print dispatcher
//
CUTE_HOST_DEVICE
void
print(char const& c) {
printf("%c", c);
}
template <class T,
__CUTE_REQUIRES(std::is_integral<T>::value)>
CUTE_HOST_DEVICE
void
print(T const& a) {
printf("%d", int(a));
}
template <class... T>
CUTE_HOST_DEVICE
void
print(char const* format, T const&... t) {
printf(format, t...);
}
} // end namespace cute

View File

@ -0,0 +1,101 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <type_traits>
#include <cute/config.hpp>
#define __CUTE_REQUIRES(...) typename std::enable_if<(__VA_ARGS__)>::type* = nullptr
#define __CUTE_REQUIRES_V(...) typename std::enable_if<decltype((__VA_ARGS__))::value>::type* = nullptr
namespace cute
{
using std::conjunction;
using std::conjunction_v;
using std::disjunction;
using std::disjunction_v;
using std::negation;
using std::negation_v;
using std::void_t;
// C++20
// using std::remove_cvref;
template <class T>
struct remove_cvref {
using type = std::remove_cv_t<std::remove_reference_t<T>>;
};
// C++20
// using std::remove_cvref_t;
template <class T>
using remove_cvref_t = typename remove_cvref<T>::type;
//
// is_valid
//
namespace detail {
template <class F, class... Args, class = decltype(std::declval<F&&>()(std::declval<Args&&>()...))>
CUTE_HOST_DEVICE constexpr auto
is_valid_impl(int) { return std::true_type{}; }
template <class F, class... Args>
CUTE_HOST_DEVICE constexpr auto
is_valid_impl(...) { return std::false_type{}; }
template <class F>
struct is_valid_fn {
template <class... Args>
CUTE_HOST_DEVICE constexpr auto
operator()(Args&&...) const { return is_valid_impl<F, Args&&...>(int{}); }
};
} // end namespace detail
template <class F>
CUTE_HOST_DEVICE constexpr auto
is_valid(F&&) {
return detail::is_valid_fn<F&&>{};
}
template <class F, class... Args>
CUTE_HOST_DEVICE constexpr auto
is_valid(F&&, Args&&...) {
return detail::is_valid_impl<F&&, Args&&...>(int{});
}
} // end namespace cute