CUTLASS 3.0.0 (#786)

* CUTLASS 3.0.0
This commit is contained in:
Vijay Thakkar
2023-01-23 17:55:28 -08:00
committed by GitHub
parent 66d9cddc83
commit 277bd6e537
377 changed files with 76396 additions and 1186 deletions

View File

@ -0,0 +1,67 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cuda_runtime.h>
struct GPU_Clock
{
GPU_Clock() {
cudaEventCreate(&start_);
cudaEventCreate(&stop_);
cudaEventRecord(start_);
}
~GPU_Clock() {
cudaEventDestroy(start_);
cudaEventDestroy(stop_);
}
void start() {
cudaEventRecord(start_);
}
float milliseconds() {
cudaEventRecord(stop_);
cudaEventSynchronize(stop_);
float time;
cudaEventElapsedTime(&time, start_, stop_);
return time;
}
float seconds() {
return milliseconds() * float(1e-3);
}
private:
cudaEvent_t start_, stop_;
};

View File

@ -0,0 +1,526 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cuda_runtime.h>
#include <cublas_v2.h>
//-- BLAM_DEBUG_OUT ---------------------------------------------------------
#ifdef BLAM_DEBUG
# include <iostream>
# ifndef BLAM_DEBUG_OUT
# define BLAM_DEBUG_OUT(msg) std::cerr << "BLAM: " << msg << std::endl
# define BLAM_DEBUG_OUT_2(msg) std::cerr << msg << std::endl
# endif // BLAM_DEBUG_OUT
#else
# ifndef BLAM_DEBUG_OUT
# define BLAM_DEBUG_OUT(msg)
# define BLAM_DEBUG_OUT_2(msg)
# endif // BLAM_DEBUG_OUT
#endif // BLAM_DEBUG
// User could potentially define ComplexFloat/ComplexDouble instead of std::
#ifndef BLAM_COMPLEX_TYPES
#define BLAM_COMPLEX_TYPES 1
#include <cuda/std/complex>
namespace blam {
template <typename T>
using Complex = cuda::std::complex<T>;
using ComplexFloat = cuda::std::complex<float>;
using ComplexDouble = cuda::std::complex<double>;
}
#endif // BLAM_COMPLEX_TYPES
// User could potentially define Half instead of cute::
#ifndef BLAM_HALF_TYPE
#define BLAM_HALF_TYPE 1
#include <cute/numeric/half.hpp>
namespace blam {
using Half = cute::half_t;
}
#endif // BLAM_HALF_TYPE
namespace blam
{
namespace cublas
{
inline const char*
cublas_get_error(cublasStatus_t status)
{
switch (status) {
case CUBLAS_STATUS_SUCCESS:
return "CUBLAS_STATUS_SUCCESS";
case CUBLAS_STATUS_NOT_INITIALIZED:
return "CUBLAS_STATUS_NOT_INITIALIZED -- The cuBLAS library was not initialized.";
case CUBLAS_STATUS_ALLOC_FAILED:
return "CUBLAS_STATUS_ALLOC_FAILED -- Resource allocation failed inside the cuBLAS library.";
case CUBLAS_STATUS_INVALID_VALUE:
return "CUBLAS_STATUS_INVALID_VALUE -- An unsupported value or parameter was passed to the function.";
case CUBLAS_STATUS_ARCH_MISMATCH:
return "CUBLAS_STATUS_ARCH_MISMATCH -- The function requires a feature absent from the device architecture.";
case CUBLAS_STATUS_MAPPING_ERROR:
return "CUBLAS_STATUS_MAPPING_ERROR -- An access to GPU memory space failed.";
case CUBLAS_STATUS_EXECUTION_FAILED:
return "CUBLAS_STATUS_EXECUTION_FAILED -- The GPU program failed to execute.";
case CUBLAS_STATUS_INTERNAL_ERROR:
return "CUBLAS_STATUS_INTERNAL_ERROR -- An internal cuBLAS operation failed.";
case CUBLAS_STATUS_NOT_SUPPORTED:
return "CUBLAS_STATUS_NOT_SUPPORTED -- The functionality requested is not supported.";
case CUBLAS_STATUS_LICENSE_ERROR:
return "CUBLAS_STATUS_LICENSE_ERROR -- An error was detected when checking the current licensing.";
default:
return "CUBLAS_ERROR -- <unknown>";
}
}
inline bool
cublas_is_error(cublasStatus_t status)
{
return status != CUBLAS_STATUS_SUCCESS;
}
// hgemm
inline cublasStatus_t
gemm(cublasHandle_t handle,
cublasOperation_t transA, cublasOperation_t transB,
int m, int n, int k,
const Half* alpha,
const Half* A, int ldA,
const Half* B, int ldB,
const Half* beta,
Half* C, int ldC)
{
BLAM_DEBUG_OUT("cublasHgemm");
return cublasGemmEx(handle, transA, transB,
m, n, k,
reinterpret_cast<const __half*>(alpha),
reinterpret_cast<const __half*>(A), CUDA_R_16F, ldA,
reinterpret_cast<const __half*>(B), CUDA_R_16F, ldB,
reinterpret_cast<const __half*>(beta),
reinterpret_cast< __half*>(C), CUDA_R_16F, ldC,
CUDA_R_16F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
// mixed hf gemm
inline cublasStatus_t
gemm(cublasHandle_t handle,
cublasOperation_t transA, cublasOperation_t transB,
int m, int n, int k,
const float* alpha,
const Half* A, int ldA,
const Half* B, int ldB,
const float* beta,
float* C, int ldC)
{
BLAM_DEBUG_OUT("cublasGemmEx mixed half-float");
return cublasGemmEx(handle, transA, transB,
m, n, k,
alpha,
reinterpret_cast<const __half*>(A), CUDA_R_16F, ldA,
reinterpret_cast<const __half*>(B), CUDA_R_16F, ldB,
beta,
C, CUDA_R_32F, ldC,
CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
// igemm
inline cublasStatus_t
gemm(cublasHandle_t handle,
cublasOperation_t transA, cublasOperation_t transB,
int m, int n, int k,
const int32_t* alpha,
const int8_t* A, int ldA,
const int8_t* B, int ldB,
const int32_t* beta,
int32_t* C, int ldC)
{
BLAM_DEBUG_OUT("cublasIgemm");
return cublasGemmEx(handle, transA, transB,
m, n, k,
alpha,
A, CUDA_R_8I, ldA,
B, CUDA_R_8I, ldB,
beta,
C, CUDA_R_32I, ldC,
CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
// sgemm
inline cublasStatus_t
gemm(cublasHandle_t handle,
cublasOperation_t transA, cublasOperation_t transB,
int m, int n, int k,
const float* alpha,
const float* A, int ldA,
const float* B, int ldB,
const float* beta,
float* C, int ldC)
{
BLAM_DEBUG_OUT("cublasSgemm");
return cublasSgemm(handle, transA, transB,
m, n, k,
alpha,
A, ldA,
B, ldB,
beta,
C, ldC);
}
// dgemm
inline cublasStatus_t
gemm(cublasHandle_t handle,
cublasOperation_t transA, cublasOperation_t transB,
int m, int n, int k,
const double* alpha,
const double* A, int ldA,
const double* B, int ldB,
const double* beta,
double* C, int ldC)
{
BLAM_DEBUG_OUT("cublasDgemm");
return cublasDgemm(handle, transA, transB,
m, n, k,
alpha,
A, ldA,
B, ldB,
beta,
C, ldC);
}
// cgemm
inline cublasStatus_t
gemm(cublasHandle_t handle,
cublasOperation_t transA, cublasOperation_t transB,
int m, int n, int k,
const ComplexFloat* alpha,
const ComplexFloat* A, int ldA,
const ComplexFloat* B, int ldB,
const ComplexFloat* beta,
ComplexFloat* C, int ldC)
{
BLAM_DEBUG_OUT("cublasCgemm");
return cublasCgemm(handle, transA, transB,
m, n, k,
reinterpret_cast<const cuFloatComplex*>(alpha),
reinterpret_cast<const cuFloatComplex*>(A), ldA,
reinterpret_cast<const cuFloatComplex*>(B), ldB,
reinterpret_cast<const cuFloatComplex*>(beta),
reinterpret_cast<cuFloatComplex*>(C), ldC);
}
// zgemm
inline cublasStatus_t
gemm(cublasHandle_t handle,
cublasOperation_t transA, cublasOperation_t transB,
int m, int n, int k,
const ComplexDouble* alpha,
const ComplexDouble* A, int ldA,
const ComplexDouble* B, int ldB,
const ComplexDouble* beta,
ComplexDouble* C, int ldC)
{
BLAM_DEBUG_OUT("cublasZgemm");
return cublasZgemm(handle, transA, transB,
m, n, k,
reinterpret_cast<const cuDoubleComplex*>(alpha),
reinterpret_cast<const cuDoubleComplex*>(A), ldA,
reinterpret_cast<const cuDoubleComplex*>(B), ldB,
reinterpret_cast<const cuDoubleComplex*>(beta),
reinterpret_cast<cuDoubleComplex*>(C), ldC);
}
// hgemm
inline cublasStatus_t
gemm_batch(cublasHandle_t handle,
cublasOperation_t transA, cublasOperation_t transB,
int m, int n, int k,
const Half* alpha,
const Half* A, int ldA, int loA,
const Half* B, int ldB, int loB,
const Half* beta,
Half* C, int ldC, int loC,
int batch_size)
{
BLAM_DEBUG_OUT("cublasHgemmStridedBatched");
return cublasHgemmStridedBatched(handle, transA, transB,
m, n, k,
reinterpret_cast<const __half*>(alpha),
reinterpret_cast<const __half*>(A), ldA, loA,
reinterpret_cast<const __half*>(B), ldB, loB,
reinterpret_cast<const __half*>(beta),
reinterpret_cast<__half*>(C), ldC, loC,
batch_size);
}
// sgemm
inline cublasStatus_t
gemm_batch(cublasHandle_t handle,
cublasOperation_t transA, cublasOperation_t transB,
int m, int n, int k,
const float* alpha,
const float* A, int ldA, int loA,
const float* B, int ldB, int loB,
const float* beta,
float* C, int ldC, int loC,
int batch_size)
{
BLAM_DEBUG_OUT("cublasSgemmStridedBatched");
return cublasSgemmStridedBatched(handle, transA, transB,
m, n, k,
alpha,
A, ldA, loA,
B, ldB, loB,
beta,
C, ldC, loC,
batch_size);
}
// dgemm
inline cublasStatus_t
gemm_batch(cublasHandle_t handle,
cublasOperation_t transA, cublasOperation_t transB,
int m, int n, int k,
const double* alpha,
const double* A, int ldA, int loA,
const double* B, int ldB, int loB,
const double* beta,
double* C, int ldC, int loC,
int batch_size)
{
BLAM_DEBUG_OUT("cublasDgemmStridedBatched");
return cublasDgemmStridedBatched(handle, transA, transB,
m, n, k,
alpha,
A, ldA, loA,
B, ldB, loB,
beta,
C, ldC, loC,
batch_size);
}
// cgemm
inline cublasStatus_t
gemm_batch(cublasHandle_t handle,
cublasOperation_t transA, cublasOperation_t transB,
int m, int n, int k,
const ComplexFloat* alpha,
const ComplexFloat* A, int ldA, int loA,
const ComplexFloat* B, int ldB, int loB,
const ComplexFloat* beta,
ComplexFloat* C, int ldC, int loC,
int batch_size)
{
BLAM_DEBUG_OUT("cublasCgemmStridedBatched");
return cublasCgemmStridedBatched(handle, transA, transB,
m, n, k,
reinterpret_cast<const cuFloatComplex*>(alpha),
reinterpret_cast<const cuFloatComplex*>(A), ldA, loA,
reinterpret_cast<const cuFloatComplex*>(B), ldB, loB,
reinterpret_cast<const cuFloatComplex*>(beta),
reinterpret_cast<cuFloatComplex*>(C), ldC, loC,
batch_size);
}
// zgemm
inline cublasStatus_t
gemm_batch(cublasHandle_t handle,
cublasOperation_t transA, cublasOperation_t transB,
int m, int n, int k,
const ComplexDouble* alpha,
const ComplexDouble* A, int ldA, int loA,
const ComplexDouble* B, int ldB, int loB,
const ComplexDouble* beta,
ComplexDouble* C, int ldC, int loC,
int batch_size)
{
BLAM_DEBUG_OUT("cublasZgemmStridedBatched");
return cublasZgemmStridedBatched(handle, transA, transB,
m, n, k,
reinterpret_cast<const cuDoubleComplex*>(alpha),
reinterpret_cast<const cuDoubleComplex*>(A), ldA, loA,
reinterpret_cast<const cuDoubleComplex*>(B), ldB, loB,
reinterpret_cast<const cuDoubleComplex*>(beta),
reinterpret_cast<cuDoubleComplex*>(C), ldC, loC,
batch_size);
}
// hgemm
inline cublasStatus_t
gemm_batch(cublasHandle_t handle,
cublasOperation_t transA, cublasOperation_t transB,
int m, int n, int k,
const Half* alpha,
const Half* const A[], int ldA,
const Half* const B[], int ldB,
const Half* beta,
Half* const C[], int ldC,
int batch_size)
{
BLAM_DEBUG_OUT("cublasHgemmBatched");
return cublasHgemmBatched(handle, transA, transB,
m, n, k,
reinterpret_cast<const __half*>(alpha),
reinterpret_cast<const __half**>(const_cast<const Half**>(A)), ldA,
// A, ldA, // cuBLAS 9.2
reinterpret_cast<const __half**>(const_cast<const Half**>(B)), ldB,
// B, ldB, // cuBLAS 9.2
reinterpret_cast<const __half*>(beta),
reinterpret_cast<__half**>(const_cast<Half**>(C)), ldC,
// C, ldC, // cuBLAS 9.2
batch_size);
}
// sgemm
inline cublasStatus_t
gemm_batch(cublasHandle_t handle,
cublasOperation_t transA, cublasOperation_t transB,
int m, int n, int k,
const float* alpha,
const float* const A[], int ldA,
const float* const B[], int ldB,
const float* beta,
float* const C[], int ldC,
int batch_size)
{
BLAM_DEBUG_OUT("cublasSgemmBatched");
return cublasSgemmBatched(handle, transA, transB,
m, n, k,
alpha,
const_cast<const float**>(A), ldA,
// A, ldA, // cuBLAS 9.2
const_cast<const float**>(B), ldB,
// B, ldB, // cuBLAS 9.2
beta,
const_cast<float**>(C), ldC,
// C, ldC, // cuBLAS 9.2
batch_size);
}
// dgemm
inline cublasStatus_t
gemm_batch(cublasHandle_t handle,
cublasOperation_t transA, cublasOperation_t transB,
int m, int n, int k,
const double* alpha,
const double* const A[], int ldA,
const double* const B[], int ldB,
const double* beta,
double* const C[], int ldC,
int batch_size)
{
BLAM_DEBUG_OUT("cublasDgemmBatched");
return cublasDgemmBatched(handle, transA, transB,
m, n, k,
alpha,
const_cast<const double**>(A), ldA,
// A, ldA, // cuBLAS 9.2
const_cast<const double**>(B), ldB,
// B, ldB, // cuBLAS 9.2
beta,
const_cast<double**>(C), ldC,
// C, ldC, // cuBLAS 9.2
batch_size);
}
// cgemm
inline cublasStatus_t
gemm_batch(cublasHandle_t handle,
cublasOperation_t transA, cublasOperation_t transB,
int m, int n, int k,
const ComplexFloat* alpha,
const ComplexFloat* const A[], int ldA,
const ComplexFloat* const B[], int ldB,
const ComplexFloat* beta,
ComplexFloat* const C[], int ldC,
int batch_size)
{
BLAM_DEBUG_OUT("cublasCgemmBatched");
return cublasCgemmBatched(handle, transA, transB,
m, n, k,
reinterpret_cast<const cuFloatComplex*>(alpha),
const_cast<const cuFloatComplex**>(reinterpret_cast<const cuFloatComplex* const *>(A)), ldA,
//reinterpret_cast<const cuFloatComplex* const *>(A), ldA, // cuBLAS 9.2
const_cast<const cuFloatComplex**>(reinterpret_cast<const cuFloatComplex* const *>(B)), ldB,
//reinterpret_cast<const cuFloatComplex* const *>(B), ldB, // cuBLAS 9.2
reinterpret_cast<const cuFloatComplex*>(beta),
const_cast<cuFloatComplex**>(reinterpret_cast<cuFloatComplex* const *>(C)), ldC,
//reinterpret_cast<cuFloatComplex* const *>(C), ldC, // cuBLAS 9.2
batch_size);
}
// zgemm
inline cublasStatus_t
gemm_batch(cublasHandle_t handle,
cublasOperation_t transA, cublasOperation_t transB,
int m, int n, int k,
const ComplexDouble* alpha,
const ComplexDouble* const A[], int ldA,
const ComplexDouble* const B[], int ldB,
const ComplexDouble* beta,
ComplexDouble* const C[], int ldC,
int batch_size)
{
BLAM_DEBUG_OUT("cublasZgemmBatched");
return cublasZgemmBatched(handle, transA, transB,
m, n, k,
reinterpret_cast<const cuDoubleComplex*>(alpha),
const_cast<const cuDoubleComplex**>(reinterpret_cast<const cuDoubleComplex* const *>(A)), ldA,
//reinterpret_cast<const cuDoubleComplex* const *>(A), ldA, // cuBLAS 9.2
const_cast<const cuDoubleComplex**>(reinterpret_cast<const cuDoubleComplex* const *>(B)), ldB,
//reinterpret_cast<const cuDoubleComplex* const *>(B), ldB, // cuBLAS 9.2
reinterpret_cast<const cuDoubleComplex*>(beta),
const_cast<cuDoubleComplex**>(reinterpret_cast<cuDoubleComplex* const *>(C)), ldC,
//reinterpret_cast<cuDoubleComplex* const *>(C), ldC, // cuBLAS 9.2
batch_size);
}
} // end namespace cublas
} // end namespace blam

View File

@ -456,7 +456,7 @@ void layernorm(cutlass::MatrixCoord tensor_size,
block.x = 1024;
}
// TODO : There should be better configs for different cases, we only use several samples to show how to use here
// TODO : using registers to store values locally can reduce the ldgs from global memory and speedup the kernels.
// TODO : using registers to store values locally can reduce the loads from global memory and speedup the kernels.
if ((n % 4 == 0) && (n >= 128) && (n <= 4096)) {
block.x = (n/4 + 31)/32*32;
if (std::is_same<T, float>::value) {

View File

@ -0,0 +1,116 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cuda.h>
#include <cute/util/debug.hpp>
namespace cute
{
void
device_init(int device_id, bool quiet = false)
{
cudaDeviceProp device_prop;
std::size_t device_free_physmem;
std::size_t device_total_physmem;
CUTE_CHECK_ERROR(cudaSetDevice(device_id));
CUTE_CHECK_ERROR(cudaMemGetInfo(&device_free_physmem, &device_total_physmem));
CUTE_CHECK_ERROR(cudaGetDeviceProperties(&device_prop, device_id));
if (device_prop.major < 1) {
fprintf(stderr, "Device does not support CUDA.\n");
exit(1);
}
//float device_giga_bandwidth = float(device_prop.memoryBusWidth) * device_prop.memoryClockRate * 2 / 8 / 1000 / 1000;
if (!quiet) {
printf("Using device %d: %s (SM%d, %d SMs)\n",
device_id, device_prop.name,
device_prop.major * 10 + device_prop.minor,
device_prop.multiProcessorCount);
fflush(stdout);
}
}
/**
* Convert the SM version (e.g. v7.0, v7.5) to the physical number of cores.
*/
inline int
_ConvertSMVer2Cores(int major, int minor)
{
// Defines for GPU Architecture types (using the SM version to determine
// the # of cores per SM
typedef struct {
int SM; // 0xMm (hexidecimal notation), M = SM Major version,
// and m = SM minor version
int Cores;
} sSMtoCores;
sSMtoCores nGpuArchCoresPerSM[] = {
{0x30, 192},
{0x32, 192},
{0x35, 192},
{0x37, 192},
{0x50, 128},
{0x52, 128},
{0x53, 128},
{0x60, 64},
{0x61, 128},
{0x62, 128},
{0x70, 64},
{0x72, 64},
{0x75, 64},
{-1, -1}};
int index = 0;
while (nGpuArchCoresPerSM[index].SM != -1) {
if (nGpuArchCoresPerSM[index].SM == ((major << 4) + minor)) {
return nGpuArchCoresPerSM[index].Cores;
}
index++;
}
// If we don't find the values, we default use the previous one
// to run properly
printf("MapSMtoCores for SM %d.%d is undefined."
" Default to use %d Cores/SM\n",
major, minor, nGpuArchCoresPerSM[index - 1].Cores);
return nGpuArchCoresPerSM[index - 1].Cores;
}
} // end namespace cute

View File

@ -0,0 +1,101 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Utilities for packing a rank-X shape into a rank-(X-1) stride in CuTe.
*/
#pragma once
#include "cute/stride.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
// Strides without batch mode
template <class StrideIntT>
cute::Stride<StrideIntT, cute::Int<1>>
make_cute_packed_stride(cute::Stride<StrideIntT, cute::Int<1>> s, cute::Shape<int,int,int> shape_MKL) {
static_assert(std::is_integral_v<StrideIntT>,
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
auto s_copy = s;
cute::get<0>(s_copy) = static_cast<StrideIntT>(cute::get<1>(shape_MKL));
return s_copy;
}
template <class StrideIntT>
cute::Stride<cute::Int<1>, StrideIntT>
make_cute_packed_stride(cute::Stride<cute::Int<1>, StrideIntT> s, cute::Shape<int,int,int> shape_MKL) {
static_assert(std::is_integral_v<StrideIntT>,
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
auto s_copy = s;
cute::get<1>(s_copy) = static_cast<StrideIntT>(cute::get<0>(shape_MKL));
return s_copy;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
// Strides with batch mode
template <class StrideIntT>
cute::Stride<StrideIntT, cute::Int<1>, int64_t>
make_cute_packed_stride(cute::Stride<StrideIntT, cute::Int<1>, int64_t> s, cute::Shape<int,int,int> shape_MKL) {
static_assert(std::is_integral_v<StrideIntT>,
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
auto s_copy = s;
cute::get<0>(s_copy) = static_cast<StrideIntT>(cute::get<1>(shape_MKL));
int batch_count = cute::get<2>(shape_MKL);
if (batch_count > 1) {
cute::get<2>(s_copy) = static_cast<StrideIntT>(cute::get<0>(shape_MKL) * cute::get<1>(shape_MKL));
}
else {
cute::get<2>(s_copy) = static_cast<StrideIntT>(0);
}
return s_copy;
}
template <class StrideIntT>
cute::Stride<cute::Int<1>, StrideIntT, int64_t>
make_cute_packed_stride(cute::Stride<cute::Int<1>, StrideIntT, int64_t> s, cute::Shape<int,int,int> shape_MKL) {
static_assert(std::is_integral_v<StrideIntT>,
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
auto s_copy = s;
cute::get<1>(s_copy) = static_cast<StrideIntT>(cute::get<0>(shape_MKL));
int batch_count = cute::get<2>(shape_MKL);
if (batch_count > 1) {
cute::get<2>(s_copy) = static_cast<StrideIntT>(cute::get<0>(shape_MKL) * cute::get<1>(shape_MKL));
}
else {
cute::get<2>(s_copy) = static_cast<StrideIntT>(0);
}
return s_copy;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,235 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <array>
#include <cassert>
#include <cmath>
#include <iostream>
#include <type_traits>
#include <cute/util/type_traits.hpp>
#include <cute/tensor.hpp>
#include <cute/numeric/half.hpp>
#include <cute/numeric/complex.hpp>
#include <cutlass/layout/layout.h>
// The computed infinity norm does not include
// any NaN column absolute-value sums.
struct matrix_inf_norm_result {
// Accumulate errors in double, as this is generally
// the highest precision that the examples use.
double inf_norm = 0.0;
bool found_nan = false;
};
// In theory, cute::Tensor<ViewEngine<T*>, T> could be treated as a view type,
// and thus passed by value (as std::span or std::string_view would be).
// However, generic cute::Tensor are more like containers
// and thus are best passed by reference or const reference.
template <typename EngineType, typename LayoutType>
matrix_inf_norm_result
matrix_inf_norm(const cute::Tensor<EngineType, LayoutType>& host_matrix)
{
using std::abs;
using error_type = decltype(std::declval<matrix_inf_norm_result>().inf_norm);
error_type inf_norm = 0.0;
bool found_nan = false;
const auto shape = host_matrix.shape();
using index_type = std::decay_t<decltype(cute::get<0>(shape))>;
// Computing the infinity norm requires that we be able
// to treat the input as a matrix, with rows and columns.
static_assert(std::is_integral_v<index_type>);
const index_type num_rows = cute::get<0>(shape);
const index_type num_cols = cute::get<1>(shape);
for(index_type i = 0; i < num_rows; ++i) {
error_type row_abs_sum = 0.0;
for(index_type j = 0; j < num_cols; ++j) {
row_abs_sum += abs(host_matrix(i, j));
}
if(std::isnan(row_abs_sum)) {
found_nan = true;
} else {
inf_norm = row_abs_sum > inf_norm ? row_abs_sum : inf_norm;
}
}
return {inf_norm, found_nan};
}
// Infinity norm of (X - Y).
template <typename EngineType, typename LayoutType>
matrix_inf_norm_result
matrix_diff_inf_norm(const cute::Tensor<EngineType, LayoutType>& X,
const cute::Tensor<EngineType, LayoutType>& Y)
{
using std::abs;
using error_type = decltype(std::declval<matrix_inf_norm_result>().inf_norm);
const auto X_shape = X.shape();
const auto Y_shape = Y.shape();
using index_type = std::decay_t<decltype(cute::get<0>(X_shape))>;
// Computing the infinity norm requires that we be able
// to treat the input as a matrix, with rows and columns.
static_assert(std::is_integral_v<index_type>);
const index_type num_rows = cute::get<0>(X_shape);
const index_type num_cols = cute::get<1>(X_shape);
assert(num_rows == cute::get<0>(Y_shape));
assert(num_cols == cute::get<1>(Y_shape));
auto matrix_ij = [&](const auto& A, std::size_t i, std::size_t j) {
return A(i, j);
};
auto diff_ij = [&](std::size_t i, std::size_t j) {
return matrix_ij(X, i, j) - matrix_ij(Y, i, j);
};
error_type inf_norm = 0.0;
bool found_nan = false;
for(index_type i = 0; i < num_rows; ++i) {
error_type row_abs_sum = 0.0;
for(index_type j = 0; j < num_cols; ++j) {
row_abs_sum += abs(diff_ij(i, j));
}
if(std::isnan(row_abs_sum)) {
found_nan = true;
} else {
inf_norm = row_abs_sum > inf_norm ? row_abs_sum : inf_norm;
}
}
return {inf_norm, found_nan};
}
template <typename EngineType_A, typename LayoutType_A,
typename EngineType_B, typename LayoutType_B,
typename EngineType_C_computed, typename LayoutType_C_computed,
typename EngineType_C_expected, typename LayoutType_C_expected>
void
print_matrix_multiply_mollified_relative_error(
const char A_value_type_name[],
const cute::Tensor<EngineType_A, LayoutType_A>& A,
const char B_value_type_name[],
const cute::Tensor<EngineType_B, LayoutType_B>& B,
const char C_value_type_name[],
const cute::Tensor<EngineType_C_computed, LayoutType_C_computed>& C_computed,
const cute::Tensor<EngineType_C_expected, LayoutType_C_expected>& C_expected)
{
const auto [A_norm, A_has_nan] = matrix_inf_norm(A);
const auto [B_norm, B_has_nan] = matrix_inf_norm(B);
const auto [C_norm, C_has_nan] = matrix_inf_norm(C_expected);
const auto [diff_norm, diff_has_nan] = matrix_diff_inf_norm(C_computed, C_expected);
const auto A_norm_times_B_norm = A_norm * B_norm;
const auto relative_error = A_norm_times_B_norm == 0.0 ?
diff_norm : (diff_norm / A_norm_times_B_norm);
// For expected error bounds, please refer to the LAPACK Users' Guide,
// in particular https://netlib.org/lapack/lug/node108.html .
// Printing the infinity norm of C is a way to check
// that both the function being tested (C_computed)
// and the reference implementation (C_expected)
// don't just do nothing (or fill with zeros).
using std::cout;
cout << "Value type of A: " << A_value_type_name << '\n'
<< std::scientific
<< "Infinity norm of A: " << A_norm << '\n'
<< "Value type of B: " << B_value_type_name << '\n'
<< "Infinity norm of B: " << B_norm << '\n'
<< "Value type of C: " << C_value_type_name << '\n'
<< "Infinity norm of C_expected: " << C_norm << '\n'
<< "Infinity norm of (C_computed - C_expected): " << diff_norm << '\n';
if(A_norm_times_B_norm == 0.0) {
cout << "Mollified relative error: " << relative_error << '\n';
} else {
cout << "Relative error: " << relative_error << '\n';
}
cout << "Did we encounter NaN in A? " << (A_has_nan ? "yes" : "no") << '\n'
<< "Did we encounter NaN in B? " << (B_has_nan ? "yes" : "no") << '\n'
<< "Did we encounter NaN in C_expected? " << (C_has_nan ? "yes" : "no") << '\n'
<< "Did we encounter NaN in (C_computed - C_expected)? "
<< (diff_has_nan ? "yes" : "no") << '\n';
}
template <typename EngineType, typename LayoutType>
void
print_matrix_multiply_mollified_relative_error(
const char value_type_name[],
const cute::Tensor<EngineType, LayoutType>& A,
const cute::Tensor<EngineType, LayoutType>& B,
const cute::Tensor<EngineType, LayoutType>& C_computed,
const cute::Tensor<EngineType, LayoutType>& C_expected)
{
print_matrix_multiply_mollified_relative_error(value_type_name, A, value_type_name, B,
value_type_name, C_computed, C_expected);
}
// Take a CUTLASS HostTensor (or the like) as input,
// and return a const CuTe Tensor.
// This is useful for use with the above error printing functions.
// This implicitly "transposes" if the layout is RowMajor.
// Note that the HostTensor must be captured by nonconst reference
// in order for X.host_ref().data() to compile.
// (CUTLASS is a bit more container-y than CuTe.)
template<class CutlassHostTensorType>
auto host_matrix_to_const_cute_tensor(CutlassHostTensorType& X)
{
// The tensors were created with post-transposed extents.
const auto extents = X.extent();
const auto shape = cute::Shape<int, int>{extents[0], extents[1]};
// Both RowMajor and ColumnMajor only store one stride.
const int LDX = X.stride(0);
const auto strides = [&]() {
using input_layout_type = typename std::decay_t<decltype(X)>::Layout;
if constexpr (std::is_same_v<input_layout_type, cutlass::layout::ColumnMajor>) {
return cute::Stride<int, int>{1, LDX};
}
else {
static_assert(std::is_same_v<input_layout_type, cutlass::layout::RowMajor>);
return cute::Stride<int, int>{LDX, 1};
}
}();
const auto layout = cute::make_layout(shape, strides);
auto X_data = X.host_ref().data();
auto X_data_const = const_cast<std::add_const_t< decltype(X_data)> >(X_data);
return cute::make_tensor(X_data_const, layout);
};

View File

@ -0,0 +1,311 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Reference implementation for GETT in host-side code.
*/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass/complex.h"
#include "cutlass/numeric_conversion.h"
#include "cute/tensor.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::reference::host {
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ElementAccumulator_,
class TensorA_, // (M, K, L)
class TensorB_ // (N, K, L)
>
struct GettMainloopParams {
using ElementAccumulator = ElementAccumulator_;
using TensorA = TensorA_;
using TensorB = TensorB_;
using EngineA = typename TensorA::engine_type;
using LayoutA = typename TensorA::layout_type;
using EngineB = typename TensorB::engine_type;
using LayoutB = typename TensorB::layout_type;
TensorA A{};
TensorB B{};
ComplexTransform transform_A = ComplexTransform::kNone;
ComplexTransform transform_B = ComplexTransform::kNone;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ElementScalar_,
class ElementAccumulator_,
class ElementCompute_,
class TensorC_, // (M, N, L)
class TensorD_ // (M, N, L)
>
struct GettEpilogueParams {
using ElementScalar = ElementScalar_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
using TensorC = TensorC_;
using TensorD = TensorD_;
using EngineC = typename TensorC::engine_type;
using LayoutC = typename TensorC::layout_type;
using EngineD = typename TensorD::engine_type;
using LayoutD = typename TensorD::layout_type;
ElementScalar alpha = ElementScalar(1);
ElementScalar beta = ElementScalar(0);
TensorC C{};
TensorD D{};
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GETT - General Tensor-Tensor contraction reference kernel
template <
class MainloopParams,
class EpilogueParams
>
void Gett(
MainloopParams const& mainloop_params,
EpilogueParams const& epilogue_params)
{
static int constexpr kBlockM = 64;
static int constexpr kBlockN = 64;
#pragma omp parallel for collapse(3)
for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) {
for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) {
for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) {
typename MainloopParams::ElementAccumulator acc[kBlockM][kBlockN];
gett_mainloop(mainloop_params, m, n, l, acc);
gett_epilogue(epilogue_params, m, n, l, acc);
}
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GETT - Mainloop
template <class MainloopParams, class ElementAccumulator, int kBlockM, int kBlockN>
void gett_mainloop(
MainloopParams const& mainloop_params,
int64_t m,
int64_t n,
int64_t l,
ElementAccumulator (&acc)[kBlockM][kBlockN])
{
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B");
static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B");
using ElementA = typename MainloopParams::EngineA::value_type;
using ElementB = typename MainloopParams::EngineB::value_type;
using RingOp = multiply_add<ElementAccumulator, ElementAccumulator, ElementAccumulator>;
RingOp fma_op;
// Zero out accumulators
for (int m_b = 0; m_b < kBlockM; ++m_b) {
for (int n_b = 0; n_b < kBlockN; ++n_b) {
acc[m_b][n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity
}
}
// Compute on this k-block
for (int64_t k = 0; k < cute::size<1>(mainloop_params.A.layout()); ++k) {
// Load A
ElementAccumulator a_frag[kBlockM];
for (int m_b = 0; m_b < kBlockM; ++m_b) {
if (m + m_b < cute::size<0>(mainloop_params.A.layout())) {
a_frag[m_b] = static_cast<ElementAccumulator>(mainloop_params.A(m + m_b, k, l));
if (mainloop_params.transform_A == ComplexTransform::kConjugate) {
a_frag[m_b] = conj(a_frag[m_b]);
}
} else {
a_frag[m_b] = ElementAccumulator(0); // RingOp::AdditionIdentity
}
}
// Load B
ElementAccumulator b_frag[kBlockN];
for (int n_b = 0; n_b < kBlockN; ++n_b) {
if (n + n_b < cute::size<0>(mainloop_params.B.layout())) {
b_frag[n_b] = static_cast<ElementAccumulator>(mainloop_params.B(n + n_b, k, l));
if (mainloop_params.transform_B == ComplexTransform::kConjugate) {
b_frag[n_b] = conj(b_frag[n_b]);
}
} else {
b_frag[n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity
}
}
// do compute
for (int m_b = 0; m_b < kBlockM; ++m_b) {
for (int n_b = 0; n_b < kBlockN; ++n_b) {
acc[m_b][n_b] = fma_op(a_frag[m_b], b_frag[n_b], acc[m_b][n_b]);
}
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GETT - Epilogue
template <class EpilogueParams, class ElementAccumulator, int kBlockM, int kBlockN>
void gett_epilogue(
EpilogueParams const& epilogue_params,
int64_t m,
int64_t n,
int64_t l,
ElementAccumulator (&acc)[kBlockM][kBlockN])
{
static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == 3, "M, K, B");
static_assert(cute::rank(typename EpilogueParams::LayoutD{}) == 3, "N, K, B");
using ElementCompute = typename EpilogueParams::ElementCompute;
using ElementC = typename EpilogueParams::EngineC::value_type;
using ElementD = typename EpilogueParams::EngineD::value_type;
using ElementScalar = typename EpilogueParams::ElementScalar;
// Input related converter
NumericConverter<ElementCompute, ElementAccumulator> accumulator_converter;
NumericConverter<ElementCompute, ElementC> source_converter;
// Scale related converter
NumericConverter<ElementCompute, ElementScalar> scale_converter;
// Output related converter
NumericConverter<ElementD, ElementCompute> destination_converter;
// Epilogue operations
multiply_add<ElementCompute, ElementCompute, ElementCompute> epilogue_fma;
multiplies<ElementCompute> mul;
// Do conversion
ElementCompute converted_alpha = scale_converter(epilogue_params.alpha);
ElementCompute converted_beta = scale_converter(epilogue_params.beta);
for (int n_b = 0; n_b < kBlockN; ++n_b) {
for (int m_b = 0; m_b < kBlockM; ++m_b) {
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) {
// Convert every type to ElementCompute first, do compute, convert to output type, write it out
ElementCompute converted_acc = accumulator_converter(acc[m_b][n_b]);
ElementCompute converted_src = source_converter(epilogue_params.C(m + m_b, n + n_b, l));
ElementScalar output = epilogue_fma(converted_alpha, converted_acc, ElementCompute(0));
output = epilogue_fma(converted_beta, converted_src, output);
epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(output);
}
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM - General Matrix-Matrix contraction without conjugation options
template <
class MainloopParams,
class EpilogueParams
>
void Gemm3x(
MainloopParams const& mainloop_params,
EpilogueParams const& epilogue_params)
{
using namespace cute;
static_assert(rank(typename MainloopParams::LayoutA{}) == rank(typename MainloopParams::LayoutB{}));
static_assert(rank(typename EpilogueParams::LayoutC{}) == rank(typename EpilogueParams::LayoutD{}));
static_assert(rank(typename MainloopParams::LayoutA{}) == rank(typename EpilogueParams::LayoutC{}));
if constexpr (rank(typename MainloopParams::LayoutA{}) == 2) {
// append a batch mode of size 1 if we do not have tensors that are rank 3
Layout layout_A = make_layout(
make_shape(get<0>(mainloop_params.A.shape()), get<1>(mainloop_params.A.shape()), Int<1>{}),
make_stride(get<0>(mainloop_params.A.stride()), get<1>(mainloop_params.A.stride()), int64_t(cosize(mainloop_params.A.layout()))));
Layout layout_B = make_layout(
make_shape(get<0>(mainloop_params.B.shape()), get<1>(mainloop_params.B.shape()), Int<1>{}),
make_stride(get<0>(mainloop_params.B.stride()), get<1>(mainloop_params.B.stride()), int64_t(cosize(mainloop_params.B.layout()))));
Layout layout_C = make_layout(
make_shape(get<0>(epilogue_params.C.shape()), get<1>(epilogue_params.C.shape()), Int<1>{}),
make_stride(get<0>(epilogue_params.C.stride()), get<1>(epilogue_params.C.stride()), int64_t(cosize(epilogue_params.C.layout()))));
Layout layout_D = make_layout(
make_shape(get<0>(epilogue_params.D.shape()), get<1>(epilogue_params.D.shape()), Int<1>{}),
make_stride(get<0>(epilogue_params.D.stride()), get<1>(epilogue_params.D.stride()), int64_t(cosize(epilogue_params.D.layout()))));
auto TensorA = make_tensor(mainloop_params.A.data(), layout_A);
auto TensorB = make_tensor(mainloop_params.B.data(), layout_B);
auto TensorC = make_tensor(epilogue_params.C.data(), layout_C);
auto TensorD = make_tensor(epilogue_params.D.data(), layout_D);
// Reconstruct mainloop params
GettMainloopParams<typename MainloopParams::ElementAccumulator,
decltype(TensorA),
decltype(TensorB)>
mainloop_params_converted{TensorA,
TensorB,
mainloop_params.transform_A,
mainloop_params.transform_B};
// Reconstruct epilogue params
GettEpilogueParams<typename EpilogueParams::ElementScalar,
typename EpilogueParams::ElementAccumulator,
typename EpilogueParams::ElementCompute,
decltype(TensorC),
decltype(TensorD)
>
epilogue_params_converted{epilogue_params.alpha,
epilogue_params.beta,
TensorC,
TensorD
};
Gett(mainloop_params_converted, epilogue_params_converted);
}
else {
// if we already have a batch mode, just pass it through
Gett(mainloop_params, epilogue_params);
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
} // cutlass::reference::host
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,101 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Provides several functions for filling tensors with data.
*/
#pragma once
// Standard Library includes
#include <utility>
#include <cstdlib>
#include <cmath>
// Cute includes
#include "cute/tensor.hpp"
// Cutlass includes
#include "cutlass/cutlass.h"
#include "cutlass/complex.h"
#include "cutlass/quaternion.h"
#include "cutlass/array.h"
#include "cutlass/numeric_types.h"
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace reference {
namespace host {
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Returns true if two tensor views are equal.
template <
typename TensorL,
typename TensorR
>
bool TensorEquals(
TensorL lhs,
TensorR rhs) {
// Extents must be identical
if (cute::size(lhs) != cute::size(rhs)) {
return false;
}
for (int64_t idx = 0; idx < cute::size(lhs); ++idx) {
if (lhs(idx) != rhs(idx)) {
return false;
}
}
return true;
}
/// Returns true if two tensor views are NOT equal.
template <
typename TensorL,
typename TensorR
>
bool TensorNotEquals(
TensorL lhs,
TensorR rhs) {
return TensorEquals(lhs, rhs);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace host
} // namespace reference
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,432 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Provides several functions for filling tensors with data.
*/
#pragma once
// Standard Library includes
#include <utility>
#include <cstdlib>
#include <cmath>
// Cute includes
#include "cute/tensor.hpp"
// Cutlass includes
#include "cutlass/cutlass.h"
#include "cutlass/complex.h"
#include "cutlass/quaternion.h"
#include "cutlass/array.h"
#include "cutlass/numeric_types.h"
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace reference {
namespace host {
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// Uniform and procedural tensor fills
//
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Fills a tensor with a scalar element
template <typename Tensor>
void TensorFill(Tensor dst, typename Tensor::value_type element) {
for (int64_t idx = 0; idx < cute::size(dst); ++idx) {
dst(idx) = element;
}
}
/// Fills a tensor with the contents of its layout
template <typename Tensor>
void TensorFillSequential(Tensor dst) {
auto layout = dst.layout();
for (int64_t idx = 0; idx < cute::size(dst); ++idx) {
dst(idx) = layout(idx);
}
}
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// Random uniform values
//
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
template <typename Element>
struct RandomUniformFunc {
using Real = typename RealType<Element>::Type;
uint64_t seed;
double range;
double min;
int int_scale;
//
// Methods
//
RandomUniformFunc(
uint64_t seed_ = 0,
double max = 1,
double min_ = 0,
int int_scale_ = -1
):
seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) {
std::srand((unsigned)seed);
}
/// Compute random value and update RNG state
Element operator()() const {
double rnd = double(std::rand()) / double(RAND_MAX);
rnd = min + range * rnd;
// Random values are cast to integer after scaling by a power of two to facilitate error
// testing
Element result;
if (int_scale >= 0) {
rnd = double(int64_t(rnd * double(1 << int_scale))) / double(1 << int_scale);
result = static_cast<Element>(Real(rnd));
}
else {
result = static_cast<Element>(Real(rnd));
}
return result;
}
};
/// Partial specialization for initializing a complex value.
template <typename Element>
struct RandomUniformFunc<complex<Element> > {
using Real = typename RealType<Element>::Type;
uint64_t seed;
double range;
double min;
int int_scale;
//
// Methods
//
RandomUniformFunc(
uint64_t seed_ = 0,
double max = 1,
double min_ = 0,
int int_scale_ = -1
):
seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) {
std::srand((unsigned)seed);
}
/// Compute random value and update RNG state
complex<Element> operator()() const {
Element reals[2];
for (int i = 0; i < 2; ++i) {
double rnd = double(std::rand()) / double(RAND_MAX);
rnd = min + range * rnd;
// Random values are cast to integer after scaling by a power of two to facilitate error
// testing
if (int_scale >= 0) {
rnd = double(int(rnd * double(1 << int_scale)));
reals[i] = from_real<Element>(Real(rnd / double(1 << int_scale)));
}
else {
reals[i] = from_real<Element>(Real(rnd));
}
}
return complex<Element>(reals[0], reals[1]);
}
};
/// Partial specialization for initializing a Quaternion value.
template <typename Element>
struct RandomUniformFunc<Quaternion<Element> > {
using Real = typename RealType<Element>::Type;
uint64_t seed;
double range;
double min;
int int_scale;
//
// Methods
//
RandomUniformFunc(
uint64_t seed_ = 0,
double max = 1,
double min_ = 0,
int int_scale_ = -1
):
seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) {
std::srand((unsigned)seed);
}
/// Compute random value and update RNG state
Quaternion<Element> operator()() const {
Element reals[4];
for (int i = 0; i < 4; ++i) {
double rnd = double(std::rand()) / double(RAND_MAX);
rnd = min + range * rnd;
// Random values are cast to integer after scaling by a power of two to facilitate error
// testing
if (int_scale >= 0) {
rnd = double(int(rnd * double(1 << int_scale)));
reals[i] = from_real<Element>(Real(rnd / double(1 << int_scale)));
}
else {
reals[i] = from_real<Element>(Real(rnd));
}
}
return make_Quaternion(reals[0], reals[1], reals[2], reals[3]);
}
};
} // namespace detail
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Fills a tensor with random values with a uniform random distribution.
template <typename Tensor> ///< Tensor object
void TensorFillRandomUniform(
Tensor dst, ///< destination tensor
uint64_t seed, ///< seed for RNG
double max = 1, ///< upper bound of distribution
double min = 0, ///< lower bound for distribution
int bits = -1) { ///< If non-negative, specifies number of fractional bits that
/// are not truncated to zero. Permits reducing precision of
/// data.
detail::RandomUniformFunc<typename Tensor::value_type> random_func(seed, max, min, bits);
for (int64_t idx = 0; idx < cute::size(dst); ++idx) {
dst(idx) = random_func();
}
}
/// Fills a block with random values with a uniform random distribution.
template <
typename Element ///< Element type
>
void BlockFillRandomUniform(
Element *ptr,
size_t capacity,
uint64_t seed, ///< seed for RNG
double max = 1, ///< upper bound of distribution
double min = 0, ///< lower bound for distribution
int bits = -1) { ///< If non-negative, specifies number of fractional bits that
/// are not truncated to zero. Permits reducing precision of
/// data.
detail::RandomUniformFunc<Element> random_func(seed, max, min, bits);
for (size_t i = 0; i < capacity; ++i) {
ptr[i] = random_func();
}
}
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// Random Gaussian
//
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
template <typename Element>
struct RandomGaussianFunc {
uint64_t seed;
double mean;
double stddev;
int int_scale;
double pi;
//
// Methods
//
RandomGaussianFunc(
uint64_t seed_ = 0,
double mean_ = 0,
double stddev_ = 1,
int int_scale_ = -1
):
seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)) {
std::srand((unsigned)seed);
}
/// Compute random value and update RNG state
Element operator()() const {
// Box-Muller transform to generate random numbers with Normal distribution
double u1 = double(std::rand()) / double(RAND_MAX);
double u2 = double(std::rand()) / double(RAND_MAX);
// Compute Gaussian random value
double rnd = std::sqrt(-2 * std::log(u1)) * std::cos(2 * pi * u2);
rnd = mean + stddev * rnd;
// Scale and convert final result
Element result;
if (int_scale >= 0) {
rnd = double(int64_t(rnd * double(1 << int_scale))) / double(1 << int_scale);
result = static_cast<Element>(rnd);
}
else {
result = static_cast<Element>(rnd);
}
return result;
}
};
} // namespace detail
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Fills a tensor with random values with a Gaussian distribution.
template <
typename Tensor
>
void TensorFillRandomGaussian(
Tensor dst, ///< destination tensor
uint64_t seed, ///< seed for RNG
double mean = 0, ///< Gaussian distribution's mean
double stddev = 1, ///< Gaussian distribution's standard deviation
int bits = -1) { ///< If non-negative, specifies number of fractional bits that
/// are not truncated to zero. Permits reducing precision of
/// data.
detail::RandomGaussianFunc<typename Tensor::value_type> random_func(seed, mean, stddev, bits);
for (int64_t idx = 0; idx < cute::size(dst); ++idx) {
dst(idx) = random_func();
}
}
/// Fills a block with random values with a Gaussian distribution.
template <
typename Element ///< Element type
>
void BlockFillRandomGaussian(
Element *ptr, ///< destination buffer
size_t capacity, ///< number of elements
uint64_t seed, ///< seed for RNG
double mean = 0, ///< Gaussian distribution's mean
double stddev = 1, ///< Gaussian distribution's standard deviation
int bits = -1) { ///< If non-negative, specifies number of fractional bits that
/// are not truncated to zero. Permits reducing precision of
/// data.
detail::RandomGaussianFunc<Element> random_func(seed, mean, stddev, bits);
for (size_t i = 0; i < capacity; ++i) {
ptr[i] = random_func();
}
}
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Fills a block of data with sequential elements
template <
typename Element
>
void BlockFillSequential(
Element *ptr,
int64_t capacity,
Element v = Element(1),
Element s = Element(0)) {
int i = 0;
while (i < capacity) {
ptr[i] = Element(s + v);
++i;
}
}
/// Fills a block of data with sequential elements
template <
typename Element
>
void BlockFillSequentialModN(
Element *ptr,
int64_t capacity,
int64_t mod,
int64_t v = int64_t(1),
int64_t s = int64_t(0)) {
int i = 0;
while (i < capacity) {
ptr[i] = static_cast<Element>(int32_t(int64_t(s + v) % mod));
++i;
}
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace host
} // namespace reference
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,203 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Provides several functions for filling tensors with data.
*/
#pragma once
// Standard Library includes
#include <utility>
#include <cstdlib>
#include <cmath>
// Cute includes
#include "cute/tensor.hpp"
// Cutlass includes
#include "cutlass/cutlass.h"
#include "cutlass/complex.h"
#include "cutlass/functional.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/quaternion.h"
#include "cutlass/array.h"
#include "cutlass/numeric_types.h"
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace reference {
namespace host {
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// Tensor reductions
//
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side
/// workspace
template <
typename Tensor,
typename ComputeType,
typename ReduceOp,
typename TransformOp
>
ComputeType TensorTransformReduce(
Tensor view,
ComputeType identity,
ReduceOp reduce,
TransformOp transform
) {
for (int64_t idx = 0; idx < cute::size(view); ++idx) {
identity = reduce(identity, transform(view(idx)));
}
return identity;
}
/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side
/// workspace
template <
typename TensorA,
typename TensorB,
typename ComputeType,
typename ReduceOp,
typename TransformOp
>
ComputeType TensorTransformReduce(
TensorA view_A,
TensorB view_B,
ComputeType identity,
ReduceOp reduce,
TransformOp transform) {
if (cute::size(view_A) != cute::size(view_B)) {
throw std::runtime_error("Tensor sizes must match.");
}
for (int64_t idx = 0; idx < cute::size(view_A); ++idx) {
identity = reduce(identity, transform(view_A(idx), view_B(idx)));
}
return identity;
}
/// Helper to compute the sum of the elements of a tensor
template <
typename Tensor,
typename ComputeType = typename Tensor::value_type
>
ComputeType TensorSum(
Tensor view,
ComputeType identity = ComputeType()
) {
plus<ComputeType> reduce;
NumericConverter<ComputeType, typename Tensor::value_type> transform;
return TensorTransformReduce(
view, identity, reduce, transform);
}
/// Helper to compute the sum of the squares of the elements of a tensor
template <
typename Tensor,
typename ComputeType = typename Tensor::value_type
>
ComputeType TensorSumSq(
Tensor view,
ComputeType identity = ComputeType()
) {
plus<ComputeType> reduce;
magnitude_squared<typename Tensor::value_type, ComputeType> transform;
return TensorTransformReduce(
view, identity, reduce, transform);
}
/// Helper to compute the norm of the elements of a tensor.
template <
typename Tensor,
typename ComputeType = double
>
ComputeType TensorNorm(
Tensor view,
ComputeType identity = ComputeType()
) {
return std::sqrt(TensorSumSq(view, identity));
}
/// Helper to compute the sum of the squares of the differences of two tensors
template <
typename TensorA,
typename TensorB,
typename ComputeType = double
>
ComputeType TensorSumSqDiff(
TensorA view_A,
TensorB view_B,
ComputeType identity = ComputeType()
) {
plus<ComputeType> reduce;
magnitude_squared_difference<typename TensorA::value_type, ComputeType> transform;
return TensorTransformReduce(
view_A, view_B, identity, reduce, transform);
}
/// Helper to compute the norm of the tensor computed as the difference of two tensors in memory
template <
typename TensorA,
typename TensorB,
typename ComputeType = double
>
ComputeType TensorNormDiff(
TensorA view_A,
TensorB view_B,
ComputeType identity = ComputeType()
) {
return std::sqrt(TensorSumSqDiff(view_A, view_B, identity));
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace host
} // namespace reference
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////