CUTLASS 2.9 (#468)
This commit is contained in:
@ -1,19 +1,31 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are not permitted.
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,19 +1,31 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are not permitted.
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
|
||||
141
tools/util/include/cutlass/util/device_nchw_to_nhwc.h
Normal file
141
tools/util/include/cutlass/util/device_nchw_to_nhwc.h
Normal file
@ -0,0 +1,141 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* \brief cuda kernels to transform a device memory tensor from NCHW layout to NHWC layout.
|
||||
*/
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/tensor_coord.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/** \brief interface to transform a device memory tensor from NCHW layout to NHWC layout.
|
||||
* \tparam T: data type
|
||||
*/
|
||||
template <typename T>
|
||||
void nchw_to_nhwc(cutlass::Tensor4DCoord input_tensor_size,
|
||||
cutlass::Tensor4DCoord output_tensor_size,
|
||||
TensorRef<T, layout::TensorNCHW> ref_input,
|
||||
TensorRef<T, layout::TensorNHWC> ref_output,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
__global__ void nchw_to_nhwc_kernel(T *output,
|
||||
const T *input,
|
||||
const int n,
|
||||
const int h,
|
||||
const int w,
|
||||
const int c) {
|
||||
const int hw = h*w;
|
||||
const int chw = c*hw;
|
||||
__shared__ T shbuf[32 * (32 + 1)];
|
||||
const int32_t tid = threadIdx.y*blockDim.x + threadIdx.x;
|
||||
const int32_t wid = tid / 32;
|
||||
const int32_t lid = tid % 32;
|
||||
const int32_t ni = blockIdx.z;
|
||||
const int32_t ci0 = blockIdx.y * 32;
|
||||
const int32_t hwi0 = blockIdx.x * 32;
|
||||
|
||||
const size_t input_idx = ni * chw + (ci0 + wid) * hw + hwi0;
|
||||
const T *A = input + input_idx;
|
||||
if (hwi0 + lid < hw) {
|
||||
const int lid_x_33 = lid * 33;
|
||||
if ((ci0 + 32) <= c) {
|
||||
int ci = wid; // between 0 and 7
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int cLoopIdx = 0; cLoopIdx < 4; cLoopIdx++) {
|
||||
shbuf[lid_x_33 + ci] = A[lid];
|
||||
A = &A[8 * hw];
|
||||
ci += 8;
|
||||
}
|
||||
} else {
|
||||
for (int ci = wid; ci < 32; ci += 8) {
|
||||
if ((ci + ci0) < c) {
|
||||
shbuf[lid_x_33 + ci] = A[lid];
|
||||
}
|
||||
A = &A[8 * hw];
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
const int32_t ciOut = ci0 + lid;
|
||||
output = &output[ni * chw + ciOut];
|
||||
if (ciOut < c) {
|
||||
if (hwi0 + 32 < hw) {
|
||||
int hwI = wid;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int hwLoopIdx = 0; hwLoopIdx < 4; ++hwLoopIdx) {
|
||||
output[(hwi0 + hwI) * c] = shbuf[(hwI)*33 + lid];
|
||||
hwI += 8;
|
||||
}
|
||||
} else {
|
||||
for (int hwI = wid; hwI < 32; hwI += 8) {
|
||||
if (hwi0 + hwI < hw) {
|
||||
output[(hwi0 + hwI) * c] = shbuf[(hwI)*33 + lid];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void nchw_to_nhwc(cutlass::Tensor4DCoord input_tensor_size,
|
||||
cutlass::Tensor4DCoord output_tensor_size,
|
||||
TensorRef<T, layout::TensorNCHW> ref_input,
|
||||
TensorRef<T, layout::TensorNHWC> ref_output,
|
||||
cudaStream_t stream) {
|
||||
|
||||
assert(
|
||||
input_tensor_size.n() == output_tensor_size.n() &&
|
||||
input_tensor_size.c() == output_tensor_size.h() &&
|
||||
input_tensor_size.h() == output_tensor_size.w() &&
|
||||
input_tensor_size.w() == output_tensor_size.c());
|
||||
|
||||
int n = output_tensor_size.n();
|
||||
int h = output_tensor_size.h();
|
||||
int w = output_tensor_size.w();
|
||||
int c = output_tensor_size.c();
|
||||
|
||||
dim3 grid((h*w + 31)/32, (c + 31)/32, n);
|
||||
dim3 block(32, 8);
|
||||
nchw_to_nhwc_kernel<<<grid, block, 0, stream>>>(ref_output.data(), ref_input.data(),
|
||||
n, h, w, c);
|
||||
}
|
||||
|
||||
} //namespace cutlass
|
||||
276
tools/util/include/cutlass/util/device_nhwc_padding.h
Normal file
276
tools/util/include/cutlass/util/device_nhwc_padding.h
Normal file
@ -0,0 +1,276 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* \brief cuda kernels for padding in device memory with NHWC layout.
|
||||
*/
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/tensor_coord.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/** \brief interface for padding in a device memory tensor with NHWC layout
|
||||
* \tparam T: data type
|
||||
*/
|
||||
template <typename T>
|
||||
void nhwc_padding(cutlass::Tensor4DCoord input_tensor_size,
|
||||
cutlass::Tensor4DCoord output_tensor_size,
|
||||
TensorRef<T, layout::TensorNHWC> ref_input,
|
||||
TensorRef<T, layout::TensorNHWC> ref_output,
|
||||
cudaStream_t stream);
|
||||
|
||||
|
||||
template <typename T>
|
||||
__global__ void nhwc_padding_kernel(const int32_t n,
|
||||
const int32_t h,
|
||||
const int32_t w,
|
||||
const int32_t c_in,
|
||||
const int32_t c_out,
|
||||
const T zero,
|
||||
const T *input,
|
||||
T *output){
|
||||
|
||||
const int32_t idx_jump = blockDim.x * gridDim.x;
|
||||
const int32_t total_elements = n * h * w * c_out;
|
||||
|
||||
int32_t c_idx, w_idx, h_idx, n_idx, resudial;
|
||||
|
||||
T value;
|
||||
for (int32_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total_elements; idx += idx_jump) {
|
||||
|
||||
c_idx = idx%c_out;
|
||||
if (c_idx >= c_in){
|
||||
value = zero;
|
||||
}
|
||||
else{
|
||||
resudial = idx/c_out;
|
||||
w_idx = resudial%w;
|
||||
resudial = resudial/w;
|
||||
h_idx = resudial%h;
|
||||
n_idx = resudial/h;
|
||||
resudial = ((n_idx * h + h_idx) * w + w_idx) * c_in + c_idx;
|
||||
value = input[resudial];
|
||||
}
|
||||
output[idx] = value;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// fast kernel for c_in = 3 & c_out = 4
|
||||
template <typename Tio, typename Telement, int element_in_Tio>
|
||||
__global__ void nhwc_padding_channel_3To4_kernel(const int32_t n,
|
||||
const int32_t h,
|
||||
const int32_t w,
|
||||
const Tio *input,
|
||||
Tio *output,
|
||||
const int32_t max_output_element,
|
||||
const int32_t max_input_element,
|
||||
const Tio zero_io,
|
||||
const Telement zero_element){
|
||||
__shared__ Tio shm[192];
|
||||
const int tidx = blockIdx.x * 192 + threadIdx.x;
|
||||
const int threadidx = threadIdx.x;
|
||||
|
||||
shm[threadIdx.x] = tidx >= max_input_element ? zero_io : input[tidx];
|
||||
__syncthreads();
|
||||
|
||||
const int ouput_offset = blockIdx.x * 256;
|
||||
const int lower_bound = max_output_element < ouput_offset + 256 ? max_output_element : ouput_offset + 256;
|
||||
for (int i = ouput_offset + threadidx, j = threadidx ; i < lower_bound ; i+=192, j+=192)
|
||||
{
|
||||
const Telement* shm_element = (const Telement*)shm + j*3*element_in_Tio/4;
|
||||
Telement array[element_in_Tio];
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k = 0 ; k < element_in_Tio ; k++)
|
||||
array[k] = ((k+1)%4 == 0) ? zero_element : shm_element[(k > 3) ? (k - 1) : k];
|
||||
output[i] = *((const Tio *)array);
|
||||
}
|
||||
}
|
||||
|
||||
// fast kernel for c_in = 3 & c_out = 8
|
||||
template <typename Tio, typename Telement, int element_in_Tio>
|
||||
__global__ void nhwc_padding_channel_3To8_kernel(const int32_t n,
|
||||
const int32_t h,
|
||||
const int32_t w,
|
||||
const Tio *input,
|
||||
Tio *output,
|
||||
const int32_t max_output_element,
|
||||
const int32_t max_input_element,
|
||||
const Tio zero_io,
|
||||
const Telement zero_element){
|
||||
__shared__ Tio shm[192];
|
||||
const int tidx = blockIdx.x * 192 + threadIdx.x;
|
||||
const int threadidx = threadIdx.x;
|
||||
|
||||
shm[threadIdx.x] = tidx >= max_input_element ? zero_io : input[tidx];
|
||||
__syncthreads();
|
||||
|
||||
const int ouput_offset = blockIdx.x * 512;
|
||||
const int lower_bound = max_output_element < ouput_offset + 512 ? max_output_element : ouput_offset + 512;
|
||||
for (int i = ouput_offset + threadidx, j = threadidx ; i < lower_bound ; i+=192, j+=192)
|
||||
{
|
||||
const Telement* shm_element = (const Telement*)shm + (element_in_Tio == 4 ? j/2 : j)*3;
|
||||
Telement array[element_in_Tio];
|
||||
//float
|
||||
if (element_in_Tio == 4){
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k = 0 ; k < element_in_Tio ; k++)
|
||||
array[k] = ((j % 2) == 1) ? zero_element : ((k >= 3) ? zero_element : shm_element[k]);
|
||||
}
|
||||
//half
|
||||
else{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k = 0 ; k < element_in_Tio ; k++)
|
||||
array[k] = (k >= 3) ? zero_element : shm_element[k];
|
||||
}
|
||||
output[i] = *((const Tio *)array);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void nhwc_padding(cutlass::Tensor4DCoord input_tensor_size,
|
||||
cutlass::Tensor4DCoord output_tensor_size,
|
||||
TensorRef<T, layout::TensorNHWC> ref_input,
|
||||
TensorRef<T, layout::TensorNHWC> ref_output,
|
||||
cudaStream_t stream){
|
||||
assert(
|
||||
input_tensor_size.n() == output_tensor_size.n() &&
|
||||
input_tensor_size.h() == output_tensor_size.h() &&
|
||||
input_tensor_size.w() == output_tensor_size.w() &&
|
||||
input_tensor_size.c() <= output_tensor_size.c());
|
||||
|
||||
int n = input_tensor_size.n();
|
||||
int h = input_tensor_size.h();
|
||||
int w = input_tensor_size.w();
|
||||
int c_in = input_tensor_size.c();
|
||||
int c_out = output_tensor_size.c();
|
||||
|
||||
//case 1 : channel == 3 padding to 4 or 8
|
||||
if ((c_out == 4 || c_out == 8) && c_in == 3 && (n*h*w % 8 == 0)){
|
||||
dim3 block(192);
|
||||
const int nhw = n*h*w;
|
||||
const int nhwc = nhw*c_in;
|
||||
//for half_t
|
||||
if (cutlass::sizeof_bits<T>::value == 16){
|
||||
const int element_in_Tio = 8;
|
||||
const int max_input_element = nhwc/element_in_Tio;
|
||||
const int max_output_element = nhw*c_out/element_in_Tio;
|
||||
const int4 zero_io = {0, 0, 0, 0};
|
||||
const half_t zero_element = static_cast<half_t>(0.0f);
|
||||
dim3 grid((nhwc + 192*element_in_Tio - 1)/(192*element_in_Tio));
|
||||
if (c_out == 4){
|
||||
nhwc_padding_channel_3To4_kernel<int4, half_t, element_in_Tio><<<grid, block, 0, stream>>>
|
||||
(n, h, w,
|
||||
(const int4 *)ref_input.data(),
|
||||
(int4 *)ref_output.data(),
|
||||
max_output_element,
|
||||
max_input_element,
|
||||
zero_io,
|
||||
zero_element);
|
||||
}
|
||||
else if (c_out == 8){
|
||||
nhwc_padding_channel_3To8_kernel<int4, half_t, element_in_Tio><<<grid, block, 0, stream>>>
|
||||
(n, h, w,
|
||||
(const int4 *)ref_input.data(),
|
||||
(int4 *)ref_output.data(),
|
||||
max_output_element,
|
||||
max_input_element,
|
||||
zero_io,
|
||||
zero_element);
|
||||
}
|
||||
}
|
||||
//for float
|
||||
else{
|
||||
const int element_in_Tio = 4;
|
||||
const int max_input_element = nhwc/element_in_Tio;
|
||||
const int max_output_element = nhw*c_out/element_in_Tio;
|
||||
const float4 zero_io = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||
const float zero_element = 0.0f;
|
||||
dim3 grid((nhwc + 192*element_in_Tio - 1)/(192*element_in_Tio));
|
||||
if (c_out == 4){
|
||||
nhwc_padding_channel_3To4_kernel<float4, float, element_in_Tio><<<grid, block, 0, stream>>>
|
||||
(n, h, w,
|
||||
(const float4 *)ref_input.data(),
|
||||
(float4 *)ref_output.data(),
|
||||
max_output_element,
|
||||
max_input_element,
|
||||
zero_io,
|
||||
zero_element);
|
||||
}
|
||||
else if (c_out == 8){
|
||||
nhwc_padding_channel_3To8_kernel<float4, float, element_in_Tio><<<grid, block, 0, stream>>>
|
||||
(n, h, w,
|
||||
(const float4 *)ref_input.data(),
|
||||
(float4 *)ref_output.data(),
|
||||
max_output_element,
|
||||
max_input_element,
|
||||
zero_io,
|
||||
zero_element);
|
||||
}
|
||||
}
|
||||
}
|
||||
//case 2 : even channel
|
||||
else if ((c_out % 2) == 0 && (c_in % 2) == 0){
|
||||
int32_t total_elements = n * h * w * c_out / 2;
|
||||
int block_size = 256;
|
||||
dim3 grid((total_elements + 255)/256);
|
||||
dim3 block(block_size);
|
||||
//for half_t
|
||||
if (cutlass::sizeof_bits<T>::value == 16){
|
||||
const __half2 zero = {0.0f, 0.0f};
|
||||
nhwc_padding_kernel<<<grid, block, 0, stream>>>(n, h, w, c_in/2, c_out/2, zero, (const __half2*)ref_input.data(), (__half2*)ref_output.data());
|
||||
}
|
||||
//for float
|
||||
else{
|
||||
const float2 zero = {0.0f, 0.0f};
|
||||
nhwc_padding_kernel<<<grid, block, 0, stream>>>(n, h, w, c_in/2, c_out/2, zero, (const float2*)ref_input.data(), (float2*)ref_output.data());
|
||||
}
|
||||
}
|
||||
//case 3 : odd channel
|
||||
else{
|
||||
int32_t total_elements = n * h * w * c_out;
|
||||
int block_size = 256;
|
||||
dim3 grid((total_elements + 255)/256);
|
||||
dim3 block(block_size);
|
||||
const T zero = static_cast<T>(0.0f);
|
||||
nhwc_padding_kernel<<<grid, block, 0, stream>>>(n, h, w, c_in, c_out, zero, ref_input.data(), ref_output.data());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
} //namespace cutlass
|
||||
144
tools/util/include/cutlass/util/device_nhwc_to_nchw.h
Normal file
144
tools/util/include/cutlass/util/device_nhwc_to_nchw.h
Normal file
@ -0,0 +1,144 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* \brief cuda kernels to transform a device memory tensor from NHWC layout to NCHW layout.
|
||||
*/
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/tensor_coord.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/** \brief interface to transform a device memory tensor from NHWC layout to NCHW layout.
|
||||
* \tparam T: data type
|
||||
*/
|
||||
template <typename T>
|
||||
void nhwc_to_nchw(cutlass::Tensor4DCoord input_tensor_size,
|
||||
cutlass::Tensor4DCoord output_tensor_size,
|
||||
TensorRef<T, layout::TensorNHWC> ref_input,
|
||||
TensorRef<T, layout::TensorNCHW> ref_output,
|
||||
cudaStream_t stream);
|
||||
|
||||
|
||||
template <typename T>
|
||||
__global__ void nhwc_to_nchw_kernel(T *output,
|
||||
const T *input,
|
||||
const int n,
|
||||
const int h,
|
||||
const int w,
|
||||
const int c) {
|
||||
|
||||
const int hw = h*w;
|
||||
const int hwc = hw*c;
|
||||
__shared__ T shbuf[32 * (32 + 1)];
|
||||
const int32_t tid = threadIdx.y*blockDim.x + threadIdx.x;
|
||||
const int32_t wid = tid / 32;
|
||||
const int32_t lid = tid % 32;
|
||||
const int32_t ni = blockIdx.z;
|
||||
const int32_t hwi0 = blockIdx.y * 32;
|
||||
const int32_t ci0 = blockIdx.x * 32;
|
||||
|
||||
const size_t input_idx = ni * hwc + (hwi0 + wid) * c + ci0;
|
||||
const T *A = input + input_idx;
|
||||
if (ci0 + lid < c) {
|
||||
const int lid_x_33 = lid * 33;
|
||||
if ((hwi0 + 32) <= hw) {
|
||||
int hwi = wid; // between 0 and 7
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int cLoopIdx = 0; cLoopIdx < 4; cLoopIdx++) {
|
||||
shbuf[lid_x_33 + hwi] = A[lid];
|
||||
A = &A[8 * c];
|
||||
hwi += 8;
|
||||
}
|
||||
} else {
|
||||
for (int hwi = wid; hwi < 32; hwi += 8) {
|
||||
if ((hwi + hwi0) < hw) {
|
||||
shbuf[lid_x_33 + hwi] = A[lid];
|
||||
}
|
||||
A = &A[8 * c];
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
const int32_t hwiOut = hwi0 + lid;
|
||||
output = &output[ni * hwc + hwiOut];
|
||||
if (hwiOut < hw) {
|
||||
if (ci0 + 32 < c) {
|
||||
int cI = wid;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int hwLoopIdx = 0; hwLoopIdx < 4; ++hwLoopIdx) {
|
||||
output[(ci0 + cI) * hw] = shbuf[(cI)*33 + lid];
|
||||
cI += 8;
|
||||
}
|
||||
} else {
|
||||
for (int cI = wid; cI < 32; cI += 8) {
|
||||
if (ci0 + cI < c) {
|
||||
output[(ci0 + cI) * hw] = shbuf[(cI)*33 + lid];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void nhwc_to_nchw(cutlass::Tensor4DCoord input_tensor_size,
|
||||
cutlass::Tensor4DCoord output_tensor_size,
|
||||
TensorRef<T, layout::TensorNHWC> ref_input,
|
||||
TensorRef<T, layout::TensorNCHW> ref_output,
|
||||
cudaStream_t stream) {
|
||||
|
||||
assert(
|
||||
input_tensor_size.n() == output_tensor_size.n() &&
|
||||
input_tensor_size.h() == output_tensor_size.c() &&
|
||||
input_tensor_size.w() == output_tensor_size.h() &&
|
||||
input_tensor_size.c() == output_tensor_size.w());
|
||||
|
||||
int n = input_tensor_size.n();
|
||||
int h = input_tensor_size.h();
|
||||
int w = input_tensor_size.w();
|
||||
int c = input_tensor_size.c();
|
||||
|
||||
dim3 grid((c + 31)/32, (h*w + 31)/32, n);
|
||||
dim3 block(32, 8);
|
||||
nhwc_to_nchw_kernel<<<grid, block, 0, stream>>>(ref_output.data(), ref_input.data(),
|
||||
n, h, w, c);
|
||||
|
||||
}
|
||||
|
||||
} //namespace cutlass
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,19 +1,31 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are not permitted.
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,52 +1,38 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <utility>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
/**
|
||||
* \file
|
||||
* \brief C++11 version of index_sequence.
|
||||
*/
|
||||
// integer_sequence moved to cutlass/numeric_types.h
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
template <size_t... Seq>
|
||||
struct index_sequence;
|
||||
|
||||
template <size_t N, size_t... Next>
|
||||
struct index_sequence_helper : index_sequence_helper<N - 1, N - 1, Next...> {};
|
||||
|
||||
template <size_t... Next>
|
||||
struct index_sequence_helper<0, 0, Next...> {
|
||||
using type = index_sequence<0, Next...>;
|
||||
};
|
||||
|
||||
template <size_t N>
|
||||
using make_index_sequence = typename index_sequence_helper<N>::type;
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -227,36 +233,80 @@ void GemmComplex(
|
||||
batch_count % std::numeric_limits<uint16_t>::max()
|
||||
);
|
||||
|
||||
kernel::GemmComplex<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ScalarType,
|
||||
ComputeType,
|
||||
ConvertOp,
|
||||
InnerProductOp,
|
||||
kMblock,
|
||||
kNblock
|
||||
><<< grid, block >>>(
|
||||
problem_size,
|
||||
alpha,
|
||||
tensor_a,
|
||||
transform_a,
|
||||
tensor_b,
|
||||
transform_b,
|
||||
beta,
|
||||
tensor_c,
|
||||
tensor_d,
|
||||
initial_accum,
|
||||
batch_count,
|
||||
batch_stride_A,
|
||||
batch_stride_B,
|
||||
batch_stride_C,
|
||||
batch_stride_D
|
||||
);
|
||||
if (grid.y <= std::numeric_limits<uint16_t>::max()) {
|
||||
kernel::GemmComplex<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ScalarType,
|
||||
ComputeType,
|
||||
ConvertOp,
|
||||
InnerProductOp,
|
||||
kMblock,
|
||||
kNblock
|
||||
><<< grid, block >>>(
|
||||
problem_size,
|
||||
alpha,
|
||||
tensor_a,
|
||||
transform_a,
|
||||
tensor_b,
|
||||
transform_b,
|
||||
beta,
|
||||
tensor_c,
|
||||
tensor_d,
|
||||
initial_accum,
|
||||
batch_count,
|
||||
batch_stride_A,
|
||||
batch_stride_B,
|
||||
batch_stride_C,
|
||||
batch_stride_D
|
||||
);
|
||||
} else {
|
||||
// Using bigger thread tile size
|
||||
int const kBigMblock = 4;
|
||||
int const kBigNblock = 16;
|
||||
|
||||
dim3 Bigblock(16, 8);
|
||||
dim3 Biggrid(
|
||||
(problem_size.m() + block.x * kBigMblock - 1) / (block.x * kBigMblock),
|
||||
(problem_size.n() + block.y * kBigNblock - 1) / (block.y * kBigNblock),
|
||||
batch_count % std::numeric_limits<uint16_t>::max()
|
||||
);
|
||||
|
||||
kernel::GemmComplex<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ScalarType,
|
||||
ComputeType,
|
||||
ConvertOp,
|
||||
InnerProductOp,
|
||||
kBigMblock,
|
||||
kBigNblock
|
||||
><<< Biggrid, Bigblock >>>(
|
||||
problem_size,
|
||||
alpha,
|
||||
tensor_a,
|
||||
transform_a,
|
||||
tensor_b,
|
||||
transform_b,
|
||||
beta,
|
||||
tensor_c,
|
||||
tensor_d,
|
||||
initial_accum,
|
||||
batch_count,
|
||||
batch_stride_A,
|
||||
batch_stride_B,
|
||||
batch_stride_C,
|
||||
batch_stride_D
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -0,0 +1,355 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Reference implementation for complex-valued GEMM in device-side code.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/blas3.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace reference {
|
||||
namespace device {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace kernel {
|
||||
|
||||
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
||||
/// objects.
|
||||
///
|
||||
/// Explicitly naming types needed by this template can be cumbersome, particularly for the
|
||||
/// accumulator type, so a function argument 'initial_accum' is exposed. Passing
|
||||
/// AccumulatorType(0) as the last function argument can be easier than naming all template
|
||||
/// arguments explicitly.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ScalarType,
|
||||
typename ComputeType,
|
||||
typename ConvertOp = NumericConverter<ElementC, ScalarType>,
|
||||
typename InnerProductOp = multiply_add<ComputeType>,
|
||||
int kMblock = 4,
|
||||
int kNblock = 4
|
||||
>
|
||||
__global__ void Rank2KComplex(
|
||||
gemm::GemmCoord problem_size,
|
||||
ScalarType alpha,
|
||||
TensorRef<ElementA, LayoutA> tensor_a,
|
||||
ComplexTransform transform_a,
|
||||
TensorRef<ElementB, LayoutB> tensor_b,
|
||||
ComplexTransform transform_b,
|
||||
ScalarType beta,
|
||||
TensorRef<ElementC, LayoutC> tensor_c,
|
||||
TensorRef<ElementC, LayoutC> tensor_d,
|
||||
ComputeType initial_accum,
|
||||
FillMode fill_mode_c,
|
||||
BlasMode blas_mode,
|
||||
int batch_count = 1,
|
||||
int64_t batch_stride_A = 0,
|
||||
int64_t batch_stride_B = 0,
|
||||
int64_t batch_stride_C = 0,
|
||||
int64_t batch_stride_D = 0) {
|
||||
|
||||
static_assert(
|
||||
LayoutA::kRank == 2 &&
|
||||
LayoutB::kRank == 2 &&
|
||||
LayoutC::kRank == 2, "Tensors must be of rank 2");
|
||||
|
||||
int const M = problem_size.m();
|
||||
int const N = problem_size.n();
|
||||
int const K = problem_size.k();
|
||||
|
||||
assert(M=N);
|
||||
|
||||
ConvertOp convert_op;
|
||||
InnerProductOp inner_product_op;
|
||||
|
||||
int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock;
|
||||
int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock;
|
||||
int batch_idx = blockIdx.z;
|
||||
|
||||
tensor_a.add_pointer_offset(batch_idx * batch_stride_A);
|
||||
tensor_b.add_pointer_offset(batch_idx * batch_stride_B);
|
||||
tensor_c.add_pointer_offset(batch_idx * batch_stride_C);
|
||||
tensor_d.add_pointer_offset(batch_idx * batch_stride_D);
|
||||
|
||||
for (; batch_idx < batch_count; batch_idx += gridDim.z) {
|
||||
|
||||
// Compute matrix product using blocks
|
||||
ComputeType accum[kMblock][kNblock];
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < kNblock; j++) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kMblock; i++) {
|
||||
accum[i][j] = initial_accum;
|
||||
}
|
||||
}
|
||||
|
||||
for (int k_block = 0; k_block < K; ++k_block) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < kNblock; j++) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kMblock; i++) {
|
||||
int row = row_block + i;
|
||||
int col = col_block + j;
|
||||
|
||||
if (row < M && col < N &&
|
||||
( (fill_mode_c == FillMode::kLower && row >= col) ||
|
||||
(fill_mode_c == FillMode::kUpper && row <= col) )
|
||||
) {
|
||||
|
||||
// A x B^T (Symmetric) or A x B^H (Hermitian)
|
||||
// complex conjugation on operandB (b_t) is function of blas3 computation
|
||||
ElementA a = tensor_a.at(MatrixCoord(row, k_block));
|
||||
ElementB b_t = (blas_mode == BlasMode::kHermitian) ?
|
||||
conj(tensor_b.at(MatrixCoord(col, k_block))) :
|
||||
tensor_b.at(MatrixCoord(col, k_block));
|
||||
|
||||
ComputeType a_ik = ComputeType(a);
|
||||
ComputeType b_jk = ComputeType(b_t);
|
||||
|
||||
// complex conjugation is a function of operand layouts
|
||||
if (transform_a == ComplexTransform::kConjugate) {
|
||||
a_ik = conj(a_ik);
|
||||
}
|
||||
// complex conjugation is a function of operand layouts
|
||||
if (transform_b == ComplexTransform::kConjugate) {
|
||||
b_jk = conj(b_jk);
|
||||
}
|
||||
|
||||
accum[i][j] = inner_product_op(a_ik, b_jk, accum[i][j]);
|
||||
|
||||
// B x A^T (Symmetric) or B x A^H (Hermitian)
|
||||
// complex conjugation on operandB (a_t) is function of blas3 computation
|
||||
ElementB b = tensor_b.at(MatrixCoord(row, k_block));
|
||||
ElementA a_t = (blas_mode == BlasMode::kHermitian) ?
|
||||
conj(tensor_a.at(MatrixCoord(col, k_block))):
|
||||
tensor_a.at(MatrixCoord(col, k_block));
|
||||
|
||||
ComputeType b_ik = ComputeType(b);
|
||||
ComputeType a_jk = ComputeType(a_t);
|
||||
|
||||
// complex conjugation here is a function of operand layouts
|
||||
if (transform_b == ComplexTransform::kConjugate) {
|
||||
b_ik = conj(b_ik);
|
||||
}
|
||||
// complex conjugation here is a function of operand layouts
|
||||
if (transform_a == ComplexTransform::kConjugate) {
|
||||
a_jk = conj(a_jk);
|
||||
}
|
||||
|
||||
accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < kNblock; j++) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kMblock; i++) {
|
||||
int row = row_block + i;
|
||||
int col = col_block + j;
|
||||
|
||||
MatrixCoord coord = MatrixCoord(row, col);
|
||||
|
||||
if (row < M && col < N &&
|
||||
((fill_mode_c == FillMode::kLower && row >= col) ||
|
||||
(fill_mode_c == FillMode::kUpper && row <= col))
|
||||
) {
|
||||
|
||||
ScalarType c = tensor_c.at(coord);
|
||||
// The imaginary parts of the diagonal elements of
|
||||
// a complex data type are assumed and set to zero
|
||||
if (blas_mode == BlasMode::kHermitian) {
|
||||
c = (row == col) ? real(c) : c;
|
||||
}
|
||||
|
||||
tensor_d.at(coord) = convert_op(
|
||||
alpha * ScalarType(accum[i][j]) +
|
||||
beta * c);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tensor_a.add_pointer_offset(batch_stride_A * gridDim.z);
|
||||
tensor_b.add_pointer_offset(batch_stride_B * gridDim.z);
|
||||
tensor_c.add_pointer_offset(batch_stride_C * gridDim.z);
|
||||
tensor_d.add_pointer_offset(batch_stride_D * gridDim.z);
|
||||
|
||||
} // for (batch_idx)
|
||||
}
|
||||
|
||||
} // namespace kernel
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
||||
/// objects.
|
||||
///
|
||||
/// Explicitly naming types needed by this template can be cumbersome, particularly for the
|
||||
/// accumulator type, so a function argument 'initial_accum' is exposed. Passing
|
||||
/// AccumulatorType(0) as the last function argument can be easier than naming all template
|
||||
/// arguments explicitly.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ScalarType,
|
||||
typename ComputeType,
|
||||
typename ConvertOp = NumericConverter<ElementC, ScalarType>,
|
||||
typename InnerProductOp = multiply_add<ComputeType>
|
||||
>
|
||||
void Rank2KComplex(
|
||||
gemm::GemmCoord problem_size,
|
||||
ScalarType alpha,
|
||||
TensorRef<ElementA, LayoutA> tensor_a,
|
||||
ComplexTransform transform_a,
|
||||
TensorRef<ElementB, LayoutB> tensor_b,
|
||||
ComplexTransform transform_b,
|
||||
ScalarType beta,
|
||||
TensorRef<ElementC, LayoutC> tensor_c,
|
||||
TensorRef<ElementC, LayoutC> tensor_d,
|
||||
ComputeType initial_accum,
|
||||
FillMode fill_mode_c,
|
||||
BlasMode blas_mode,
|
||||
int batch_count = 1,
|
||||
int64_t batch_stride_A = 0,
|
||||
int64_t batch_stride_B = 0,
|
||||
int64_t batch_stride_C = 0,
|
||||
int64_t batch_stride_D = 0) {
|
||||
|
||||
static_assert(
|
||||
LayoutA::kRank == 2 &&
|
||||
LayoutB::kRank == 2 &&
|
||||
LayoutC::kRank == 2, "Tensors must be of rank 2");
|
||||
|
||||
int const kMblock = 4;
|
||||
int const kNblock = 4;
|
||||
|
||||
dim3 block(16, 8);
|
||||
dim3 grid(
|
||||
(problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock),
|
||||
(problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock),
|
||||
batch_count % std::numeric_limits<uint16_t>::max()
|
||||
);
|
||||
|
||||
kernel::Rank2KComplex<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ScalarType,
|
||||
ComputeType,
|
||||
ConvertOp,
|
||||
InnerProductOp,
|
||||
kMblock,
|
||||
kNblock
|
||||
><<< grid, block >>>(
|
||||
problem_size,
|
||||
alpha,
|
||||
tensor_a,
|
||||
transform_a,
|
||||
tensor_b,
|
||||
transform_b,
|
||||
beta,
|
||||
tensor_c,
|
||||
tensor_d,
|
||||
initial_accum,
|
||||
fill_mode_c,
|
||||
blas_mode,
|
||||
batch_count,
|
||||
batch_stride_A,
|
||||
batch_stride_B,
|
||||
batch_stride_C,
|
||||
batch_stride_D
|
||||
);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
||||
/// objects.
|
||||
///
|
||||
/// This assumes the accumulator type is the same type as the scalars.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ScalarType
|
||||
>
|
||||
void Rank2KComplex(
|
||||
gemm::GemmCoord problem_size,
|
||||
ScalarType alpha,
|
||||
TensorRef<ElementA, LayoutA> tensor_a,
|
||||
ComplexTransform transform_a,
|
||||
TensorRef<ElementB, LayoutB> tensor_b,
|
||||
ComplexTransform transform_b,
|
||||
ScalarType beta,
|
||||
TensorRef<ElementC, LayoutC> tensor_c,
|
||||
TensorRef<ElementC, LayoutC> tensor_d,
|
||||
FillMode fill_mode_c,
|
||||
BlasMode blas_mode) {
|
||||
|
||||
Rank2KComplex(
|
||||
problem_size, alpha,
|
||||
tensor_a, transform_a,
|
||||
tensor_b, transform_b,
|
||||
beta, tensor_c, tensor_d,
|
||||
ScalarType(0),
|
||||
fill_mode_c,
|
||||
blas_mode);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace device
|
||||
} // namespace reference
|
||||
} // namespace cutlass
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -50,6 +56,8 @@
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/blas3.h"
|
||||
|
||||
#include "cutlass/util/reference/device/tensor_foreach.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
|
||||
@ -1006,6 +1014,224 @@ struct TensorFillDiagonalFunc {
|
||||
}
|
||||
};
|
||||
|
||||
// Overwrites the elements of a tensor with a uniform value depending on fill mode
|
||||
template <
|
||||
typename Element, ///< Element type
|
||||
typename Layout> ///< Layout function
|
||||
struct TensorFillPartialFunc {
|
||||
|
||||
/// View type
|
||||
using TensorView = TensorView<Element, Layout>;
|
||||
|
||||
/// Scalar type
|
||||
typedef typename TensorView::Element T;
|
||||
|
||||
/// Coordinate in tensor's index space
|
||||
typedef typename TensorView::TensorCoord TensorCoord;
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
TensorView view;
|
||||
Element element;
|
||||
FillMode fill_mode;
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): fill_mode(FillMode::kNone) { }
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Construction of Gaussian RNG functor.
|
||||
Params(
|
||||
TensorView view_,
|
||||
Element element_,
|
||||
FillMode fill_mode_
|
||||
):
|
||||
view(view_), element(element_), fill_mode(fill_mode_) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Parameters object
|
||||
Params params;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
TensorFillPartialFunc(Params const ¶ms): params(params) {
|
||||
|
||||
}
|
||||
|
||||
/// Overwrites the element if it is within the covered region.
|
||||
CUTLASS_DEVICE
|
||||
void operator()(TensorCoord const &coord) {
|
||||
|
||||
bool predicate = true;
|
||||
|
||||
switch (params.fill_mode) {
|
||||
case FillMode::kFull:
|
||||
predicate = true;
|
||||
break;
|
||||
|
||||
case FillMode::kLower:
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 1; i < Layout::kRank; ++i) {
|
||||
if (coord[i - 1] < coord[i]) {
|
||||
predicate = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
case FillMode::kUpper:
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 1; i < Layout::kRank; ++i) {
|
||||
if (coord[i - 1] > coord[i]) {
|
||||
predicate = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
case FillMode::kDiagonal:
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 1; i < Layout::kRank; ++i) {
|
||||
if (coord[i - 1] != coord[i]) {
|
||||
predicate = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
case FillMode::kNone: // fall-through
|
||||
|
||||
default:
|
||||
predicate = false;
|
||||
break;
|
||||
}
|
||||
|
||||
if (predicate) {
|
||||
params.view.at(coord) = params.element;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <
|
||||
typename Element, ///< Element type
|
||||
typename Layout> ///< Layout function
|
||||
struct TensorClearPartialFunc {
|
||||
|
||||
/// View type
|
||||
using TensorView = TensorView<Element, Layout>;
|
||||
|
||||
/// Scalar type
|
||||
typedef typename TensorView::Element T;
|
||||
|
||||
/// Coordinate in tensor's index space
|
||||
typedef typename TensorView::TensorCoord TensorCoord;
|
||||
|
||||
///
|
||||
static_assert((Layout::kRank == 2), "TensorClearPartial is only supported for matrices");
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
TensorView view;
|
||||
Element element;
|
||||
FillMode fill_mode;
|
||||
int alignment;
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): fill_mode(FillMode::kNone) { }
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Construction of Gaussian RNG functor.
|
||||
Params(
|
||||
TensorView view_,
|
||||
Element element_,
|
||||
FillMode fill_mode_,
|
||||
int alignment_
|
||||
):
|
||||
view(view_), element(element_), fill_mode(fill_mode_), alignment(alignment_) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Parameters object
|
||||
Params params;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
TensorClearPartialFunc(Params const ¶ms): params(params) {
|
||||
|
||||
}
|
||||
|
||||
/// Overwrites the element if it is within the covered region.
|
||||
CUTLASS_DEVICE
|
||||
void operator()(TensorCoord const &coord) {
|
||||
|
||||
bool predicate = true;
|
||||
|
||||
switch (params.fill_mode) {
|
||||
|
||||
case FillMode::kLower:
|
||||
if ((coord[0] >= coord[1]) ||
|
||||
((coord[1] - coord[0]) >= params.alignment)) {
|
||||
predicate = false;
|
||||
break;
|
||||
}
|
||||
break;
|
||||
|
||||
case FillMode::kUpper:
|
||||
if ((coord[0] <= coord[1]) ||
|
||||
((coord[0] - coord[1]) >= params.alignment)) {
|
||||
predicate = false;
|
||||
break;
|
||||
}
|
||||
break;
|
||||
|
||||
case FillMode::kNone: // fall-through
|
||||
|
||||
default:
|
||||
predicate = false;
|
||||
break;
|
||||
}
|
||||
|
||||
if (predicate) {
|
||||
params.view.at(coord) = params.element;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1028,6 +1254,45 @@ void TensorFillDiagonal(
|
||||
);
|
||||
}
|
||||
|
||||
/// Fills a tensor partially depending on fill mode. Elements not covered by the fillmode are
|
||||
/// not written.
|
||||
template <
|
||||
typename Element, ///< Element type
|
||||
typename Layout> ///< Layout function
|
||||
void TensorFillPartial(
|
||||
TensorView<Element, Layout> view, ///< destination tensor
|
||||
Element element,
|
||||
FillMode fill_mode) {
|
||||
|
||||
typedef detail::TensorFillPartialFunc<Element, Layout> Func;
|
||||
typedef typename Func::Params Params;
|
||||
|
||||
TensorForEach<Func, Layout::kRank, Params>(
|
||||
view.extent(),
|
||||
Params(view, element, fill_mode)
|
||||
);
|
||||
}
|
||||
|
||||
/// Clears a tensor partially depending on fill mode and alignment. Elements on the wrong-side
|
||||
/// of fillmode (upto the alignment) are overwritten with the user supplied element (typically zeros)
|
||||
template <
|
||||
typename Element, ///< Element type
|
||||
typename Layout> ///< Layout function
|
||||
void TensorClearPartial(
|
||||
TensorView<Element, Layout> view, ///< destination tensor
|
||||
Element element,
|
||||
FillMode fill_mode,
|
||||
int alignment) {
|
||||
|
||||
typedef detail::TensorClearPartialFunc<Element, Layout> Func;
|
||||
typedef typename Func::Params Params;
|
||||
|
||||
TensorForEach<Func, Layout::kRank, Params>(
|
||||
view.extent(),
|
||||
Params(view, element, fill_mode, alignment)
|
||||
);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Fills a tensor with a uniform value
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,25 +1,31 @@
|
||||
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
261
tools/util/include/cutlass/util/reference/host/rank_2k.h
Normal file
261
tools/util/include/cutlass/util/reference/host/rank_2k.h
Normal file
@ -0,0 +1,261 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Reference implementation for Rank 2k update in host-side code.
|
||||
|
||||
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/blas3.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/arch/mma.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace reference {
|
||||
namespace host {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
||||
/// objects.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
FillMode FillModeC,
|
||||
typename ScalarType,
|
||||
typename ComputeType,
|
||||
typename InnerProductOp = multiply_add<ComputeType>,
|
||||
typename ConvertOp = NumericConverter<ElementC, ScalarType>
|
||||
>
|
||||
void compute_rank2k(
|
||||
gemm::GemmCoord problem_size,
|
||||
ScalarType alpha,
|
||||
TensorRef<ElementA, LayoutA> tensor_a,
|
||||
TensorRef<ElementB, LayoutB> tensor_b,
|
||||
ScalarType beta,
|
||||
TensorRef<ElementC, LayoutC> tensor_c,
|
||||
TensorRef<ElementC, LayoutC> tensor_d,
|
||||
ComputeType initial_accum) {
|
||||
|
||||
static_assert(
|
||||
LayoutA::kRank == 2 &&
|
||||
LayoutB::kRank == 2 &&
|
||||
LayoutC::kRank == 2,
|
||||
"Tensors must be of rank 2");
|
||||
|
||||
static_assert(
|
||||
FillModeC == FillMode::kLower ||
|
||||
FillModeC == FillMode::kUpper,
|
||||
"Fill Mode can either be Lower or Upper.");
|
||||
|
||||
using CompareOp = typename platform::conditional<(FillModeC == FillMode::kLower),
|
||||
std::greater_equal<int>,
|
||||
std::less_equal<int>>::type;
|
||||
|
||||
// Note: batch is ignored.
|
||||
// Note: M is same as N for Rank 2k update
|
||||
int const N = problem_size.n();
|
||||
int const K = problem_size.k();
|
||||
|
||||
// Blocking necessary to speedup reference implementation
|
||||
int const Nblock = 16;
|
||||
|
||||
ConvertOp convert_op;
|
||||
InnerProductOp inner_product_op;
|
||||
CompareOp compare_op;
|
||||
|
||||
for (int row_block = 0; row_block < N; row_block += Nblock) {
|
||||
for (int col_block = 0; col_block < N; col_block += Nblock) {
|
||||
|
||||
ComputeType accum[Nblock][Nblock];
|
||||
|
||||
for (int j = 0; j < Nblock; j++) {
|
||||
for (int i = 0; i < Nblock; i++) {
|
||||
accum[i][j] = initial_accum;
|
||||
}
|
||||
}
|
||||
|
||||
for (int k_block = 0; k_block < K; ++k_block) {
|
||||
for (int j = 0; j < Nblock; j++) {
|
||||
for (int i = 0; i < Nblock; i++) {
|
||||
int row = row_block + i;
|
||||
int col = col_block + j;
|
||||
|
||||
if (row < N && col < N && compare_op(row, col))
|
||||
{
|
||||
|
||||
// A x B^T
|
||||
ElementA a = tensor_a.at(MatrixCoord(row, k_block));
|
||||
ElementB b_t = tensor_b.at(MatrixCoord(col, k_block));
|
||||
|
||||
ComputeType compute_a(cast_if_scalar<ComputeType>(a));
|
||||
ComputeType compute_b_t(cast_if_scalar<ComputeType>(b_t));
|
||||
|
||||
accum[i][j] = inner_product_op(compute_a, compute_b_t, accum[i][j]);
|
||||
|
||||
// B x A^T
|
||||
ElementB b = tensor_b.at(MatrixCoord(row, k_block));
|
||||
ElementA a_t = tensor_a.at(MatrixCoord(col, k_block));
|
||||
|
||||
ComputeType compute_b(cast_if_scalar<ComputeType>(b));
|
||||
ComputeType compute_a_t(cast_if_scalar<ComputeType>(a_t));
|
||||
|
||||
accum[i][j] = inner_product_op(compute_b, compute_a_t, accum[i][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int j = 0; j < Nblock; j++) {
|
||||
for (int i = 0; i < Nblock; i++) {
|
||||
int row = row_block + i;
|
||||
int col = col_block + j;
|
||||
|
||||
MatrixCoord coord = MatrixCoord(row, col);
|
||||
|
||||
if (row < N && col < N &&
|
||||
( (FillModeC == FillMode::kLower && row >= col) ||
|
||||
(FillModeC == FillMode::kUpper && row <= col) )
|
||||
) {
|
||||
tensor_d.at(coord) = convert_op(
|
||||
alpha * ScalarType(accum[i][j]) +
|
||||
beta * ScalarType(tensor_c.at(coord)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Computes a general Rank 2k update (tensors of rank=2) pointed to by TensorRef
|
||||
/// objects.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
FillMode FillModeC,
|
||||
typename ScalarType,
|
||||
typename ComputeType,
|
||||
typename InnerProductOp = multiply_add<ComputeType>,
|
||||
typename ConvertOp = NumericConverter<ElementC, ScalarType>
|
||||
>
|
||||
void compute_rank2k(
|
||||
gemm::GemmCoord problem_size,
|
||||
ScalarType alpha,
|
||||
TensorRef<ElementA, LayoutA> tensor_a,
|
||||
TensorRef<ElementB, LayoutB> tensor_b,
|
||||
ScalarType beta,
|
||||
TensorRef<ElementC, LayoutC> tensor_c,
|
||||
ComputeType initial_accum) {
|
||||
compute_rank2k<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, FillModeC,
|
||||
ScalarType, ComputeType, InnerProductOp, ConvertOp>(
|
||||
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c,
|
||||
initial_accum);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
FillMode FillModeC,
|
||||
typename ScalarType,
|
||||
typename ComputeType,
|
||||
typename InnerProductOp = cutlass::arch::OpMultiplyAdd
|
||||
>
|
||||
struct Rank2K;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for multiply-add
|
||||
template <typename ElementA, typename LayoutA,
|
||||
typename ElementB, typename LayoutB,
|
||||
typename ElementC, typename LayoutC, FillMode FillModeC,
|
||||
typename ScalarType, typename ComputeType>
|
||||
struct Rank2K<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, FillModeC, ScalarType,
|
||||
ComputeType, arch::OpMultiplyAdd> {
|
||||
|
||||
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
||||
TensorRef<ElementA, LayoutA> tensor_a,
|
||||
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
||||
TensorRef<ElementC, LayoutC> tensor_c,
|
||||
ComputeType initial_accum = ComputeType(0)) {
|
||||
static_assert(
|
||||
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
||||
"Tensors must be of rank 2");
|
||||
|
||||
compute_rank2k<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, FillModeC,
|
||||
ScalarType, ComputeType, multiply_add<ComputeType>>(
|
||||
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
|
||||
}
|
||||
|
||||
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
||||
TensorRef<ElementA, LayoutA> tensor_a,
|
||||
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
||||
TensorRef<ElementC, LayoutC> tensor_c,
|
||||
TensorRef<ElementC, LayoutC> tensor_d,
|
||||
ComputeType initial_accum = ComputeType(0)) {
|
||||
static_assert(
|
||||
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
||||
"Tensors must be of rank 2");
|
||||
|
||||
compute_rank2k<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, FillModeC,
|
||||
ScalarType, ComputeType, multiply_add<ComputeType>>(
|
||||
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace host
|
||||
} // namespace reference
|
||||
} // namespace cutlass
|
||||
318
tools/util/include/cutlass/util/reference/host/rank_2k_complex.h
Normal file
318
tools/util/include/cutlass/util/reference/host/rank_2k_complex.h
Normal file
@ -0,0 +1,318 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Reference implementation for complex-valued Rank 2K update in host-side code.
|
||||
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/blas3.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include <assert.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace reference {
|
||||
namespace host {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
||||
/// objects.
|
||||
///
|
||||
/// Explicitly naming types needed by this template can be cumbersome, particularly for the
|
||||
/// accumulator type, so a function argument 'initial_accum' is exposed. Passing
|
||||
/// AccumulatorType(0) as the last function argument can be easier than naming all template
|
||||
/// arguments explicitly.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ScalarType,
|
||||
typename ComputeType,
|
||||
typename ConvertOp = NumericConverter<ElementC, ScalarType>,
|
||||
typename InnerProductOp = multiply_add<ComputeType>
|
||||
>
|
||||
void Rank2KComplex(
|
||||
gemm::GemmCoord problem_size,
|
||||
ScalarType alpha,
|
||||
TensorRef<ElementA, LayoutA> tensor_a,
|
||||
ComplexTransform transform_a,
|
||||
TensorRef<ElementB, LayoutB> tensor_b,
|
||||
ComplexTransform transform_b,
|
||||
ScalarType beta,
|
||||
TensorRef<ElementC, LayoutC> tensor_c,
|
||||
TensorRef<ElementC, LayoutC> tensor_d,
|
||||
ComputeType initial_accum,
|
||||
FillMode fill_mode_c,
|
||||
BlasMode blas_mode,
|
||||
int batch_count = 1,
|
||||
int64_t batch_stride_A = 0,
|
||||
int64_t batch_stride_B = 0,
|
||||
int64_t batch_stride_C = 0,
|
||||
int64_t batch_stride_D = 0) {
|
||||
|
||||
static_assert(
|
||||
LayoutA::kRank == 2 &&
|
||||
LayoutB::kRank == 2 &&
|
||||
LayoutC::kRank == 2, "Tensors must be of rank 2");
|
||||
|
||||
// Note: batch is ignored.
|
||||
int const M = problem_size.m();
|
||||
int const N = problem_size.n();
|
||||
int const K = problem_size.k();
|
||||
|
||||
// Rank2K update operates on A=NxK, B=NxK, and C=NxN
|
||||
assert(M==N);
|
||||
|
||||
// Blocking necessary to speedup reference implementation
|
||||
int const Mblock = 16;
|
||||
int const Nblock = 16;
|
||||
|
||||
ConvertOp convert_op;
|
||||
InnerProductOp inner_product_op;
|
||||
|
||||
for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) {
|
||||
|
||||
// Compute matrix product using blocks
|
||||
for (int row_block = 0; row_block < M; row_block += Mblock) {
|
||||
for (int col_block = 0; col_block < N; col_block += Nblock) {
|
||||
|
||||
ComputeType accum[Mblock][Nblock];
|
||||
|
||||
for (int j = 0; j < Nblock; j++) {
|
||||
for (int i = 0; i < Mblock; i++) {
|
||||
accum[i][j] = initial_accum;
|
||||
}
|
||||
}
|
||||
|
||||
for (int k_block = 0; k_block < K; ++k_block) {
|
||||
for (int j = 0; j < Nblock; j++) {
|
||||
for (int i = 0; i < Mblock; i++) {
|
||||
int row = row_block + i;
|
||||
int col = col_block + j;
|
||||
|
||||
if (row < M && col < N &&
|
||||
( (fill_mode_c == FillMode::kLower && row >= col) ||
|
||||
(fill_mode_c == FillMode::kUpper && row <= col) )
|
||||
) {
|
||||
|
||||
// A x B^T (Symmetric) or A x B^H (Hermitian)
|
||||
// complex conjugation on operandB (b_t) is function of blas3 computation
|
||||
ElementA a = tensor_a.at(MatrixCoord(row, k_block));
|
||||
ElementB b_t = (blas_mode == BlasMode::kHermitian) ?
|
||||
conj(tensor_b.at(MatrixCoord(col, k_block))) :
|
||||
tensor_b.at(MatrixCoord(col, k_block));
|
||||
|
||||
ComputeType a_ik = ComputeType(a);
|
||||
ComputeType b_jk = ComputeType(b_t);
|
||||
|
||||
// complex conjugation is a function of operand layouts
|
||||
if (transform_a == ComplexTransform::kConjugate) {
|
||||
a_ik = conj(a_ik);
|
||||
}
|
||||
// complex conjugation is a function of operand layouts
|
||||
if (transform_b == ComplexTransform::kConjugate) {
|
||||
b_jk = conj(b_jk);
|
||||
}
|
||||
|
||||
accum[i][j] = inner_product_op(a_ik, b_jk, accum[i][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* HER2K need two epilogues to handle complex alpha value */
|
||||
if ( blas_mode == BlasMode::kHermitian ) {
|
||||
for (int j = 0; j < Nblock; j++) {
|
||||
for (int i = 0; i < Mblock; i++) {
|
||||
int row = row_block + i;
|
||||
int col = col_block + j;
|
||||
|
||||
MatrixCoord coord = MatrixCoord(row, col);
|
||||
|
||||
if (row < M && col < N &&
|
||||
((fill_mode_c == FillMode::kLower && row >= col) ||
|
||||
(fill_mode_c == FillMode::kUpper && row <= col))
|
||||
) {
|
||||
|
||||
ScalarType c = tensor_c.at(coord);
|
||||
// The imaginary parts of the diagonal elements of
|
||||
// a complex data type are assumed and set to zero
|
||||
if (blas_mode == BlasMode::kHermitian) {
|
||||
c = (row == col) ? real(c) : c;
|
||||
}
|
||||
|
||||
tensor_d.at(coord) = convert_op(alpha *
|
||||
ScalarType(accum[i][j]) +
|
||||
beta * c);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Zeoring out accum for second HERK */
|
||||
for (int j = 0; j < Nblock; j++) {
|
||||
for (int i = 0; i < Mblock; i++) {
|
||||
accum[i][j] = initial_accum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int k_block = 0; k_block < K; ++k_block) {
|
||||
for (int j = 0; j < Nblock; j++) {
|
||||
for (int i = 0; i < Mblock; i++) {
|
||||
int row = row_block + i;
|
||||
int col = col_block + j;
|
||||
|
||||
if (row < M && col < N &&
|
||||
( (fill_mode_c == FillMode::kLower && row >= col) ||
|
||||
(fill_mode_c == FillMode::kUpper && row <= col) )
|
||||
) {
|
||||
|
||||
// B x A^T (Symmetric) or B x A^H (Hermitian)
|
||||
// complex conjugation on operandB (a_t) is function of blas3 computation
|
||||
ElementB b = tensor_b.at(MatrixCoord(row, k_block));
|
||||
ElementA a_t = (blas_mode == BlasMode::kHermitian) ?
|
||||
conj(tensor_a.at(MatrixCoord(col, k_block))):
|
||||
tensor_a.at(MatrixCoord(col, k_block));
|
||||
|
||||
ComputeType b_ik = ComputeType(b);
|
||||
ComputeType a_jk = ComputeType(a_t);
|
||||
|
||||
// complex conjugation here is a function of operand layouts
|
||||
if (transform_b == ComplexTransform::kConjugate) {
|
||||
b_ik = conj(b_ik);
|
||||
}
|
||||
// complex conjugation here is a function of operand layouts
|
||||
if (transform_a == ComplexTransform::kConjugate) {
|
||||
a_jk = conj(a_jk);
|
||||
}
|
||||
|
||||
accum[i][j] = inner_product_op(b_ik, a_jk, accum[i][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ScalarType alpha_hermitian = (blas_mode == BlasMode::kHermitian) ?
|
||||
conj(alpha) : alpha;
|
||||
ScalarType beta_hermitian = (blas_mode == BlasMode::kHermitian) ?
|
||||
1 : beta;
|
||||
|
||||
for (int j = 0; j < Nblock; j++) {
|
||||
for (int i = 0; i < Mblock; i++) {
|
||||
int row = row_block + i;
|
||||
int col = col_block + j;
|
||||
|
||||
MatrixCoord coord = MatrixCoord(row, col);
|
||||
|
||||
if (row < M && col < N &&
|
||||
((fill_mode_c == FillMode::kLower && row >= col) ||
|
||||
(fill_mode_c == FillMode::kUpper && row <= col))
|
||||
) {
|
||||
|
||||
ScalarType d = (blas_mode == BlasMode::kHermitian) ?
|
||||
tensor_d.at(coord) : tensor_c.at(coord);
|
||||
|
||||
ScalarType tmp_d = convert_op(
|
||||
alpha_hermitian * ScalarType(accum[i][j]) +
|
||||
beta_hermitian * d);
|
||||
|
||||
if (blas_mode == BlasMode::kHermitian && row == col ) {
|
||||
tensor_d.at(coord) = real(tmp_d);
|
||||
} else {
|
||||
tensor_d.at(coord) = tmp_d;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // for (col_block)
|
||||
} // for (row_block)
|
||||
|
||||
tensor_a.add_pointer_offset(batch_stride_A);
|
||||
tensor_b.add_pointer_offset(batch_stride_B);
|
||||
tensor_c.add_pointer_offset(batch_stride_C);
|
||||
tensor_d.add_pointer_offset(batch_stride_D);
|
||||
|
||||
} // for (batch_idx)
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
||||
/// objects.
|
||||
///
|
||||
/// This assumes the accumulator type is the same type as the scalars.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ScalarType
|
||||
>
|
||||
void Rank2KComplex(
|
||||
gemm::GemmCoord problem_size,
|
||||
ScalarType alpha,
|
||||
TensorRef<ElementA, LayoutA> tensor_a,
|
||||
ComplexTransform transform_a,
|
||||
TensorRef<ElementB, LayoutB> tensor_b,
|
||||
ComplexTransform transform_b,
|
||||
ScalarType beta,
|
||||
TensorRef<ElementC, LayoutC> tensor_c,
|
||||
TensorRef<ElementC, LayoutC> tensor_d,
|
||||
FillMode fill_mode_c,
|
||||
BlasMode blas_mode) {
|
||||
|
||||
Rank2KComplex(
|
||||
problem_size, alpha,
|
||||
tensor_a, transform_a,
|
||||
tensor_b, transform_b,
|
||||
beta, tensor_c, tensor_d,
|
||||
ScalarType(0),
|
||||
fill_mode_c,
|
||||
blas_mode);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace host
|
||||
} // namespace reference
|
||||
} // namespace cutlass
|
||||
234
tools/util/include/cutlass/util/reference/host/rank_k_complex.h
Normal file
234
tools/util/include/cutlass/util/reference/host/rank_k_complex.h
Normal file
@ -0,0 +1,234 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Reference implementation for complex-valued Rank 2K update in host-side code.
|
||||
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/blas3.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include <assert.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace reference {
|
||||
namespace host {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
||||
/// objects.
|
||||
///
|
||||
/// Explicitly naming types needed by this template can be cumbersome, particularly for the
|
||||
/// accumulator type, so a function argument 'initial_accum' is exposed. Passing
|
||||
/// AccumulatorType(0) as the last function argument can be easier than naming all template
|
||||
/// arguments explicitly.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ScalarType,
|
||||
typename ComputeType,
|
||||
typename ConvertOp = NumericConverter<ElementC, ScalarType>,
|
||||
typename InnerProductOp = multiply_add<ComputeType>
|
||||
>
|
||||
void Rank2KComplex(
|
||||
gemm::GemmCoord problem_size,
|
||||
ScalarType alpha,
|
||||
TensorRef<ElementA, LayoutA> tensor_a,
|
||||
ComplexTransform transform_a,
|
||||
ScalarType beta,
|
||||
TensorRef<ElementC, LayoutC> tensor_c,
|
||||
TensorRef<ElementC, LayoutC> tensor_d,
|
||||
ComputeType initial_accum,
|
||||
FillMode fill_mode_c,
|
||||
BlasMode blas_mode,
|
||||
int batch_count = 1,
|
||||
int64_t batch_stride_A = 0,
|
||||
int64_t batch_stride_C = 0,
|
||||
int64_t batch_stride_D = 0) {
|
||||
|
||||
static_assert(
|
||||
LayoutA::kRank == 2 &&
|
||||
LayoutC::kRank == 2, "Tensors must be of rank 2");
|
||||
|
||||
// Note: batch is ignored.
|
||||
int const M = problem_size.m();
|
||||
int const N = problem_size.n();
|
||||
int const K = problem_size.k();
|
||||
|
||||
// Rank2K update operates on A=NxK, B=NxK, and C=NxN
|
||||
assert(M==N);
|
||||
|
||||
// Blocking necessary to speedup reference implementation
|
||||
int const Mblock = 16;
|
||||
int const Nblock = 16;
|
||||
|
||||
ConvertOp convert_op;
|
||||
InnerProductOp inner_product_op;
|
||||
|
||||
for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) {
|
||||
|
||||
// Compute matrix product using blocks
|
||||
for (int row_block = 0; row_block < M; row_block += Mblock) {
|
||||
for (int col_block = 0; col_block < N; col_block += Nblock) {
|
||||
|
||||
ComputeType accum[Mblock][Nblock];
|
||||
|
||||
for (int j = 0; j < Nblock; j++) {
|
||||
for (int i = 0; i < Mblock; i++) {
|
||||
accum[i][j] = initial_accum;
|
||||
}
|
||||
}
|
||||
|
||||
for (int k_block = 0; k_block < K; ++k_block) {
|
||||
for (int j = 0; j < Nblock; j++) {
|
||||
for (int i = 0; i < Mblock; i++) {
|
||||
int row = row_block + i;
|
||||
int col = col_block + j;
|
||||
|
||||
if (row < M && col < N &&
|
||||
( (fill_mode_c == FillMode::kLower && row >= col) ||
|
||||
(fill_mode_c == FillMode::kUpper && row <= col) )
|
||||
) {
|
||||
|
||||
// A x A^T (Symmetric) or A x A^H (Hermitian)
|
||||
// complex conjugation on operandB (a_t) (function of blas3 computation)
|
||||
ElementA a = tensor_a.at(MatrixCoord(row, k_block));
|
||||
ElementA a_t = (blas_mode == BlasMode::kHermitian) ?
|
||||
conj(tensor_a.at(MatrixCoord(col, k_block))) :
|
||||
tensor_a.at(MatrixCoord(col, k_block));
|
||||
|
||||
ComputeType a_ik = ComputeType(a);
|
||||
ComputeType b_jk = ComputeType(a_t);
|
||||
|
||||
// complex conjugation (function of input layouts)
|
||||
if (transform_a == ComplexTransform::kConjugate) {
|
||||
a_ik = conj(a_ik);
|
||||
}
|
||||
// complex conjugation (function of input layouts)
|
||||
if (transform_a == ComplexTransform::kConjugate) {
|
||||
b_jk = conj(b_jk);
|
||||
}
|
||||
|
||||
accum[i][j] = inner_product_op(a_ik, b_jk, accum[i][j]);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int j = 0; j < Nblock; j++) {
|
||||
for (int i = 0; i < Mblock; i++) {
|
||||
int row = row_block + i;
|
||||
int col = col_block + j;
|
||||
|
||||
MatrixCoord coord = MatrixCoord(row, col);
|
||||
|
||||
if (row < M && col < N &&
|
||||
((fill_mode_c == FillMode::kLower && row >= col) ||
|
||||
(fill_mode_c == FillMode::kUpper && row <= col))
|
||||
) {
|
||||
|
||||
ScalarType c = tensor_c.at(coord);
|
||||
// The imaginary parts of the diagonal elements of
|
||||
// a complex data type are assumed and set to zero
|
||||
if (blas_mode == BlasMode::kHermitian) {
|
||||
c = (row == col) ? real(c) : c;
|
||||
}
|
||||
|
||||
ScalarType tmp_d = convert_op(
|
||||
alpha * ScalarType(accum[i][j]) +
|
||||
beta * c);
|
||||
|
||||
if (blas_mode == BlasMode::kHermitian && row == col ) {
|
||||
tensor_d.at(coord) = real(tmp_d);
|
||||
} else {
|
||||
tensor_d.at(coord) = tmp_d;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // for (col_block)
|
||||
} // for (row_block)
|
||||
|
||||
tensor_a.add_pointer_offset(batch_stride_A);
|
||||
tensor_c.add_pointer_offset(batch_stride_C);
|
||||
tensor_d.add_pointer_offset(batch_stride_D);
|
||||
|
||||
} // for (batch_idx)
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
||||
/// objects.
|
||||
///
|
||||
/// This assumes the accumulator type is the same type as the scalars.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ScalarType
|
||||
>
|
||||
void RankKComplex(
|
||||
gemm::GemmCoord problem_size,
|
||||
ScalarType alpha,
|
||||
TensorRef<ElementA, LayoutA> tensor_a,
|
||||
ComplexTransform transform_a,
|
||||
ScalarType beta,
|
||||
TensorRef<ElementC, LayoutC> tensor_c,
|
||||
TensorRef<ElementC, LayoutC> tensor_d,
|
||||
FillMode fill_mode_c,
|
||||
BlasMode blas_mode) {
|
||||
|
||||
Rank2KComplex(
|
||||
problem_size, alpha,
|
||||
tensor_a, transform_a,
|
||||
beta, tensor_c, tensor_d,
|
||||
ScalarType(0),
|
||||
fill_mode_c,
|
||||
blas_mode);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace host
|
||||
} // namespace reference
|
||||
} // namespace cutlass
|
||||
285
tools/util/include/cutlass/util/reference/host/symm.h
Normal file
285
tools/util/include/cutlass/util/reference/host/symm.h
Normal file
@ -0,0 +1,285 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Reference implementation for SYMM update in host-side code.
|
||||
|
||||
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/blas3.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/arch/mma.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace reference {
|
||||
namespace host {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
||||
/// objects.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
SideMode SideModeA,
|
||||
FillMode FillModeA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ScalarType,
|
||||
typename ComputeType,
|
||||
typename InnerProductOp = multiply_add<ComputeType>,
|
||||
typename ConvertOp = NumericConverter<ElementC, ScalarType>
|
||||
>
|
||||
void compute_symm(
|
||||
gemm::GemmCoord problem_size,
|
||||
ScalarType alpha,
|
||||
TensorRef<ElementA, LayoutA> tensor_a,
|
||||
TensorRef<ElementB, LayoutB> tensor_b,
|
||||
ScalarType beta,
|
||||
TensorRef<ElementC, LayoutC> tensor_c,
|
||||
TensorRef<ElementC, LayoutC> tensor_d,
|
||||
ComputeType initial_accum) {
|
||||
|
||||
static_assert(
|
||||
LayoutA::kRank == 2 &&
|
||||
LayoutB::kRank == 2 &&
|
||||
LayoutC::kRank == 2,
|
||||
"Tensors must be of rank 2");
|
||||
|
||||
static_assert(SideModeA != SideMode::kInvalid
|
||||
, "Side Mode can either be Left or Right.");
|
||||
|
||||
static_assert(
|
||||
FillModeA == FillMode::kLower ||
|
||||
FillModeA == FillMode::kUpper,
|
||||
"Fill Mode can either be Lower or Upper.");
|
||||
|
||||
using CompareOp_w_diag = typename TrMatrixCompareOp<FillModeA, DiagType::kNonUnit>::Type;
|
||||
using CompareOp_wo_diag = typename TrMatrixCompareOp<FillModeA, DiagType::kZero>::Type;
|
||||
|
||||
// Note: batch is ignored.
|
||||
int const M = problem_size.m();
|
||||
int const N = problem_size.n();
|
||||
// Assuming correct k-dimension value is passed
|
||||
int const K = problem_size.k();
|
||||
|
||||
// Blocking necessary to speedup reference implementation
|
||||
int const Mblock = 16;
|
||||
int const Nblock = 16;
|
||||
|
||||
ConvertOp convert_op;
|
||||
InnerProductOp inner_product_op;
|
||||
CompareOp_w_diag compare_op_1;
|
||||
CompareOp_wo_diag compare_op_2;
|
||||
|
||||
for (int row_block = 0; row_block < M; row_block += Mblock) {
|
||||
for (int col_block = 0; col_block < N; col_block += Nblock) {
|
||||
|
||||
ComputeType accum[Mblock][Nblock];
|
||||
|
||||
for (int j = 0; j < Nblock; j++) {
|
||||
for (int i = 0; i < Mblock; i++) {
|
||||
accum[i][j] = initial_accum;
|
||||
}
|
||||
}
|
||||
|
||||
for (int k_block = 0; k_block < K; ++k_block) {
|
||||
for (int j = 0; j < Nblock; j++) {
|
||||
for (int i = 0; i < Mblock; i++) {
|
||||
int row = row_block + i;
|
||||
int col = col_block + j;
|
||||
|
||||
if (row < M && col < N) {
|
||||
ElementA a_1 = ElementA();
|
||||
ElementB b_1 = ElementB();
|
||||
ElementA a_2 = ElementA();
|
||||
ElementB b_2 = ElementB();
|
||||
|
||||
// A x B or B x A (with diagonal)
|
||||
if (SideModeA == SideMode::kLeft) {
|
||||
a_1 = (compare_op_1(row, k_block)) ?
|
||||
(tensor_a.at(MatrixCoord(row, k_block))) : ElementA();
|
||||
b_1 = tensor_b.at(MatrixCoord(k_block, col));
|
||||
} else if (SideModeA == SideMode::kRight) {
|
||||
a_1 = tensor_b.at(MatrixCoord(row, k_block));
|
||||
b_1 = (compare_op_1(k_block, col)) ?
|
||||
tensor_a.at(MatrixCoord(k_block, col)) : ElementA();
|
||||
}
|
||||
|
||||
ComputeType compute_a_1(cast_if_scalar<ComputeType>(a_1));
|
||||
ComputeType compute_b_1(cast_if_scalar<ComputeType>(b_1));
|
||||
|
||||
accum[i][j] = inner_product_op(compute_a_1, compute_b_1, accum[i][j]);
|
||||
|
||||
// A^T x B or B x A^T (without diagonal)
|
||||
if (SideModeA == SideMode::kLeft) {
|
||||
a_2 = (compare_op_2(k_block, row)) ?
|
||||
(tensor_a.at(MatrixCoord(k_block, row))) : ElementA();
|
||||
b_2 = tensor_b.at(MatrixCoord(k_block, col));
|
||||
} else if (SideModeA == SideMode::kRight) {
|
||||
a_2 = tensor_b.at(MatrixCoord(row, k_block));
|
||||
b_2 = (compare_op_2(col, k_block)) ?
|
||||
tensor_a.at(MatrixCoord(col, k_block)) : ElementA();
|
||||
}
|
||||
|
||||
ComputeType compute_a_2(cast_if_scalar<ComputeType>(a_2));
|
||||
ComputeType compute_b_2(cast_if_scalar<ComputeType>(b_2));
|
||||
|
||||
accum[i][j] = inner_product_op(compute_a_2, compute_b_2, accum[i][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int j = 0; j < Nblock; j++) {
|
||||
for (int i = 0; i < Mblock; i++) {
|
||||
int row = row_block + i;
|
||||
int col = col_block + j;
|
||||
|
||||
MatrixCoord coord = MatrixCoord(row, col);
|
||||
|
||||
if (row < M && col < N) {
|
||||
tensor_d.at(coord) = convert_op(
|
||||
alpha * ScalarType(accum[i][j]) +
|
||||
beta * ScalarType(tensor_c.at(coord)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Computes a general Symm update (tensors of rank=2) pointed to by TensorRef
|
||||
/// objects.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
SideMode SideModeA,
|
||||
FillMode FillModeA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ScalarType,
|
||||
typename ComputeType,
|
||||
typename InnerProductOp = multiply_add<ComputeType>,
|
||||
typename ConvertOp = NumericConverter<ElementC, ScalarType>
|
||||
>
|
||||
void compute_symm(
|
||||
gemm::GemmCoord problem_size,
|
||||
ScalarType alpha,
|
||||
TensorRef<ElementA, LayoutA> tensor_a,
|
||||
TensorRef<ElementB, LayoutB> tensor_b,
|
||||
ScalarType beta,
|
||||
TensorRef<ElementC, LayoutC> tensor_c,
|
||||
ComputeType initial_accum) {
|
||||
compute_symm<ElementA, LayoutA, SideModeA, FillModeA, ElementB, LayoutB, ElementC, LayoutC,
|
||||
ScalarType, ComputeType, InnerProductOp, ConvertOp>(
|
||||
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c,
|
||||
initial_accum);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
SideMode SideModeA,
|
||||
FillMode FillModeA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ScalarType,
|
||||
typename ComputeType,
|
||||
typename InnerProductOp = cutlass::arch::OpMultiplyAdd
|
||||
>
|
||||
struct Symm;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for multiply-add
|
||||
template <typename ElementA, typename LayoutA,
|
||||
SideMode SideModeA, FillMode FillModeA,
|
||||
typename ElementB, typename LayoutB,
|
||||
typename ElementC, typename LayoutC,
|
||||
typename ScalarType, typename ComputeType>
|
||||
struct Symm<ElementA, LayoutA, SideModeA, FillModeA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
|
||||
ComputeType, arch::OpMultiplyAdd> {
|
||||
|
||||
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
||||
TensorRef<ElementA, LayoutA> tensor_a,
|
||||
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
||||
TensorRef<ElementC, LayoutC> tensor_c,
|
||||
ComputeType initial_accum = ComputeType(0)) {
|
||||
static_assert(
|
||||
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
||||
"Tensors must be of rank 2");
|
||||
|
||||
compute_symm<ElementA, LayoutA, SideModeA, FillModeA, ElementB, LayoutB, ElementC, LayoutC,
|
||||
ScalarType, ComputeType, multiply_add<ComputeType>>(
|
||||
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
|
||||
}
|
||||
|
||||
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
||||
TensorRef<ElementA, LayoutA> tensor_a,
|
||||
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
||||
TensorRef<ElementC, LayoutC> tensor_c,
|
||||
TensorRef<ElementC, LayoutC> tensor_d,
|
||||
ComputeType initial_accum = ComputeType(0)) {
|
||||
static_assert(
|
||||
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
||||
"Tensors must be of rank 2");
|
||||
|
||||
compute_symm<ElementA, LayoutA, SideModeA, FillModeA, ElementB, LayoutB, ElementC, LayoutC,
|
||||
ScalarType, ComputeType, multiply_add<ComputeType>>(
|
||||
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace host
|
||||
} // namespace reference
|
||||
} // namespace cutlass
|
||||
319
tools/util/include/cutlass/util/reference/host/symm_complex.h
Normal file
319
tools/util/include/cutlass/util/reference/host/symm_complex.h
Normal file
@ -0,0 +1,319 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Reference implementation for complex-valued SYMM update in host-side code.
|
||||
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/blas3.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include <assert.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace reference {
|
||||
namespace host {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
||||
/// objects.
|
||||
///
|
||||
/// Explicitly naming types needed by this template can be cumbersome, particularly for the
|
||||
/// accumulator type, so a function argument 'initial_accum' is exposed. Passing
|
||||
/// AccumulatorType(0) as the last function argument can be easier than naming all template
|
||||
/// arguments explicitly.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
SideMode SideModeA,
|
||||
FillMode FillModeA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ScalarType,
|
||||
typename ComputeType,
|
||||
BlasMode BlasMode_ = BlasMode::kSymmetric,
|
||||
typename InnerProductOp = multiply_add<ComputeType>,
|
||||
typename ConvertOp = NumericConverter<ElementC, ScalarType>
|
||||
>
|
||||
void compute_symm_complex(
|
||||
gemm::GemmCoord problem_size,
|
||||
ScalarType alpha,
|
||||
TensorRef<ElementA, LayoutA> tensor_a,
|
||||
TensorRef<ElementB, LayoutB> tensor_b,
|
||||
ScalarType beta,
|
||||
TensorRef<ElementC, LayoutC> tensor_c,
|
||||
TensorRef<ElementC, LayoutC> tensor_d,
|
||||
ComputeType initial_accum,
|
||||
int batch_count = 1,
|
||||
int64_t batch_stride_A = 0,
|
||||
int64_t batch_stride_B = 0,
|
||||
int64_t batch_stride_C = 0,
|
||||
int64_t batch_stride_D = 0) {
|
||||
|
||||
static SideMode const kSideModeA = SideModeA;
|
||||
static FillMode const kFillModeA = FillModeA;
|
||||
static BlasMode const kBlasMode = BlasMode_;
|
||||
|
||||
static_assert(
|
||||
LayoutA::kRank == 2 &&
|
||||
LayoutB::kRank == 2 &&
|
||||
LayoutC::kRank == 2, "Tensors must be of rank 2");
|
||||
|
||||
static_assert(kSideModeA != SideMode::kInvalid
|
||||
, "Side Mode can either be Left or Right.");
|
||||
|
||||
static_assert(
|
||||
kFillModeA == FillMode::kLower ||
|
||||
kFillModeA == FillMode::kUpper,
|
||||
"Fill Mode can either be Lower or Upper.");
|
||||
|
||||
using CompareOp_w_diag = typename TrMatrixCompareOp<kFillModeA, DiagType::kNonUnit>::Type;
|
||||
using CompareOp_wo_diag = typename TrMatrixCompareOp<kFillModeA, DiagType::kZero>::Type;
|
||||
|
||||
// Note: batch is ignored.
|
||||
int const M = problem_size.m();
|
||||
int const N = problem_size.n();
|
||||
// Assuming correct k-dimension value is passed
|
||||
int const K = problem_size.k();
|
||||
|
||||
// Blocking necessary to speedup reference implementation
|
||||
int const Mblock = 16;
|
||||
int const Nblock = 16;
|
||||
|
||||
ConvertOp convert_op;
|
||||
InnerProductOp inner_product_op;
|
||||
CompareOp_w_diag compare_op_1;
|
||||
CompareOp_wo_diag compare_op_2;
|
||||
|
||||
for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) {
|
||||
|
||||
// Compute matrix product using blocks
|
||||
for (int row_block = 0; row_block < M; row_block += Mblock) {
|
||||
for (int col_block = 0; col_block < N; col_block += Nblock) {
|
||||
|
||||
ComputeType accum[Mblock][Nblock];
|
||||
|
||||
for (int j = 0; j < Nblock; j++) {
|
||||
for (int i = 0; i < Mblock; i++) {
|
||||
accum[i][j] = initial_accum;
|
||||
}
|
||||
}
|
||||
|
||||
for (int k_block = 0; k_block < K; ++k_block) {
|
||||
for (int j = 0; j < Nblock; j++) {
|
||||
for (int i = 0; i < Mblock; i++) {
|
||||
int row = row_block + i;
|
||||
int col = col_block + j;
|
||||
|
||||
if (row < M && col < N)
|
||||
{
|
||||
ElementA a_1 = ElementA();
|
||||
ElementB b_1 = ElementB();
|
||||
ElementA a_2 = ElementA();
|
||||
ElementB b_2 = ElementB();
|
||||
|
||||
// A x B or B x A (with diagonal)
|
||||
if (kSideModeA == SideMode::kLeft) {
|
||||
a_1 = (compare_op_1(row, k_block)) ?
|
||||
(tensor_a.at(MatrixCoord(row, k_block))) : ElementA();
|
||||
b_1 = tensor_b.at(MatrixCoord(k_block, col));
|
||||
} else if (kSideModeA == SideMode::kRight) {
|
||||
a_1 = tensor_b.at(MatrixCoord(row, k_block));
|
||||
b_1 = (compare_op_1(k_block, col)) ?
|
||||
tensor_a.at(MatrixCoord(k_block, col)) : ElementA();
|
||||
}
|
||||
ComputeType compute_a_1 = ComputeType(a_1);
|
||||
ComputeType compute_b_1 = ComputeType(b_1);
|
||||
|
||||
// The imaginary parts of the diagonal elements of
|
||||
// a complex data type are assumed and set to zero
|
||||
if (kBlasMode == BlasMode::kHermitian && kSideModeA == SideMode::kLeft && row == k_block) {
|
||||
compute_a_1 = real(compute_a_1);
|
||||
} else if (kBlasMode == BlasMode::kHermitian && kSideModeA == SideMode::kRight && k_block == col) {
|
||||
compute_b_1 = real(compute_b_1);
|
||||
}
|
||||
|
||||
accum[i][j] = inner_product_op(compute_a_1, compute_b_1, accum[i][j]);
|
||||
|
||||
// A^T x B or B x A^T (without diagonal)
|
||||
if (kSideModeA == SideMode::kLeft) {
|
||||
a_2 = (compare_op_2(k_block, row)) ?
|
||||
(tensor_a.at(MatrixCoord(k_block, row))) : ElementA();
|
||||
b_2 = tensor_b.at(MatrixCoord(k_block, col));
|
||||
if (kBlasMode == BlasMode::kHermitian)
|
||||
a_2 = conj(a_2);
|
||||
} else if (kSideModeA == SideMode::kRight) {
|
||||
a_2 = tensor_b.at(MatrixCoord(row, k_block));
|
||||
b_2 = (compare_op_2(col, k_block)) ?
|
||||
tensor_a.at(MatrixCoord(col, k_block)) : ElementA();
|
||||
if (kBlasMode == BlasMode::kHermitian)
|
||||
b_2 = conj(b_2);
|
||||
}
|
||||
|
||||
ComputeType compute_a_2 = ComputeType(a_2);
|
||||
ComputeType compute_b_2 = ComputeType(b_2);
|
||||
|
||||
accum[i][j] = inner_product_op(compute_a_2, compute_b_2, accum[i][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int j = 0; j < Nblock; j++) {
|
||||
for (int i = 0; i < Mblock; i++) {
|
||||
int row = row_block + i;
|
||||
int col = col_block + j;
|
||||
|
||||
MatrixCoord coord = MatrixCoord(row, col);
|
||||
|
||||
if (row < M && col < N) {
|
||||
|
||||
ScalarType c = tensor_c.at(coord);
|
||||
|
||||
tensor_d.at(coord) = convert_op(
|
||||
alpha * ScalarType(accum[i][j]) +
|
||||
beta * c);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // for (col_block)
|
||||
} // for (row_block)
|
||||
|
||||
tensor_a.add_pointer_offset(batch_stride_A);
|
||||
tensor_b.add_pointer_offset(batch_stride_B);
|
||||
tensor_c.add_pointer_offset(batch_stride_C);
|
||||
tensor_d.add_pointer_offset(batch_stride_D);
|
||||
|
||||
} // for (batch_idx)
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
SideMode SideModeA,
|
||||
FillMode FillModeA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ScalarType,
|
||||
typename ComputeType,
|
||||
BlasMode BlasMode_ = cutlass::BlasMode::kSymmetric,
|
||||
typename InnerProductOp = cutlass::arch::OpMultiplyAddComplex
|
||||
>
|
||||
struct SymmComplex;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for multiply-add
|
||||
template <typename ElementA, typename LayoutA,
|
||||
SideMode SideModeA, FillMode FillModeA,
|
||||
typename ElementB, typename LayoutB,
|
||||
typename ElementC, typename LayoutC,
|
||||
typename ScalarType, typename ComputeType,
|
||||
BlasMode BlasMode_>
|
||||
struct SymmComplex<ElementA, LayoutA,
|
||||
SideModeA, FillModeA,
|
||||
ElementB, LayoutB,
|
||||
ElementC, LayoutC, ScalarType,
|
||||
ComputeType, BlasMode_,
|
||||
arch::OpMultiplyAddComplex> {
|
||||
|
||||
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
||||
TensorRef<ElementA, LayoutA> tensor_a,
|
||||
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
||||
TensorRef<ElementC, LayoutC> tensor_c,
|
||||
TensorRef<ElementC, LayoutC> tensor_d,
|
||||
ComputeType initial_accum = ComputeType(0)) {
|
||||
static_assert(
|
||||
LayoutA::kRank == 2 && LayoutC::kRank == 2,
|
||||
"Tensors must be of rank 2");
|
||||
|
||||
compute_symm_complex<ElementA, LayoutA,
|
||||
SideModeA, FillModeA,
|
||||
ElementB, LayoutB,
|
||||
ElementC, LayoutC,
|
||||
ScalarType, ComputeType, BlasMode_, multiply_add<ComputeType>>(
|
||||
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for gaussian multiply-add
|
||||
template <typename ElementA, typename LayoutA,
|
||||
SideMode SideModeA, FillMode FillModeA,
|
||||
typename ElementB, typename LayoutB,
|
||||
typename ElementC, typename LayoutC,
|
||||
typename ScalarType, typename ComputeType,
|
||||
BlasMode BlasMode_>
|
||||
struct SymmComplex<ElementA, LayoutA,
|
||||
SideModeA, FillModeA,
|
||||
ElementB, LayoutB,
|
||||
ElementC, LayoutC, ScalarType,
|
||||
ComputeType, BlasMode_,
|
||||
arch::OpMultiplyAddGaussianComplex> {
|
||||
|
||||
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
||||
TensorRef<ElementA, LayoutA> tensor_a,
|
||||
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
||||
TensorRef<ElementC, LayoutC> tensor_c,
|
||||
TensorRef<ElementC, LayoutC> tensor_d,
|
||||
ComputeType initial_accum = ComputeType(0)) {
|
||||
static_assert(
|
||||
LayoutA::kRank == 2 && LayoutC::kRank == 2,
|
||||
"Tensors must be of rank 2");
|
||||
|
||||
compute_symm_complex<ElementA, LayoutA,
|
||||
SideModeA, FillModeA,
|
||||
ElementB, LayoutB,
|
||||
ElementC, LayoutC,
|
||||
ScalarType, ComputeType, BlasMode_, multiply_add<ComputeType>>(
|
||||
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace host
|
||||
} // namespace reference
|
||||
} // namespace cutlass
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -42,6 +48,8 @@
|
||||
#include "cutlass/subbyte_reference.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/tensor_view_planar_complex.h"
|
||||
#include "cutlass/blas3.h"
|
||||
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "tensor_foreach.h"
|
||||
|
||||
@ -303,6 +311,51 @@ struct TensorFillGaussianFunc {
|
||||
}
|
||||
};
|
||||
|
||||
/// Computes a random Gaussian distribution
|
||||
template <
|
||||
typename Element, ///< Element type
|
||||
typename Layout> ///< Layout function
|
||||
struct TensorFillSymmetricGaussianFunc {
|
||||
|
||||
using TensorView = TensorView<Element, Layout>;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
TensorView view;
|
||||
RandomGaussianFunc<Element> func;
|
||||
cutlass::FillMode fill_mode;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Construction of Gaussian RNG functor.
|
||||
TensorFillSymmetricGaussianFunc(
|
||||
TensorView view_ = TensorView(),
|
||||
RandomGaussianFunc<Element> func_ = RandomGaussianFunc<Element>(),
|
||||
cutlass::FillMode fill_mode_ = cutlass::FillMode::kInvalid
|
||||
):
|
||||
view(view_), func(func_), fill_mode(fill_mode_) {
|
||||
|
||||
}
|
||||
|
||||
/// Compute random value and update RNG state
|
||||
void operator()(Coord<Layout::kRank> const &coord) const {
|
||||
// Fill half of matrix based on FillMode
|
||||
if (Layout::kRank == 2 &&
|
||||
fill_mode == cutlass::FillMode::kLower &&
|
||||
coord[0] >= coord[1]) {
|
||||
view.at(coord) = func();
|
||||
} else if (Layout::kRank == 2 &&
|
||||
fill_mode == cutlass::FillMode::kUpper &&
|
||||
coord[0] <= coord[1]) {
|
||||
view.at(coord) = func();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -351,6 +404,36 @@ void TensorFillRandomGaussian(
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Fills a tensor with random values with a Gaussian distribution.
|
||||
template <
|
||||
typename Element, ///< Element type
|
||||
typename Layout> ///< Layout function
|
||||
void TensorFillSymmetricRandomGaussian(
|
||||
TensorView<Element, Layout> dst, ///< destination tensor
|
||||
uint64_t seed, ///< seed for RNG
|
||||
cutlass::FillMode fill_mode, ///< FillMode for symmetric matrices
|
||||
double mean = 0, ///< Gaussian distribution's mean
|
||||
double stddev = 1, ///< Gaussian distribution's standard deviation
|
||||
int bits = -1) { ///< If non-negative, specifies number of fractional bits that
|
||||
/// are not truncated to zero. Permits reducing precision of
|
||||
/// data.
|
||||
|
||||
detail::RandomGaussianFunc<Element> random_func(seed, mean, stddev, bits);
|
||||
|
||||
detail::TensorFillSymmetricGaussianFunc<Element, Layout> func(
|
||||
dst,
|
||||
random_func,
|
||||
fill_mode
|
||||
);
|
||||
|
||||
TensorForEach(
|
||||
dst.extent(),
|
||||
func
|
||||
);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Fills a tensor with random values with a Gaussian distribution.
|
||||
template <
|
||||
typename Element ///< Element type
|
||||
@ -566,6 +649,104 @@ struct TensorFillRandomUniformFunc {
|
||||
}
|
||||
};
|
||||
|
||||
/// Computes a random Gaussian distribution
|
||||
template <
|
||||
typename Element, ///< Element type
|
||||
typename Layout> ///< Layout function
|
||||
struct TensorFillSymmetricRandomUniformFunc {
|
||||
|
||||
using TensorView = TensorView<Element, Layout>;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
TensorView view;
|
||||
RandomUniformFunc<Element> func;
|
||||
cutlass::FillMode fill_mode;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Construction of Gaussian RNG functor.
|
||||
TensorFillSymmetricRandomUniformFunc(
|
||||
TensorView view_ = TensorView(),
|
||||
RandomUniformFunc<Element> func_ = RandomUniformFunc<Element>(),
|
||||
cutlass::FillMode fill_mode_ = cutlass::FillMode::kInvalid
|
||||
):
|
||||
view(view_), func(func_), fill_mode(fill_mode_) {
|
||||
|
||||
}
|
||||
|
||||
/// Compute random value and update RNG state
|
||||
void operator()(Coord<Layout::kRank> const &coord) const {
|
||||
// Fill half of matrix based on FillMode
|
||||
if (Layout::kRank == 2 &&
|
||||
fill_mode == cutlass::FillMode::kLower &&
|
||||
coord[0] >= coord[1]) {
|
||||
view.at(coord) = func();
|
||||
} else if (Layout::kRank == 2 &&
|
||||
fill_mode == cutlass::FillMode::kUpper &&
|
||||
coord[0] <= coord[1]) {
|
||||
view.at(coord) = func();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
//
|
||||
// We expect to release this with CUTLASS 2.4. -akerr
|
||||
|
||||
/// Computes a random Uniform distribution and pads diagonal with zeros
|
||||
template <
|
||||
typename Element, ///< Element type
|
||||
typename Layout> ///< Layout function
|
||||
struct TensorFillPadDiagonalRandomUniformFunc {
|
||||
|
||||
using TensorView = TensorView<Element, Layout>;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
TensorView view;
|
||||
RandomUniformFunc<Element> func;
|
||||
cutlass::FillMode fill_mode;
|
||||
int alignment;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Construction of Gaussian RNG functor.
|
||||
TensorFillPadDiagonalRandomUniformFunc(
|
||||
TensorView view_ = TensorView(),
|
||||
RandomUniformFunc<Element> func_ = RandomUniformFunc<Element>(),
|
||||
cutlass::FillMode fill_mode_ = cutlass::FillMode::kInvalid,
|
||||
int alignment_ = 1
|
||||
):
|
||||
view(view_), func(func_), fill_mode(fill_mode_), alignment(alignment_) {
|
||||
|
||||
}
|
||||
|
||||
/// Compute random value and update RNG state
|
||||
void operator()(Coord<Layout::kRank> const &coord) const {
|
||||
// Fill half of matrix based on FillMode
|
||||
if (Layout::kRank == 2 &&
|
||||
(fill_mode == cutlass::FillMode::kLower) &&
|
||||
(coord[0] >= coord[1]) ||
|
||||
((coord[1] - coord[0]) >= alignment)) {
|
||||
view.at(coord) = func();
|
||||
} else if (Layout::kRank == 2 &&
|
||||
fill_mode == cutlass::FillMode::kUpper &&
|
||||
(coord[0] <= coord[1]) ||
|
||||
((coord[0] - coord[1]) >= alignment)) {
|
||||
view.at(coord) = func();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -639,6 +820,69 @@ void TensorFillRandomUniform(
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Fills a tensor with random values with a uniform random distribution.
|
||||
template <
|
||||
typename Element, ///< Element type
|
||||
typename Layout> ///< Layout function
|
||||
void TensorFillSymmetricRandomUniform(
|
||||
TensorView<Element, Layout> dst, ///< destination tensor
|
||||
uint64_t seed, ///< seed for RNG
|
||||
cutlass::FillMode fill_mode, ///< FillMode for symmetric matrices
|
||||
double max = 1, ///< upper bound of distribution
|
||||
double min = 0, ///< lower bound for distribution
|
||||
int bits = -1) { ///< If non-negative, specifies number of fractional bits that
|
||||
/// are not truncated to zero. Permits reducing precision of
|
||||
/// data.
|
||||
|
||||
detail::RandomUniformFunc<Element> random_func(seed, max, min, bits);
|
||||
|
||||
detail::TensorFillSymmetricRandomUniformFunc<Element, Layout> func(
|
||||
dst,
|
||||
random_func,
|
||||
fill_mode
|
||||
);
|
||||
|
||||
TensorForEach(
|
||||
dst.extent(),
|
||||
func
|
||||
);
|
||||
}
|
||||
|
||||
/// Fills a tensor with random values with a uniform random distribution pads zeros along diagonal
|
||||
template <
|
||||
typename Element, ///< Element type
|
||||
typename Layout> ///< Layout function
|
||||
void TensorFillPadDiagonalRandomUniform(
|
||||
TensorView<Element, Layout> dst, ///< destination tensor
|
||||
uint64_t seed, ///< seed for RNG
|
||||
cutlass::FillMode fill_mode, ///< FillMode for symmetric matrices
|
||||
double max = 1, ///< upper bound of distribution
|
||||
double min = 0, ///< lower bound for distribution
|
||||
int bits = -1, ///< If non-negative, specifies number of fractional bits that
|
||||
/// are not truncated to zero. Permits reducing precision of
|
||||
/// data.
|
||||
int alignment = 1
|
||||
) {
|
||||
|
||||
detail::RandomUniformFunc<Element> random_func(seed, max, min, bits);
|
||||
|
||||
detail::TensorFillPadDiagonalRandomUniformFunc<Element, Layout> func(
|
||||
dst,
|
||||
random_func,
|
||||
fill_mode,
|
||||
alignment
|
||||
);
|
||||
|
||||
TensorForEach(
|
||||
dst.extent(),
|
||||
func
|
||||
);
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Fills a tensor with random values with a uniform random distribution.
|
||||
template <
|
||||
typename Element ///< Element type
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
215
tools/util/include/cutlass/util/reference/host/trmm.h
Normal file
215
tools/util/include/cutlass/util/reference/host/trmm.h
Normal file
@ -0,0 +1,215 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Reference implementation for TRMM in host-side code.
|
||||
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/blas3.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/arch/mma.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace reference {
|
||||
namespace host {
|
||||
|
||||
/// Computes a Triangular Matrix Multiplication (tensors of rank=2) pointed to by TensorRef
|
||||
/// objects.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
SideMode SideModeA,
|
||||
FillMode FillModeA,
|
||||
DiagType DiagTypeA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ScalarType,
|
||||
typename ComputeType,
|
||||
typename InnerProductOp = multiply_add<ComputeType>,
|
||||
typename ConvertOp = NumericConverter<ElementC, ScalarType>
|
||||
>
|
||||
void compute_trmm(
|
||||
gemm::GemmCoord problem_size,
|
||||
ScalarType alpha,
|
||||
TensorRef<ElementA, LayoutA> tensor_a,
|
||||
TensorRef<ElementB, LayoutB> tensor_b,
|
||||
TensorRef<ElementC, LayoutC> tensor_d,
|
||||
ComputeType initial_accum) {
|
||||
|
||||
static_assert(
|
||||
LayoutA::kRank == 2 &&
|
||||
LayoutC::kRank == 2, "Tensors must be of rank 2");
|
||||
|
||||
static_assert(SideModeA != SideMode::kInvalid
|
||||
, "Side Mode can either be Left or Right.");
|
||||
|
||||
static_assert(FillModeA == FillMode::kLower || FillModeA == FillMode::kUpper
|
||||
, "Fill Mode can either be Lower or Upper.");
|
||||
|
||||
using CompareOp = typename TrMatrixCompareOp<FillModeA, DiagTypeA>::Type;
|
||||
|
||||
// Note: batch is ignored.
|
||||
int const M = problem_size.m();
|
||||
int const N = problem_size.n();
|
||||
// Assuming correct k-dimension value is passed
|
||||
int const K = problem_size.k();
|
||||
|
||||
// Blocking necessary to speedup reference implementation
|
||||
int const Mblock = 16;
|
||||
int const Nblock = 16;
|
||||
|
||||
ConvertOp convert_op;
|
||||
InnerProductOp inner_product_op;
|
||||
CompareOp compare_op;
|
||||
|
||||
for (int row_block = 0; row_block < M; row_block += Mblock) {
|
||||
for (int col_block = 0; col_block < N; col_block += Nblock) {
|
||||
|
||||
ComputeType accum[Mblock][Nblock];
|
||||
|
||||
for (int j = 0; j < Nblock; j++) {
|
||||
for (int i = 0; i < Mblock; i++) {
|
||||
accum[i][j] = initial_accum;
|
||||
}
|
||||
}
|
||||
|
||||
for (int k_block = 0; k_block < K; ++k_block) {
|
||||
for (int j = 0; j < Nblock; j++) {
|
||||
for (int i = 0; i < Mblock; i++) {
|
||||
int row = row_block + i;
|
||||
int col = col_block + j;
|
||||
|
||||
if (row < M && col < N) {
|
||||
ElementA a = ElementA();
|
||||
ElementB b = ElementB();
|
||||
|
||||
if (SideModeA == SideMode::kLeft) {
|
||||
a = (compare_op(row, k_block)) ?
|
||||
(tensor_a.at(MatrixCoord(row, k_block))) : ElementA(0);
|
||||
if (row == k_block && DiagTypeA == DiagType::kUnit) {
|
||||
a = ElementA(1);
|
||||
}
|
||||
b = tensor_b.at(MatrixCoord(k_block, col));
|
||||
} else if (SideModeA == SideMode::kRight) {
|
||||
a = tensor_b.at(MatrixCoord(row, k_block));
|
||||
b = (compare_op(k_block, col)) ?
|
||||
tensor_a.at(MatrixCoord(k_block, col)) : ElementA(0);
|
||||
if (k_block == col && DiagTypeA == DiagType::kUnit) {
|
||||
b = ElementA(1);
|
||||
}
|
||||
}
|
||||
|
||||
ComputeType compute_a(cast_if_scalar<ComputeType>(a));
|
||||
ComputeType compute_b(cast_if_scalar<ComputeType>(b));
|
||||
|
||||
accum[i][j] = inner_product_op(compute_a, compute_b, accum[i][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int j = 0; j < Nblock; j++) {
|
||||
for (int i = 0; i < Mblock; i++) {
|
||||
int row = row_block + i;
|
||||
int col = col_block + j;
|
||||
|
||||
MatrixCoord coord = MatrixCoord(row, col);
|
||||
|
||||
if (row < M && col < N) {
|
||||
tensor_d.at(coord) = convert_op(
|
||||
alpha * ScalarType(accum[i][j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
SideMode SideModeA,
|
||||
FillMode FillModeA,
|
||||
DiagType DiagTypeA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ScalarType,
|
||||
typename ComputeType,
|
||||
typename InnerProductOp = cutlass::arch::OpMultiplyAdd
|
||||
>
|
||||
struct Trmm;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for multiply-add
|
||||
template <typename ElementA, typename LayoutA, SideMode SideModeA,
|
||||
FillMode FillModeA, DiagType DiagTypeA,
|
||||
typename ElementB, typename LayoutB,
|
||||
typename ElementC, typename LayoutC,
|
||||
typename ScalarType, typename ComputeType>
|
||||
struct Trmm<ElementA, LayoutA, SideModeA, FillModeA, DiagTypeA, ElementB, LayoutB,
|
||||
ElementC, LayoutC, ScalarType,
|
||||
ComputeType, arch::OpMultiplyAdd> {
|
||||
|
||||
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
||||
TensorRef<ElementA, LayoutA> tensor_a,
|
||||
TensorRef<ElementB, LayoutB> tensor_b,
|
||||
TensorRef<ElementC, LayoutC> tensor_d,
|
||||
ComputeType initial_accum = ComputeType(0)) {
|
||||
static_assert(
|
||||
LayoutA::kRank == 2 && LayoutC::kRank == 2,
|
||||
"Tensors must be of rank 2");
|
||||
|
||||
compute_trmm<ElementA, LayoutA, SideModeA, FillModeA, DiagTypeA, ElementB, LayoutB,
|
||||
ElementC, LayoutC, ScalarType, ComputeType, multiply_add<ComputeType>>(
|
||||
problem_size, alpha, tensor_a, tensor_b, tensor_d, initial_accum);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace host
|
||||
} // namespace reference
|
||||
} // namespace cutlass
|
||||
262
tools/util/include/cutlass/util/reference/host/trmm_complex.h
Normal file
262
tools/util/include/cutlass/util/reference/host/trmm_complex.h
Normal file
@ -0,0 +1,262 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Reference implementation for complex-valued TRMM in host-side code.
|
||||
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/blas3.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace reference {
|
||||
namespace host {
|
||||
|
||||
/// Computes a Triangular Matrix Multiplication (tensors of rank=2) pointed to by TensorRef
|
||||
/// objects.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
ComplexTransform TransformA,
|
||||
SideMode SideModeA,
|
||||
FillMode FillModeA,
|
||||
DiagType DiagTypeA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
ComplexTransform TransformB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ScalarType,
|
||||
typename ComputeType,
|
||||
typename InnerProductOp = multiply_add<ComputeType>,
|
||||
typename ConvertOp = NumericConverter<ElementC, ScalarType>
|
||||
>
|
||||
void compute_trmm_complex(
|
||||
gemm::GemmCoord problem_size,
|
||||
ScalarType alpha,
|
||||
TensorRef<ElementA, LayoutA> tensor_a,
|
||||
TensorRef<ElementB, LayoutB> tensor_b,
|
||||
TensorRef<ElementC, LayoutC> tensor_d,
|
||||
ComputeType initial_accum) {
|
||||
|
||||
static_assert(
|
||||
LayoutA::kRank == 2 &&
|
||||
LayoutC::kRank == 2, "Tensors must be of rank 2");
|
||||
|
||||
static_assert(SideModeA != SideMode::kInvalid
|
||||
, "Side Mode can either be Left or Right.");
|
||||
|
||||
static_assert(FillModeA == FillMode::kLower || FillModeA == FillMode::kUpper
|
||||
, "Fill Mode can either be Lower or Upper.");
|
||||
|
||||
using CompareOp = typename TrMatrixCompareOp<FillModeA, DiagTypeA>::Type;
|
||||
|
||||
// Note: batch is ignored.
|
||||
int const M = problem_size.m();
|
||||
int const N = problem_size.n();
|
||||
// Assuming correct k-dimension value is passed
|
||||
int const K = problem_size.k();
|
||||
|
||||
// Blocking necessary to speedup reference implementation
|
||||
int const Mblock = 16;
|
||||
int const Nblock = 16;
|
||||
|
||||
ConvertOp convert_op;
|
||||
InnerProductOp inner_product_op;
|
||||
CompareOp compare_op;
|
||||
|
||||
for (int row_block = 0; row_block < M; row_block += Mblock) {
|
||||
for (int col_block = 0; col_block < N; col_block += Nblock) {
|
||||
|
||||
ComputeType accum[Mblock][Nblock];
|
||||
|
||||
for (int j = 0; j < Nblock; j++) {
|
||||
for (int i = 0; i < Mblock; i++) {
|
||||
accum[i][j] = initial_accum;
|
||||
}
|
||||
}
|
||||
|
||||
for (int k_block = 0; k_block < K; ++k_block) {
|
||||
for (int j = 0; j < Nblock; j++) {
|
||||
for (int i = 0; i < Mblock; i++) {
|
||||
int row = row_block + i;
|
||||
int col = col_block + j;
|
||||
|
||||
if (row < M && col < N) {
|
||||
ElementA a = ElementA();
|
||||
ElementB b = ElementB();
|
||||
|
||||
if (SideModeA == SideMode::kLeft) {
|
||||
a = (compare_op(row, k_block)) ?
|
||||
(tensor_a.at(MatrixCoord(row, k_block))) : ElementA(0);
|
||||
if (row == k_block && DiagTypeA == DiagType::kUnit) {
|
||||
a = ElementA(1);
|
||||
}
|
||||
b = tensor_b.at(MatrixCoord(k_block, col));
|
||||
} else if (SideModeA == SideMode::kRight) {
|
||||
a = tensor_b.at(MatrixCoord(row, k_block));
|
||||
b = (compare_op(k_block, col)) ?
|
||||
tensor_a.at(MatrixCoord(k_block, col)) : ElementA(0);
|
||||
if (k_block == col && DiagTypeA == DiagType::kUnit) {
|
||||
b = ElementA(1);
|
||||
}
|
||||
}
|
||||
|
||||
ComputeType a_ik = ComputeType(a);
|
||||
ComputeType b_kj = ComputeType(b);
|
||||
|
||||
// Conjugate, and hence hermitian, is only allowed for the triangular matrix
|
||||
if (SideModeA == SideMode::kLeft && TransformA == ComplexTransform::kConjugate) {
|
||||
a_ik = conj(a_ik);
|
||||
} else if (SideModeA == SideMode::kRight && TransformA == ComplexTransform::kConjugate) {
|
||||
b_kj = conj(b_kj);
|
||||
}
|
||||
|
||||
accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int j = 0; j < Nblock; j++) {
|
||||
for (int i = 0; i < Mblock; i++) {
|
||||
int row = row_block + i;
|
||||
int col = col_block + j;
|
||||
|
||||
MatrixCoord coord = MatrixCoord(row, col);
|
||||
|
||||
if (row < M && col < N) {
|
||||
tensor_d.at(coord) = convert_op(
|
||||
alpha * ScalarType(accum[i][j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
ComplexTransform TransformA,
|
||||
SideMode SideModeA,
|
||||
FillMode FillModeA,
|
||||
DiagType DiagTypeA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
ComplexTransform TransformB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ScalarType,
|
||||
typename ComputeType,
|
||||
typename InnerProductOp = cutlass::arch::OpMultiplyAddComplex
|
||||
>
|
||||
struct TrmmComplex;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for multiply-add
|
||||
template <typename ElementA, typename LayoutA, ComplexTransform TransformA,
|
||||
SideMode SideModeA, FillMode FillModeA, DiagType DiagTypeA,
|
||||
typename ElementB, typename LayoutB, ComplexTransform TransformB,
|
||||
typename ElementC, typename LayoutC,
|
||||
typename ScalarType, typename ComputeType>
|
||||
struct TrmmComplex<ElementA, LayoutA, TransformA,
|
||||
SideModeA, FillModeA, DiagTypeA,
|
||||
ElementB, LayoutB, TransformB,
|
||||
ElementC, LayoutC, ScalarType,
|
||||
ComputeType, arch::OpMultiplyAddComplex> {
|
||||
|
||||
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
||||
TensorRef<ElementA, LayoutA> tensor_a,
|
||||
TensorRef<ElementB, LayoutB> tensor_b,
|
||||
TensorRef<ElementC, LayoutC> tensor_d,
|
||||
ComputeType initial_accum = ComputeType(0)) {
|
||||
static_assert(
|
||||
LayoutA::kRank == 2 && LayoutC::kRank == 2,
|
||||
"Tensors must be of rank 2");
|
||||
|
||||
compute_trmm_complex<ElementA, LayoutA, TransformA,
|
||||
SideModeA, FillModeA, DiagTypeA,
|
||||
ElementB, LayoutB, TransformB,
|
||||
ElementC, LayoutC,
|
||||
ScalarType, ComputeType, multiply_add<ComputeType>>(
|
||||
problem_size, alpha, tensor_a, tensor_b, tensor_d, initial_accum);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for gaussian multiply-add
|
||||
template <typename ElementA, typename LayoutA, ComplexTransform TransformA,
|
||||
SideMode SideModeA, FillMode FillModeA, DiagType DiagTypeA,
|
||||
typename ElementB, typename LayoutB, ComplexTransform TransformB,
|
||||
typename ElementC, typename LayoutC,
|
||||
typename ScalarType, typename ComputeType>
|
||||
struct TrmmComplex<ElementA, LayoutA, TransformA,
|
||||
SideModeA, FillModeA, DiagTypeA,
|
||||
ElementB, LayoutB, TransformB,
|
||||
ElementC, LayoutC, ScalarType,
|
||||
ComputeType, arch::OpMultiplyAddGaussianComplex> {
|
||||
|
||||
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
||||
TensorRef<ElementA, LayoutA> tensor_a,
|
||||
TensorRef<ElementB, LayoutB> tensor_b,
|
||||
TensorRef<ElementC, LayoutC> tensor_d,
|
||||
ComputeType initial_accum = ComputeType(0)) {
|
||||
static_assert(
|
||||
LayoutA::kRank == 2 && LayoutC::kRank == 2,
|
||||
"Tensors must be of rank 2");
|
||||
|
||||
compute_trmm_complex<ElementA, LayoutA, TransformA,
|
||||
SideModeA, FillModeA, DiagTypeA,
|
||||
ElementB, LayoutB, TransformB,
|
||||
ElementC, LayoutC,
|
||||
ScalarType, ComputeType, multiply_add<ComputeType>>(
|
||||
problem_size, alpha, tensor_a, tensor_b, tensor_d, initial_accum);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace host
|
||||
} // namespace reference
|
||||
} // namespace cutlass
|
||||
@ -1,25 +1,31 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
Reference in New Issue
Block a user