Files
cutlass/include/cutlass/matrix.h
ANIKET SHIVAM 66d9cddc83 New updates for 2.11 (#775)
* New updates.

* Minor profiler updates

Co-authored-by: Aniket Shivam <ashivam@nvidia.com>
2023-01-20 16:32:57 -05:00

14130 lines
356 KiB
C++

/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*
\file
\brief Matrix classes with value semantics.
*/
#pragma once
#if !defined(__CUDACC_RTC__)
#include <iosfwd>
#include <cmath>
#endif
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/coord.h"
#include "cutlass/fast_math.h"
#include "cutlass/layout/matrix.h"
namespace cutlass {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Primary template with partial specializations to follow
template <typename Element, int Rows, int Columns> struct Matrix;
/////////////////////////////////////////////////////////////////////////////////////////////////
/// 1-by-2 matrix template class definition
template <typename Element_>
struct Matrix<Element_, 1, 2> {
//
// Type definitions
//
/// Element data type
using Element = Element_;
/// Number of rows in matrix
static int const kRows = 1;
/// Number of columns in matrix
static int const kColumns = 2;
/// Layout of matrix in underlying array
using Layout = layout::RowMajor;
/// Number of elements in matrix
static int const kCount = 2;
//
// Data members
//
/// Elements of the matrix in row-major layout
Array<Element, kCount> data;
//
// Methods
//
/// Constructs a zero matrix
CUTLASS_HOST_DEVICE
Matrix() {
data.clear();
}
/// Copy constructor for a 1-by-2 matrix
CUTLASS_HOST_DEVICE
Matrix(Matrix const &rhs) {
data = rhs.data;
}
/// Constucts a 1-by-2 matrix from scalar elements
CUTLASS_HOST_DEVICE
Matrix(
Element _0_0, Element _0_1
) {
data[0] = _0_0; data[1] = _0_1;
}
/// Constructs a matrix from a uniform element
CUTLASS_HOST_DEVICE
static Matrix uniform(Element s) {
Matrix m;
m.data[0] = s;
m.data[1] = s;
return m;
}
/// Constructs a matrix from a uniform element 1
CUTLASS_HOST_DEVICE
static Matrix ones() {
return uniform(Element(1));
}
/// Constructs a matrix from a uniform element 0
CUTLASS_HOST_DEVICE
static Matrix zero() {
return Matrix();
}
/// Returns a transposed matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 1> transpose() const {
Matrix<Element, 2, 1> mt;
mt.data[0] = data[0];
mt.data[1] = data[1];
return mt;
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element at(int i, int j) const {
return data[i * 1 + j];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & at(int i, int j) {
return data[i * 1 + j];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element at(Coord<2> const &coord) const {
return at(coord[0], coord[1]);
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & at(Coord<2> const &coord) {
return at(coord[0], coord[1]);
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element &at(int offset) {
return data[offset];
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element at(int offset) const {
return data[offset];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element operator[](Coord<2> const &coord) const {
return at(coord[0], coord[1]);
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & operator[](Coord<2> const &coord) {
return at(coord[0], coord[1]);
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element & operator[](int offset) {
return data[offset];
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element operator[](int offset) const {
return data[offset];
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 2> slice_1x2(int i = 0, int j = 0) const {
Matrix<Element, 1, 2> m;
m.data[0] = data[i * 2 + j + 0];
m.data[1] = data[i * 2 + j + 1];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_1x2(Matrix<Element, 1, 2> const &m, int i = 0, int j = 0) {
data[i * 2 + j + 0] = m.data[0];
data[i * 2 + j + 1] = m.data[1];
return *this;
}
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 2> row(int i) const {
return slice_1x2(i, 0);
}
CUTLASS_HOST_DEVICE
Matrix &set_row(Matrix<Element, 1, 2> const &v, int i = 0) {
return set_slice_1x2(v, i, 0);
}
/// Forms a 1-by-2 matrix by horizontally concatenating an Element with an Element
CUTLASS_HOST_DEVICE
static Matrix hcat(Element lhs, Element rhs) {
return Matrix(
lhs, rhs);
}
/// Concatenates this matrix with a an Element to form a 1-by-3 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 3> hcat(Element rhs) const {
return Matrix<Element, 1, 3>::hcat(*this, rhs);
}
/// Concatenates this matrix with a a 1-by-2 matrix to form a 1-by-4 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 4> hcat(Matrix<Element, 1, 2> const & rhs) const {
return Matrix<Element, 1, 4>::hcat(*this, rhs);
}
/// Concatenates this matrix with a a 1-by-2 matrix to form a 2-by-2 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 2> vcat(Matrix<Element, 1, 2> const & rhs) const {
return Matrix<Element, 2, 2>::vcat(*this, rhs);
}
/// Concatenates this matrix with a a 2-by-2 matrix to form a 3-by-2 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 2> vcat(Matrix<Element, 2, 2> const & rhs) const {
return Matrix<Element, 3, 2>::vcat(*this, rhs);
}
/// Concatenates this matrix with a a 3-by-2 matrix to form a 4-by-2 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 2> vcat(Matrix<Element, 3, 2> const & rhs) const {
return Matrix<Element, 4, 2>::vcat(*this, rhs);
}
/// Elementwise add operator (1-by-2)
CUTLASS_HOST_DEVICE
Matrix add(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] + rhs.data[0];
result.data[1] = data[1] + rhs.data[1];
return result;
}
/// Elementwise add operator (1-by-2)
CUTLASS_HOST_DEVICE
Matrix operator +(Matrix const &rhs) const {
return add(rhs);
}
/// Elementwise add operator (1-by-2)
CUTLASS_HOST_DEVICE
Matrix & operator +=(Matrix const &rhs) {
data[0] += rhs.data[0];
data[1] += rhs.data[1];
return *this;
}
/// Elementwise subtract operator (1-by-2)
CUTLASS_HOST_DEVICE
Matrix subtract(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] - rhs.data[0];
result.data[1] = data[1] - rhs.data[1];
return result;
}
/// Elementwise subtract operator (1-by-2)
CUTLASS_HOST_DEVICE
Matrix operator -(Matrix const &rhs) const {
return subtract(rhs);
}
/// Elementwise subtract operator (1-by-2)
CUTLASS_HOST_DEVICE
Matrix & operator -=(Matrix const &rhs) {
data[0] -= rhs.data[0];
data[1] -= rhs.data[1];
return *this;
}
/// Elementwise multiply operator (1-by-2)
CUTLASS_HOST_DEVICE
Matrix multiply(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] * rhs.data[0];
result.data[1] = data[1] * rhs.data[1];
return result;
}
/// Scalar multiply operator (1-by-2)
CUTLASS_HOST_DEVICE
Matrix multiply(Element const &s) const {
Matrix result;
result.data[0] = data[0] * s;
result.data[1] = data[1] * s;
return result;
}
/// Scalar multiply operator (1-by-2)
CUTLASS_HOST_DEVICE
Matrix operator *(Element const &s) const {
return multiply(s);
}
/// Scalar multiply operator (1-by-2)
CUTLASS_HOST_DEVICE
Matrix & operator *=(Element const &s) {
data[0] *= s;
data[1] *= s;
return *this;
}
/// Elementwise divide operator (1-by-2)
CUTLASS_HOST_DEVICE
Matrix divide(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] / rhs.data[0];
result.data[1] = data[1] / rhs.data[1];
return result;
}
/// Scalar divide operator (1-by-2)
CUTLASS_HOST_DEVICE
Matrix divide(Element const &s) const {
Matrix result;
result.data[0] = data[0] / s;
result.data[1] = data[1] / s;
return result;
}
/// Scalar divide operator (1-by-2)
CUTLASS_HOST_DEVICE
Matrix operator /(Element const &s) const {
return divide(s);
}
/// Scalar divide operator (1-by-2)
CUTLASS_HOST_DEVICE
Matrix & operator /=(Element const &s) {
data[0] /= s;
data[1] /= s;
return *this;
}
/// Elementwise divide operator (1-by-2)
CUTLASS_HOST_DEVICE
Matrix operator /(Matrix const &rhs) const {
return divide(rhs);
}
/// Elementwise divide operator (1-by-2)
CUTLASS_HOST_DEVICE
Matrix & operator /=(Matrix const &rhs) {
data[0] /= rhs.data[0];
data[1] /= rhs.data[1];
return *this;
}
/// Negates each element of the matrix
CUTLASS_HOST_DEVICE
Matrix operator-() const {
Matrix m;
m.data[0] = -m.data[0];
m.data[1] = -m.data[1];
return m;
}
/// Matrix product of size 1-by-1-by-2
CUTLASS_HOST_DEVICE
Element product(Matrix<Element, 2, 1> const &rhs, Element accum = Element()) const {
// k=0
accum += data[0] * rhs.data[0];
// k=1
accum += data[1] * rhs.data[1];
return accum;
}
/// Matrix product of size 1-by-1-by-2
CUTLASS_HOST_DEVICE
Element operator*(Matrix<Element, 2, 1> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 1-by-2-by-2
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 2> product(
Matrix<Element, 2, 2> const &rhs,
Matrix<Element, 1, 2> accum = Matrix<Element, 1, 2>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
// k=1
accum.data[0] += data[1] * rhs.data[2];
accum.data[1] += data[1] * rhs.data[3];
return accum;
}
/// Matrix product of size 1-by-2-by-2
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 2> operator*(Matrix<Element, 2, 2> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 1-by-2-by-2
CUTLASS_HOST_DEVICE
Matrix & operator*=(Matrix<Element, 2, 2> const &rhs) {
*this = product(rhs);
return *this;
}
/// Matrix product of size 1-by-3-by-2
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 3> product(
Matrix<Element, 2, 3> const &rhs,
Matrix<Element, 1, 3> accum = Matrix<Element, 1, 3>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[0] * rhs.data[2];
// k=1
accum.data[0] += data[1] * rhs.data[3];
accum.data[1] += data[1] * rhs.data[4];
accum.data[2] += data[1] * rhs.data[5];
return accum;
}
/// Matrix product of size 1-by-3-by-2
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 3> operator*(Matrix<Element, 2, 3> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 1-by-4-by-2
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 4> product(
Matrix<Element, 2, 4> const &rhs,
Matrix<Element, 1, 4> accum = Matrix<Element, 1, 4>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[0] * rhs.data[2];
accum.data[3] += data[0] * rhs.data[3];
// k=1
accum.data[0] += data[1] * rhs.data[4];
accum.data[1] += data[1] * rhs.data[5];
accum.data[2] += data[1] * rhs.data[6];
accum.data[3] += data[1] * rhs.data[7];
return accum;
}
/// Matrix product of size 1-by-4-by-2
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 4> operator*(Matrix<Element, 2, 4> const &rhs) const {
return product(rhs);
}
/// Dot product of vectors with extent 2
CUTLASS_HOST_DEVICE
Element dot(Matrix<Element, 2, 1> const &rhs, Element accum = Element()) const {
accum += data[0] * rhs.data[0];
accum += data[1] * rhs.data[1];
return accum;
}
/// Dot product of vectors with extent 2
CUTLASS_HOST_DEVICE
Element dot(Matrix<Element, 1, 2> const &rhs, Element accum = Element()) const {
accum += data[0] * rhs.data[0];
accum += data[1] * rhs.data[1];
return accum;
}
/// Returns the sum of elements
CUTLASS_HOST_DEVICE
Element sum(Element accum = Element()) const {
accum += data[0];
accum += data[1];
return accum;
}
/// Returns the sum of squared elements
CUTLASS_HOST_DEVICE
Element norm(Element accum = Element()) const {
accum += data[0] * data[0];
accum += data[1] * data[1];
return accum;
}
/// Returns square root of the norm
CUTLASS_HOST_DEVICE
Element magnitude() const {
return fast_sqrt(norm());
}
/// Returns the sum of diagonal elements
CUTLASS_HOST_DEVICE
Element trace(Element accum = Element()) const {
accum += data[0];
return accum;
}
};
/// Template alias for 1-by-2 matrix
template <typename Element>
using Matrix1x2 = Matrix<Element, 1, 2>;
/// Free funciton to infer element type from template arguments
template <typename Element>
CUTLASS_HOST_DEVICE Matrix1x2<Element> make_Matrix1x2(
Element _0_0, Element _0_1
) {
return Matrix1x2<Element>(
_0_0, _0_1
);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// 1-by-3 matrix template class definition
template <typename Element_>
struct Matrix<Element_, 1, 3> {
//
// Type definitions
//
/// Element data type
using Element = Element_;
/// Number of rows in matrix
static int const kRows = 1;
/// Number of columns in matrix
static int const kColumns = 3;
/// Layout of matrix in underlying array
using Layout = layout::RowMajor;
/// Number of elements in matrix
static int const kCount = 3;
//
// Data members
//
/// Elements of the matrix in row-major layout
Array<Element, kCount> data;
//
// Methods
//
/// Constructs a zero matrix
CUTLASS_HOST_DEVICE
Matrix() {
data.clear();
}
/// Copy constructor for a 1-by-3 matrix
CUTLASS_HOST_DEVICE
Matrix(Matrix const &rhs) {
data = rhs.data;
}
/// Constucts a 1-by-3 matrix from scalar elements
CUTLASS_HOST_DEVICE
Matrix(
Element _0_0, Element _0_1, Element _0_2
) {
data[0] = _0_0; data[1] = _0_1; data[2] = _0_2;
}
/// Constructs a matrix from a uniform element
CUTLASS_HOST_DEVICE
static Matrix uniform(Element s) {
Matrix m;
m.data[0] = s;
m.data[1] = s;
m.data[2] = s;
return m;
}
/// Constructs a matrix from a uniform element 1
CUTLASS_HOST_DEVICE
static Matrix ones() {
return uniform(Element(1));
}
/// Constructs a matrix from a uniform element 0
CUTLASS_HOST_DEVICE
static Matrix zero() {
return Matrix();
}
/// Returns a transposed matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 1> transpose() const {
Matrix<Element, 3, 1> mt;
mt.data[0] = data[0];
mt.data[1] = data[1];
mt.data[2] = data[2];
return mt;
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element at(int i, int j) const {
return data[i * 1 + j];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & at(int i, int j) {
return data[i * 1 + j];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element at(Coord<2> const &coord) const {
return at(coord[0], coord[1]);
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & at(Coord<2> const &coord) {
return at(coord[0], coord[1]);
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element &at(int offset) {
return data[offset];
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element at(int offset) const {
return data[offset];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element operator[](Coord<2> const &coord) const {
return at(coord[0], coord[1]);
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & operator[](Coord<2> const &coord) {
return at(coord[0], coord[1]);
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element & operator[](int offset) {
return data[offset];
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element operator[](int offset) const {
return data[offset];
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 2> slice_1x2(int i = 0, int j = 0) const {
Matrix<Element, 1, 2> m;
m.data[0] = data[i * 3 + j + 0];
m.data[1] = data[i * 3 + j + 1];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_1x2(Matrix<Element, 1, 2> const &m, int i = 0, int j = 0) {
data[i * 3 + j + 0] = m.data[0];
data[i * 3 + j + 1] = m.data[1];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 3> slice_1x3(int i = 0, int j = 0) const {
Matrix<Element, 1, 3> m;
m.data[0] = data[i * 3 + j + 0];
m.data[1] = data[i * 3 + j + 1];
m.data[2] = data[i * 3 + j + 2];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_1x3(Matrix<Element, 1, 3> const &m, int i = 0, int j = 0) {
data[i * 3 + j + 0] = m.data[0];
data[i * 3 + j + 1] = m.data[1];
data[i * 3 + j + 2] = m.data[2];
return *this;
}
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 3> row(int i) const {
return slice_1x3(i, 0);
}
CUTLASS_HOST_DEVICE
Matrix &set_row(Matrix<Element, 1, 3> const &v, int i = 0) {
return set_slice_1x3(v, i, 0);
}
/// Forms a 1-by-3 matrix by horizontally concatenating an Element with a 1-by-2 matrix
CUTLASS_HOST_DEVICE
static Matrix hcat(Element lhs, Matrix<Element, 1, 2> const & rhs) {
return Matrix(
lhs, rhs.at(0, 0), rhs.at(0, 1));
}
/// Forms a 1-by-3 matrix by horizontally concatenating a 1-by-2 matrix with an Element
CUTLASS_HOST_DEVICE
static Matrix hcat(Matrix<Element, 1, 2> const & lhs, Element rhs) {
return Matrix(
lhs.at(0, 0), lhs.at(0, 1), rhs);
}
/// Concatenates this matrix with a an Element to form a 1-by-4 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 4> hcat(Element rhs) const {
return Matrix<Element, 1, 4>::hcat(*this, rhs);
}
/// Concatenates this matrix with a a 1-by-3 matrix to form a 2-by-3 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 3> vcat(Matrix<Element, 1, 3> const & rhs) const {
return Matrix<Element, 2, 3>::vcat(*this, rhs);
}
/// Concatenates this matrix with a a 2-by-3 matrix to form a 3-by-3 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 3> vcat(Matrix<Element, 2, 3> const & rhs) const {
return Matrix<Element, 3, 3>::vcat(*this, rhs);
}
/// Concatenates this matrix with a a 3-by-3 matrix to form a 4-by-3 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 3> vcat(Matrix<Element, 3, 3> const & rhs) const {
return Matrix<Element, 4, 3>::vcat(*this, rhs);
}
/// Elementwise add operator (1-by-3)
CUTLASS_HOST_DEVICE
Matrix add(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] + rhs.data[0];
result.data[1] = data[1] + rhs.data[1];
result.data[2] = data[2] + rhs.data[2];
return result;
}
/// Elementwise add operator (1-by-3)
CUTLASS_HOST_DEVICE
Matrix operator +(Matrix const &rhs) const {
return add(rhs);
}
/// Elementwise add operator (1-by-3)
CUTLASS_HOST_DEVICE
Matrix & operator +=(Matrix const &rhs) {
data[0] += rhs.data[0];
data[1] += rhs.data[1];
data[2] += rhs.data[2];
return *this;
}
/// Elementwise subtract operator (1-by-3)
CUTLASS_HOST_DEVICE
Matrix subtract(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] - rhs.data[0];
result.data[1] = data[1] - rhs.data[1];
result.data[2] = data[2] - rhs.data[2];
return result;
}
/// Elementwise subtract operator (1-by-3)
CUTLASS_HOST_DEVICE
Matrix operator -(Matrix const &rhs) const {
return subtract(rhs);
}
/// Elementwise subtract operator (1-by-3)
CUTLASS_HOST_DEVICE
Matrix & operator -=(Matrix const &rhs) {
data[0] -= rhs.data[0];
data[1] -= rhs.data[1];
data[2] -= rhs.data[2];
return *this;
}
/// Elementwise multiply operator (1-by-3)
CUTLASS_HOST_DEVICE
Matrix multiply(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] * rhs.data[0];
result.data[1] = data[1] * rhs.data[1];
result.data[2] = data[2] * rhs.data[2];
return result;
}
/// Scalar multiply operator (1-by-3)
CUTLASS_HOST_DEVICE
Matrix multiply(Element const &s) const {
Matrix result;
result.data[0] = data[0] * s;
result.data[1] = data[1] * s;
result.data[2] = data[2] * s;
return result;
}
/// Scalar multiply operator (1-by-3)
CUTLASS_HOST_DEVICE
Matrix operator *(Element const &s) const {
return multiply(s);
}
/// Scalar multiply operator (1-by-3)
CUTLASS_HOST_DEVICE
Matrix & operator *=(Element const &s) {
data[0] *= s;
data[1] *= s;
data[2] *= s;
return *this;
}
/// Elementwise divide operator (1-by-3)
CUTLASS_HOST_DEVICE
Matrix divide(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] / rhs.data[0];
result.data[1] = data[1] / rhs.data[1];
result.data[2] = data[2] / rhs.data[2];
return result;
}
/// Scalar divide operator (1-by-3)
CUTLASS_HOST_DEVICE
Matrix divide(Element const &s) const {
Matrix result;
result.data[0] = data[0] / s;
result.data[1] = data[1] / s;
result.data[2] = data[2] / s;
return result;
}
/// Scalar divide operator (1-by-3)
CUTLASS_HOST_DEVICE
Matrix operator /(Element const &s) const {
return divide(s);
}
/// Scalar divide operator (1-by-3)
CUTLASS_HOST_DEVICE
Matrix & operator /=(Element const &s) {
data[0] /= s;
data[1] /= s;
data[2] /= s;
return *this;
}
/// Elementwise divide operator (1-by-3)
CUTLASS_HOST_DEVICE
Matrix operator /(Matrix const &rhs) const {
return divide(rhs);
}
/// Elementwise divide operator (1-by-3)
CUTLASS_HOST_DEVICE
Matrix & operator /=(Matrix const &rhs) {
data[0] /= rhs.data[0];
data[1] /= rhs.data[1];
data[2] /= rhs.data[2];
return *this;
}
/// Negates each element of the matrix
CUTLASS_HOST_DEVICE
Matrix operator-() const {
Matrix m;
m.data[0] = -m.data[0];
m.data[1] = -m.data[1];
m.data[2] = -m.data[2];
return m;
}
/// Matrix product of size 1-by-1-by-3
CUTLASS_HOST_DEVICE
Element product(Matrix<Element, 3, 1> const &rhs, Element accum = Element()) const {
// k=0
accum += data[0] * rhs.data[0];
// k=1
accum += data[1] * rhs.data[1];
// k=2
accum += data[2] * rhs.data[2];
return accum;
}
/// Matrix product of size 1-by-1-by-3
CUTLASS_HOST_DEVICE
Element operator*(Matrix<Element, 3, 1> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 1-by-2-by-3
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 2> product(
Matrix<Element, 3, 2> const &rhs,
Matrix<Element, 1, 2> accum = Matrix<Element, 1, 2>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
// k=1
accum.data[0] += data[1] * rhs.data[2];
accum.data[1] += data[1] * rhs.data[3];
// k=2
accum.data[0] += data[2] * rhs.data[4];
accum.data[1] += data[2] * rhs.data[5];
return accum;
}
/// Matrix product of size 1-by-2-by-3
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 2> operator*(Matrix<Element, 3, 2> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 1-by-3-by-3
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 3> product(
Matrix<Element, 3, 3> const &rhs,
Matrix<Element, 1, 3> accum = Matrix<Element, 1, 3>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[0] * rhs.data[2];
// k=1
accum.data[0] += data[1] * rhs.data[3];
accum.data[1] += data[1] * rhs.data[4];
accum.data[2] += data[1] * rhs.data[5];
// k=2
accum.data[0] += data[2] * rhs.data[6];
accum.data[1] += data[2] * rhs.data[7];
accum.data[2] += data[2] * rhs.data[8];
return accum;
}
/// Matrix product of size 1-by-3-by-3
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 3> operator*(Matrix<Element, 3, 3> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 1-by-3-by-3
CUTLASS_HOST_DEVICE
Matrix & operator*=(Matrix<Element, 3, 3> const &rhs) {
*this = product(rhs);
return *this;
}
/// Matrix product of size 1-by-4-by-3
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 4> product(
Matrix<Element, 3, 4> const &rhs,
Matrix<Element, 1, 4> accum = Matrix<Element, 1, 4>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[0] * rhs.data[2];
accum.data[3] += data[0] * rhs.data[3];
// k=1
accum.data[0] += data[1] * rhs.data[4];
accum.data[1] += data[1] * rhs.data[5];
accum.data[2] += data[1] * rhs.data[6];
accum.data[3] += data[1] * rhs.data[7];
// k=2
accum.data[0] += data[2] * rhs.data[8];
accum.data[1] += data[2] * rhs.data[9];
accum.data[2] += data[2] * rhs.data[10];
accum.data[3] += data[2] * rhs.data[11];
return accum;
}
/// Matrix product of size 1-by-4-by-3
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 4> operator*(Matrix<Element, 3, 4> const &rhs) const {
return product(rhs);
}
/// Dot product of vectors with extent 3
CUTLASS_HOST_DEVICE
Element dot(Matrix<Element, 3, 1> const &rhs, Element accum = Element()) const {
accum += data[0] * rhs.data[0];
accum += data[1] * rhs.data[1];
accum += data[2] * rhs.data[2];
return accum;
}
/// Dot product of vectors with extent 3
CUTLASS_HOST_DEVICE
Element dot(Matrix<Element, 1, 3> const &rhs, Element accum = Element()) const {
accum += data[0] * rhs.data[0];
accum += data[1] * rhs.data[1];
accum += data[2] * rhs.data[2];
return accum;
}
/// Returns the sum of elements
CUTLASS_HOST_DEVICE
Element sum(Element accum = Element()) const {
accum += data[0];
accum += data[1];
accum += data[2];
return accum;
}
/// Returns the sum of squared elements
CUTLASS_HOST_DEVICE
Element norm(Element accum = Element()) const {
accum += data[0] * data[0];
accum += data[1] * data[1];
accum += data[2] * data[2];
return accum;
}
/// Returns square root of the norm
CUTLASS_HOST_DEVICE
Element magnitude() const {
return fast_sqrt(norm());
}
/// Returns the sum of diagonal elements
CUTLASS_HOST_DEVICE
Element trace(Element accum = Element()) const {
accum += data[0];
return accum;
}
/// Cross product
CUTLASS_HOST_DEVICE
Matrix cross(Matrix const &rhs) const {
return Matrix(
data[1] * rhs.data[2] - data[2] * rhs.data[1],
data[0] * rhs.data[2] - data[2] * rhs.data[1],
data[0] * rhs.data[1] - data[1] * rhs.data[0]
);
}
};
/// Template alias for 1-by-3 matrix
template <typename Element>
using Matrix1x3 = Matrix<Element, 1, 3>;
/// Free funciton to infer element type from template arguments
template <typename Element>
CUTLASS_HOST_DEVICE Matrix1x3<Element> make_Matrix1x3(
Element _0_0, Element _0_1, Element _0_2
) {
return Matrix1x3<Element>(
_0_0, _0_1, _0_2
);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// 1-by-4 matrix template class definition
template <typename Element_>
struct Matrix<Element_, 1, 4> {
//
// Type definitions
//
/// Element data type
using Element = Element_;
/// Number of rows in matrix
static int const kRows = 1;
/// Number of columns in matrix
static int const kColumns = 4;
/// Layout of matrix in underlying array
using Layout = layout::RowMajor;
/// Number of elements in matrix
static int const kCount = 4;
//
// Data members
//
/// Elements of the matrix in row-major layout
Array<Element, kCount> data;
//
// Methods
//
/// Constructs a zero matrix
CUTLASS_HOST_DEVICE
Matrix() {
data.clear();
}
/// Copy constructor for a 1-by-4 matrix
CUTLASS_HOST_DEVICE
Matrix(Matrix const &rhs) {
data = rhs.data;
}
/// Constucts a 1-by-4 matrix from scalar elements
CUTLASS_HOST_DEVICE
Matrix(
Element _0_0, Element _0_1, Element _0_2, Element _0_3
) {
data[0] = _0_0; data[1] = _0_1; data[2] = _0_2; data[3] = _0_3;
}
/// Constructs a matrix from a uniform element
CUTLASS_HOST_DEVICE
static Matrix uniform(Element s) {
Matrix m;
m.data[0] = s;
m.data[1] = s;
m.data[2] = s;
m.data[3] = s;
return m;
}
/// Constructs a matrix from a uniform element 1
CUTLASS_HOST_DEVICE
static Matrix ones() {
return uniform(Element(1));
}
/// Constructs a matrix from a uniform element 0
CUTLASS_HOST_DEVICE
static Matrix zero() {
return Matrix();
}
/// Returns a transposed matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 1> transpose() const {
Matrix<Element, 4, 1> mt;
mt.data[0] = data[0];
mt.data[1] = data[1];
mt.data[2] = data[2];
mt.data[3] = data[3];
return mt;
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element at(int i, int j) const {
return data[i * 1 + j];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & at(int i, int j) {
return data[i * 1 + j];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element at(Coord<2> const &coord) const {
return at(coord[0], coord[1]);
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & at(Coord<2> const &coord) {
return at(coord[0], coord[1]);
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element &at(int offset) {
return data[offset];
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element at(int offset) const {
return data[offset];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element operator[](Coord<2> const &coord) const {
return at(coord[0], coord[1]);
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & operator[](Coord<2> const &coord) {
return at(coord[0], coord[1]);
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element & operator[](int offset) {
return data[offset];
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element operator[](int offset) const {
return data[offset];
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 2> slice_1x2(int i = 0, int j = 0) const {
Matrix<Element, 1, 2> m;
m.data[0] = data[i * 4 + j + 0];
m.data[1] = data[i * 4 + j + 1];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_1x2(Matrix<Element, 1, 2> const &m, int i = 0, int j = 0) {
data[i * 4 + j + 0] = m.data[0];
data[i * 4 + j + 1] = m.data[1];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 3> slice_1x3(int i = 0, int j = 0) const {
Matrix<Element, 1, 3> m;
m.data[0] = data[i * 4 + j + 0];
m.data[1] = data[i * 4 + j + 1];
m.data[2] = data[i * 4 + j + 2];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_1x3(Matrix<Element, 1, 3> const &m, int i = 0, int j = 0) {
data[i * 4 + j + 0] = m.data[0];
data[i * 4 + j + 1] = m.data[1];
data[i * 4 + j + 2] = m.data[2];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 4> slice_1x4(int i = 0, int j = 0) const {
Matrix<Element, 1, 4> m;
m.data[0] = data[i * 4 + j + 0];
m.data[1] = data[i * 4 + j + 1];
m.data[2] = data[i * 4 + j + 2];
m.data[3] = data[i * 4 + j + 3];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_1x4(Matrix<Element, 1, 4> const &m, int i = 0, int j = 0) {
data[i * 4 + j + 0] = m.data[0];
data[i * 4 + j + 1] = m.data[1];
data[i * 4 + j + 2] = m.data[2];
data[i * 4 + j + 3] = m.data[3];
return *this;
}
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 4> row(int i) const {
return slice_1x4(i, 0);
}
CUTLASS_HOST_DEVICE
Matrix &set_row(Matrix<Element, 1, 4> const &v, int i = 0) {
return set_slice_1x4(v, i, 0);
}
/// Forms a 1-by-4 matrix by horizontally concatenating an Element with a 1-by-3 matrix
CUTLASS_HOST_DEVICE
static Matrix hcat(Element lhs, Matrix<Element, 1, 3> const & rhs) {
return Matrix(
lhs, rhs.at(0, 0), rhs.at(0, 1), rhs.at(0, 2));
}
/// Forms a 1-by-4 matrix by horizontally concatenating a 1-by-2 matrix with a 1-by-2 matrix
CUTLASS_HOST_DEVICE
static Matrix hcat(Matrix<Element, 1, 2> const & lhs, Matrix<Element, 1, 2> const & rhs) {
return Matrix(
lhs.at(0, 0), lhs.at(0, 1), rhs.at(0, 0), rhs.at(0, 1));
}
/// Forms a 1-by-4 matrix by horizontally concatenating a 1-by-3 matrix with an Element
CUTLASS_HOST_DEVICE
static Matrix hcat(Matrix<Element, 1, 3> const & lhs, Element rhs) {
return Matrix(
lhs.at(0, 0), lhs.at(0, 1), lhs.at(0, 2), rhs);
}
/// Concatenates this matrix with a a 1-by-4 matrix to form a 2-by-4 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 4> vcat(Matrix<Element, 1, 4> const & rhs) const {
return Matrix<Element, 2, 4>::vcat(*this, rhs);
}
/// Concatenates this matrix with a a 2-by-4 matrix to form a 3-by-4 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 4> vcat(Matrix<Element, 2, 4> const & rhs) const {
return Matrix<Element, 3, 4>::vcat(*this, rhs);
}
/// Concatenates this matrix with a a 3-by-4 matrix to form a 4-by-4 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 4> vcat(Matrix<Element, 3, 4> const & rhs) const {
return Matrix<Element, 4, 4>::vcat(*this, rhs);
}
/// Elementwise add operator (1-by-4)
CUTLASS_HOST_DEVICE
Matrix add(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] + rhs.data[0];
result.data[1] = data[1] + rhs.data[1];
result.data[2] = data[2] + rhs.data[2];
result.data[3] = data[3] + rhs.data[3];
return result;
}
/// Elementwise add operator (1-by-4)
CUTLASS_HOST_DEVICE
Matrix operator +(Matrix const &rhs) const {
return add(rhs);
}
/// Elementwise add operator (1-by-4)
CUTLASS_HOST_DEVICE
Matrix & operator +=(Matrix const &rhs) {
data[0] += rhs.data[0];
data[1] += rhs.data[1];
data[2] += rhs.data[2];
data[3] += rhs.data[3];
return *this;
}
/// Elementwise subtract operator (1-by-4)
CUTLASS_HOST_DEVICE
Matrix subtract(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] - rhs.data[0];
result.data[1] = data[1] - rhs.data[1];
result.data[2] = data[2] - rhs.data[2];
result.data[3] = data[3] - rhs.data[3];
return result;
}
/// Elementwise subtract operator (1-by-4)
CUTLASS_HOST_DEVICE
Matrix operator -(Matrix const &rhs) const {
return subtract(rhs);
}
/// Elementwise subtract operator (1-by-4)
CUTLASS_HOST_DEVICE
Matrix & operator -=(Matrix const &rhs) {
data[0] -= rhs.data[0];
data[1] -= rhs.data[1];
data[2] -= rhs.data[2];
data[3] -= rhs.data[3];
return *this;
}
/// Elementwise multiply operator (1-by-4)
CUTLASS_HOST_DEVICE
Matrix multiply(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] * rhs.data[0];
result.data[1] = data[1] * rhs.data[1];
result.data[2] = data[2] * rhs.data[2];
result.data[3] = data[3] * rhs.data[3];
return result;
}
/// Scalar multiply operator (1-by-4)
CUTLASS_HOST_DEVICE
Matrix multiply(Element const &s) const {
Matrix result;
result.data[0] = data[0] * s;
result.data[1] = data[1] * s;
result.data[2] = data[2] * s;
result.data[3] = data[3] * s;
return result;
}
/// Scalar multiply operator (1-by-4)
CUTLASS_HOST_DEVICE
Matrix operator *(Element const &s) const {
return multiply(s);
}
/// Scalar multiply operator (1-by-4)
CUTLASS_HOST_DEVICE
Matrix & operator *=(Element const &s) {
data[0] *= s;
data[1] *= s;
data[2] *= s;
data[3] *= s;
return *this;
}
/// Elementwise divide operator (1-by-4)
CUTLASS_HOST_DEVICE
Matrix divide(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] / rhs.data[0];
result.data[1] = data[1] / rhs.data[1];
result.data[2] = data[2] / rhs.data[2];
result.data[3] = data[3] / rhs.data[3];
return result;
}
/// Scalar divide operator (1-by-4)
CUTLASS_HOST_DEVICE
Matrix divide(Element const &s) const {
Matrix result;
result.data[0] = data[0] / s;
result.data[1] = data[1] / s;
result.data[2] = data[2] / s;
result.data[3] = data[3] / s;
return result;
}
/// Scalar divide operator (1-by-4)
CUTLASS_HOST_DEVICE
Matrix operator /(Element const &s) const {
return divide(s);
}
/// Scalar divide operator (1-by-4)
CUTLASS_HOST_DEVICE
Matrix & operator /=(Element const &s) {
data[0] /= s;
data[1] /= s;
data[2] /= s;
data[3] /= s;
return *this;
}
/// Elementwise divide operator (1-by-4)
CUTLASS_HOST_DEVICE
Matrix operator /(Matrix const &rhs) const {
return divide(rhs);
}
/// Elementwise divide operator (1-by-4)
CUTLASS_HOST_DEVICE
Matrix & operator /=(Matrix const &rhs) {
data[0] /= rhs.data[0];
data[1] /= rhs.data[1];
data[2] /= rhs.data[2];
data[3] /= rhs.data[3];
return *this;
}
/// Negates each element of the matrix
CUTLASS_HOST_DEVICE
Matrix operator-() const {
Matrix m;
m.data[0] = -m.data[0];
m.data[1] = -m.data[1];
m.data[2] = -m.data[2];
m.data[3] = -m.data[3];
return m;
}
/// Matrix product of size 1-by-1-by-4
CUTLASS_HOST_DEVICE
Element product(Matrix<Element, 4, 1> const &rhs, Element accum = Element()) const {
// k=0
accum += data[0] * rhs.data[0];
// k=1
accum += data[1] * rhs.data[1];
// k=2
accum += data[2] * rhs.data[2];
// k=3
accum += data[3] * rhs.data[3];
return accum;
}
/// Matrix product of size 1-by-1-by-4
CUTLASS_HOST_DEVICE
Element operator*(Matrix<Element, 4, 1> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 1-by-2-by-4
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 2> product(
Matrix<Element, 4, 2> const &rhs,
Matrix<Element, 1, 2> accum = Matrix<Element, 1, 2>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
// k=1
accum.data[0] += data[1] * rhs.data[2];
accum.data[1] += data[1] * rhs.data[3];
// k=2
accum.data[0] += data[2] * rhs.data[4];
accum.data[1] += data[2] * rhs.data[5];
// k=3
accum.data[0] += data[3] * rhs.data[6];
accum.data[1] += data[3] * rhs.data[7];
return accum;
}
/// Matrix product of size 1-by-2-by-4
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 2> operator*(Matrix<Element, 4, 2> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 1-by-3-by-4
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 3> product(
Matrix<Element, 4, 3> const &rhs,
Matrix<Element, 1, 3> accum = Matrix<Element, 1, 3>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[0] * rhs.data[2];
// k=1
accum.data[0] += data[1] * rhs.data[3];
accum.data[1] += data[1] * rhs.data[4];
accum.data[2] += data[1] * rhs.data[5];
// k=2
accum.data[0] += data[2] * rhs.data[6];
accum.data[1] += data[2] * rhs.data[7];
accum.data[2] += data[2] * rhs.data[8];
// k=3
accum.data[0] += data[3] * rhs.data[9];
accum.data[1] += data[3] * rhs.data[10];
accum.data[2] += data[3] * rhs.data[11];
return accum;
}
/// Matrix product of size 1-by-3-by-4
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 3> operator*(Matrix<Element, 4, 3> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 1-by-4-by-4
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 4> product(
Matrix<Element, 4, 4> const &rhs,
Matrix<Element, 1, 4> accum = Matrix<Element, 1, 4>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[0] * rhs.data[2];
accum.data[3] += data[0] * rhs.data[3];
// k=1
accum.data[0] += data[1] * rhs.data[4];
accum.data[1] += data[1] * rhs.data[5];
accum.data[2] += data[1] * rhs.data[6];
accum.data[3] += data[1] * rhs.data[7];
// k=2
accum.data[0] += data[2] * rhs.data[8];
accum.data[1] += data[2] * rhs.data[9];
accum.data[2] += data[2] * rhs.data[10];
accum.data[3] += data[2] * rhs.data[11];
// k=3
accum.data[0] += data[3] * rhs.data[12];
accum.data[1] += data[3] * rhs.data[13];
accum.data[2] += data[3] * rhs.data[14];
accum.data[3] += data[3] * rhs.data[15];
return accum;
}
/// Matrix product of size 1-by-4-by-4
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 4> operator*(Matrix<Element, 4, 4> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 1-by-4-by-4
CUTLASS_HOST_DEVICE
Matrix & operator*=(Matrix<Element, 4, 4> const &rhs) {
*this = product(rhs);
return *this;
}
/// Dot product of vectors with extent 4
CUTLASS_HOST_DEVICE
Element dot(Matrix<Element, 4, 1> const &rhs, Element accum = Element()) const {
accum += data[0] * rhs.data[0];
accum += data[1] * rhs.data[1];
accum += data[2] * rhs.data[2];
accum += data[3] * rhs.data[3];
return accum;
}
/// Dot product of vectors with extent 4
CUTLASS_HOST_DEVICE
Element dot(Matrix<Element, 1, 4> const &rhs, Element accum = Element()) const {
accum += data[0] * rhs.data[0];
accum += data[1] * rhs.data[1];
accum += data[2] * rhs.data[2];
accum += data[3] * rhs.data[3];
return accum;
}
/// Returns the sum of elements
CUTLASS_HOST_DEVICE
Element sum(Element accum = Element()) const {
accum += data[0];
accum += data[1];
accum += data[2];
accum += data[3];
return accum;
}
/// Returns the sum of squared elements
CUTLASS_HOST_DEVICE
Element norm(Element accum = Element()) const {
accum += data[0] * data[0];
accum += data[1] * data[1];
accum += data[2] * data[2];
accum += data[3] * data[3];
return accum;
}
/// Returns square root of the norm
CUTLASS_HOST_DEVICE
Element magnitude() const {
return fast_sqrt(norm());
}
/// Returns the sum of diagonal elements
CUTLASS_HOST_DEVICE
Element trace(Element accum = Element()) const {
accum += data[0];
return accum;
}
};
/// Template alias for 1-by-4 matrix
template <typename Element>
using Matrix1x4 = Matrix<Element, 1, 4>;
/// Free funciton to infer element type from template arguments
template <typename Element>
CUTLASS_HOST_DEVICE Matrix1x4<Element> make_Matrix1x4(
Element _0_0, Element _0_1, Element _0_2, Element _0_3
) {
return Matrix1x4<Element>(
_0_0, _0_1, _0_2, _0_3
);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// 2-by-1 matrix template class definition
template <typename Element_>
struct Matrix<Element_, 2, 1> {
//
// Type definitions
//
/// Element data type
using Element = Element_;
/// Number of rows in matrix
static int const kRows = 2;
/// Number of columns in matrix
static int const kColumns = 1;
/// Layout of matrix in underlying array
using Layout = layout::RowMajor;
/// Number of elements in matrix
static int const kCount = 2;
//
// Data members
//
/// Elements of the matrix in row-major layout
Array<Element, kCount> data;
//
// Methods
//
/// Constructs a zero matrix
CUTLASS_HOST_DEVICE
Matrix() {
data.clear();
}
/// Copy constructor for a 2-by-1 matrix
CUTLASS_HOST_DEVICE
Matrix(Matrix const &rhs) {
data = rhs.data;
}
/// Constucts a 2-by-1 matrix from scalar elements
CUTLASS_HOST_DEVICE
Matrix(
Element _0_0,
Element _1_0
) {
data[0] = _0_0;
data[1] = _1_0;
}
/// Constructs a matrix from a uniform element
CUTLASS_HOST_DEVICE
static Matrix uniform(Element s) {
Matrix m;
m.data[0] = s;
m.data[1] = s;
return m;
}
/// Constructs a matrix from a uniform element 1
CUTLASS_HOST_DEVICE
static Matrix ones() {
return uniform(Element(1));
}
/// Constructs a matrix from a uniform element 0
CUTLASS_HOST_DEVICE
static Matrix zero() {
return Matrix();
}
/// Returns a transposed matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 2> transpose() const {
Matrix<Element, 1, 2> mt;
mt.data[0] = data[0];
mt.data[1] = data[1];
return mt;
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element at(int i, int j) const {
return data[i * 2 + j];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & at(int i, int j) {
return data[i * 2 + j];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element at(Coord<2> const &coord) const {
return at(coord[0], coord[1]);
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & at(Coord<2> const &coord) {
return at(coord[0], coord[1]);
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element &at(int offset) {
return data[offset];
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element at(int offset) const {
return data[offset];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element operator[](Coord<2> const &coord) const {
return at(coord[0], coord[1]);
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & operator[](Coord<2> const &coord) {
return at(coord[0], coord[1]);
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element & operator[](int offset) {
return data[offset];
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element operator[](int offset) const {
return data[offset];
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 1> slice_2x1(int i = 0, int j = 0) const {
Matrix<Element, 2, 1> m;
m.data[0] = data[i * 1 + j + 0];
m.data[1] = data[i * 1 + j + 1];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_2x1(Matrix<Element, 2, 1> const &m, int i = 0, int j = 0) {
data[i * 1 + j + 0] = m.data[0];
data[i * 1 + j + 1] = m.data[1];
return *this;
}
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 1> column(int j) const {
return slice_2x1(0, j);
}
CUTLASS_HOST_DEVICE
Matrix &set_column(Matrix<Element, 2, 1> const &v, int j =0) {
return set_slice_2x1(v, 0, j);
}
/// Concatenates this matrix with a a 2-by-1 matrix to form a 2-by-2 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 2> hcat(Matrix<Element, 2, 1> const & rhs) const {
return Matrix<Element, 2, 2>::hcat(*this, rhs);
}
/// Concatenates this matrix with a a 2-by-2 matrix to form a 2-by-3 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 3> hcat(Matrix<Element, 2, 2> const & rhs) const {
return Matrix<Element, 2, 3>::hcat(*this, rhs);
}
/// Concatenates this matrix with a a 2-by-3 matrix to form a 2-by-4 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 4> hcat(Matrix<Element, 2, 3> const & rhs) const {
return Matrix<Element, 2, 4>::hcat(*this, rhs);
}
/// Forms a 2-by-1 matrix by vertically concatenating an Element with an Element
CUTLASS_HOST_DEVICE
static Matrix vcat(Element upper, Element lower) {
return Matrix(
upper
, lower);
}
/// Concatenates this matrix with a an Element to form a 3-by-1 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 1> vcat(Element rhs) const {
return Matrix<Element, 3, 1>::vcat(*this, rhs);
}
/// Concatenates this matrix with a a 2-by-1 matrix to form a 4-by-1 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 1> vcat(Matrix<Element, 2, 1> const & rhs) const {
return Matrix<Element, 4, 1>::vcat(*this, rhs);
}
/// Elementwise add operator (2-by-1)
CUTLASS_HOST_DEVICE
Matrix add(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] + rhs.data[0];
result.data[1] = data[1] + rhs.data[1];
return result;
}
/// Elementwise add operator (2-by-1)
CUTLASS_HOST_DEVICE
Matrix operator +(Matrix const &rhs) const {
return add(rhs);
}
/// Elementwise add operator (2-by-1)
CUTLASS_HOST_DEVICE
Matrix & operator +=(Matrix const &rhs) {
data[0] += rhs.data[0];
data[1] += rhs.data[1];
return *this;
}
/// Elementwise subtract operator (2-by-1)
CUTLASS_HOST_DEVICE
Matrix subtract(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] - rhs.data[0];
result.data[1] = data[1] - rhs.data[1];
return result;
}
/// Elementwise subtract operator (2-by-1)
CUTLASS_HOST_DEVICE
Matrix operator -(Matrix const &rhs) const {
return subtract(rhs);
}
/// Elementwise subtract operator (2-by-1)
CUTLASS_HOST_DEVICE
Matrix & operator -=(Matrix const &rhs) {
data[0] -= rhs.data[0];
data[1] -= rhs.data[1];
return *this;
}
/// Elementwise multiply operator (2-by-1)
CUTLASS_HOST_DEVICE
Matrix multiply(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] * rhs.data[0];
result.data[1] = data[1] * rhs.data[1];
return result;
}
/// Scalar multiply operator (2-by-1)
CUTLASS_HOST_DEVICE
Matrix multiply(Element const &s) const {
Matrix result;
result.data[0] = data[0] * s;
result.data[1] = data[1] * s;
return result;
}
/// Scalar multiply operator (2-by-1)
CUTLASS_HOST_DEVICE
Matrix operator *(Element const &s) const {
return multiply(s);
}
/// Scalar multiply operator (2-by-1)
CUTLASS_HOST_DEVICE
Matrix & operator *=(Element const &s) {
data[0] *= s;
data[1] *= s;
return *this;
}
/// Elementwise divide operator (2-by-1)
CUTLASS_HOST_DEVICE
Matrix divide(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] / rhs.data[0];
result.data[1] = data[1] / rhs.data[1];
return result;
}
/// Scalar divide operator (2-by-1)
CUTLASS_HOST_DEVICE
Matrix divide(Element const &s) const {
Matrix result;
result.data[0] = data[0] / s;
result.data[1] = data[1] / s;
return result;
}
/// Scalar divide operator (2-by-1)
CUTLASS_HOST_DEVICE
Matrix operator /(Element const &s) const {
return divide(s);
}
/// Scalar divide operator (2-by-1)
CUTLASS_HOST_DEVICE
Matrix & operator /=(Element const &s) {
data[0] /= s;
data[1] /= s;
return *this;
}
/// Elementwise divide operator (2-by-1)
CUTLASS_HOST_DEVICE
Matrix operator /(Matrix const &rhs) const {
return divide(rhs);
}
/// Elementwise divide operator (2-by-1)
CUTLASS_HOST_DEVICE
Matrix & operator /=(Matrix const &rhs) {
data[0] /= rhs.data[0];
data[1] /= rhs.data[1];
return *this;
}
/// Negates each element of the matrix
CUTLASS_HOST_DEVICE
Matrix operator-() const {
Matrix m;
m.data[0] = -m.data[0];
m.data[1] = -m.data[1];
return m;
}
/// Matrix product of size 2-by-1-by-1
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 1> product(
Matrix<Element, 1, 1> const &rhs,
Matrix<Element, 2, 1> accum = Matrix<Element, 2, 1>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[1] * rhs.data[0];
return accum;
}
/// Matrix product of size 2-by-1-by-1
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 1> operator*(Matrix<Element, 1, 1> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 2-by-1-by-1
CUTLASS_HOST_DEVICE
Matrix & operator*=(Matrix<Element, 1, 1> const &rhs) {
*this = product(rhs);
return *this;
}
/// Matrix product of size 2-by-2-by-1
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 2> product(
Matrix<Element, 1, 2> const &rhs,
Matrix<Element, 2, 2> accum = Matrix<Element, 2, 2>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[1] * rhs.data[0];
accum.data[3] += data[1] * rhs.data[1];
return accum;
}
/// Matrix product of size 2-by-2-by-1
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 2> operator*(Matrix<Element, 1, 2> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 2-by-3-by-1
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 3> product(
Matrix<Element, 1, 3> const &rhs,
Matrix<Element, 2, 3> accum = Matrix<Element, 2, 3>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[0] * rhs.data[2];
accum.data[3] += data[1] * rhs.data[0];
accum.data[4] += data[1] * rhs.data[1];
accum.data[5] += data[1] * rhs.data[2];
return accum;
}
/// Matrix product of size 2-by-3-by-1
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 3> operator*(Matrix<Element, 1, 3> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 2-by-4-by-1
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 4> product(
Matrix<Element, 1, 4> const &rhs,
Matrix<Element, 2, 4> accum = Matrix<Element, 2, 4>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[0] * rhs.data[2];
accum.data[3] += data[0] * rhs.data[3];
accum.data[4] += data[1] * rhs.data[0];
accum.data[5] += data[1] * rhs.data[1];
accum.data[6] += data[1] * rhs.data[2];
accum.data[7] += data[1] * rhs.data[3];
return accum;
}
/// Matrix product of size 2-by-4-by-1
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 4> operator*(Matrix<Element, 1, 4> const &rhs) const {
return product(rhs);
}
/// Dot product of vectors with extent 2
CUTLASS_HOST_DEVICE
Element dot(Matrix<Element, 2, 1> const &rhs, Element accum = Element()) const {
accum += data[0] * rhs.data[0];
accum += data[1] * rhs.data[1];
return accum;
}
/// Dot product of vectors with extent 2
CUTLASS_HOST_DEVICE
Element dot(Matrix<Element, 1, 2> const &rhs, Element accum = Element()) const {
accum += data[0] * rhs.data[0];
accum += data[1] * rhs.data[1];
return accum;
}
/// Returns the sum of elements
CUTLASS_HOST_DEVICE
Element sum(Element accum = Element()) const {
accum += data[0];
accum += data[1];
return accum;
}
/// Returns the sum of squared elements
CUTLASS_HOST_DEVICE
Element norm(Element accum = Element()) const {
accum += data[0] * data[0];
accum += data[1] * data[1];
return accum;
}
/// Returns square root of the norm
CUTLASS_HOST_DEVICE
Element magnitude() const {
return fast_sqrt(norm());
}
/// Returns the sum of diagonal elements
CUTLASS_HOST_DEVICE
Element trace(Element accum = Element()) const {
accum += data[0];
return accum;
}
};
/// Template alias for 2-by-1 matrix
template <typename Element>
using Matrix2x1 = Matrix<Element, 2, 1>;
/// Free funciton to infer element type from template arguments
template <typename Element>
CUTLASS_HOST_DEVICE Matrix2x1<Element> make_Matrix2x1(
Element _0_0,
Element _1_0
) {
return Matrix2x1<Element>(
_0_0,
_1_0
);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// 2-by-2 matrix template class definition
template <typename Element_>
struct Matrix<Element_, 2, 2> {
//
// Type definitions
//
/// Element data type
using Element = Element_;
/// Number of rows in matrix
static int const kRows = 2;
/// Number of columns in matrix
static int const kColumns = 2;
/// Layout of matrix in underlying array
using Layout = layout::RowMajor;
/// Number of elements in matrix
static int const kCount = 4;
//
// Data members
//
/// Elements of the matrix in row-major layout
Array<Element, kCount> data;
//
// Methods
//
/// Constructs a zero matrix
CUTLASS_HOST_DEVICE
Matrix() {
data.clear();
}
/// Copy constructor for a 2-by-2 matrix
CUTLASS_HOST_DEVICE
Matrix(Matrix const &rhs) {
data = rhs.data;
}
/// Constucts a 2-by-2 matrix from scalar elements
CUTLASS_HOST_DEVICE
Matrix(
Element _0_0, Element _0_1,
Element _1_0, Element _1_1
) {
data[0] = _0_0; data[1] = _0_1;
data[2] = _1_0; data[3] = _1_1;
}
/// Constucts a 2-by-2 matrix from row vectors
CUTLASS_HOST_DEVICE
Matrix(
Matrix<Element, 1, 2> const &row_0,
Matrix<Element, 1, 2> const &row_1
) {
data[0] = row_0.data[0];
data[1] = row_0.data[1];
data[2] = row_1.data[0];
data[3] = row_1.data[1];
}
/// Static method to construct a 2-by-2 matrix from column vectors
CUTLASS_HOST_DEVICE
static Matrix from_columns(
Matrix<Element, 2, 1> const &column_0,
Matrix<Element, 2, 1> const &column_1
) {
Matrix result;
result.data[0] = column_0.data[0];
result.data[1] = column_1.data[0];
result.data[2] = column_0.data[1];
result.data[3] = column_1.data[1];
return result;
}
/// Constructs an identity matrix
CUTLASS_HOST_DEVICE
static Matrix identity() {
Matrix m;
m.data[0] = Element(1);
m.data[3] = Element(1);
return m;
}
/// Constructs a matrix from a uniform element
CUTLASS_HOST_DEVICE
static Matrix uniform(Element s) {
Matrix m;
m.data[0] = s;
m.data[1] = s;
m.data[2] = s;
m.data[3] = s;
return m;
}
/// Constructs a matrix from a uniform element 1
CUTLASS_HOST_DEVICE
static Matrix ones() {
return uniform(Element(1));
}
/// Constructs a matrix from a uniform element 0
CUTLASS_HOST_DEVICE
static Matrix zero() {
return Matrix();
}
/// Constructs a matrix from elements along its diagonal
CUTLASS_HOST_DEVICE
static Matrix from_diagonal(Matrix<Element, 2, 1> const &diag) {
Matrix m;
m.data[0] = diag.data[0];
m.data[3] = diag.data[1];
return m;
}
/// Constructs a matrix from elements along its diagonal
CUTLASS_HOST_DEVICE
static Matrix from_diagonal(Matrix<Element, 1, 2> const &diag) {
Matrix m;
m.data[0] = diag.data[0];
m.data[3] = diag.data[1];
return m;
}
/// Gets an array of diagonal elements
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 1> diagonal() const {
Matrix<Element, 2, 1> diag;
diag.data[0] = data[0];
diag.data[1] = data[3];
return diag;
}
/// Returns a transposed matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 2> transpose() const {
Matrix<Element, 2, 2> mt;
mt.data[0] = data[0];
mt.data[2] = data[1];
mt.data[1] = data[2];
mt.data[3] = data[3];
return mt;
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element at(int i, int j) const {
return data[i * 2 + j];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & at(int i, int j) {
return data[i * 2 + j];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element at(Coord<2> const &coord) const {
return at(coord[0], coord[1]);
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & at(Coord<2> const &coord) {
return at(coord[0], coord[1]);
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element &at(int offset) {
return data[offset];
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element at(int offset) const {
return data[offset];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element operator[](Coord<2> const &coord) const {
return at(coord[0], coord[1]);
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & operator[](Coord<2> const &coord) {
return at(coord[0], coord[1]);
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element & operator[](int offset) {
return data[offset];
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element operator[](int offset) const {
return data[offset];
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 2> slice_1x2(int i = 0, int j = 0) const {
Matrix<Element, 1, 2> m;
m.data[0] = data[i * 2 + j + 0];
m.data[1] = data[i * 2 + j + 1];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_1x2(Matrix<Element, 1, 2> const &m, int i = 0, int j = 0) {
data[i * 2 + j + 0] = m.data[0];
data[i * 2 + j + 1] = m.data[1];
return *this;
}
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 2> row(int i) const {
return slice_1x2(i, 0);
}
CUTLASS_HOST_DEVICE
Matrix &set_row(Matrix<Element, 1, 2> const &v, int i = 0) {
return set_slice_1x2(v, i, 0);
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 1> slice_2x1(int i = 0, int j = 0) const {
Matrix<Element, 2, 1> m;
m.data[0] = data[i * 2 + j + 0];
m.data[1] = data[i * 2 + j + 2];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_2x1(Matrix<Element, 2, 1> const &m, int i = 0, int j = 0) {
data[i * 2 + j + 0] = m.data[0];
data[i * 2 + j + 2] = m.data[1];
return *this;
}
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 1> column(int j) const {
return slice_2x1(0, j);
}
CUTLASS_HOST_DEVICE
Matrix &set_column(Matrix<Element, 2, 1> const &v, int j =0) {
return set_slice_2x1(v, 0, j);
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 2> slice_2x2(int i = 0, int j = 0) const {
Matrix<Element, 2, 2> m;
m.data[0] = data[i * 2 + j + 0];
m.data[1] = data[i * 2 + j + 1];
m.data[2] = data[i * 2 + j + 2];
m.data[3] = data[i * 2 + j + 3];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_2x2(Matrix<Element, 2, 2> const &m, int i = 0, int j = 0) {
data[i * 2 + j + 0] = m.data[0];
data[i * 2 + j + 1] = m.data[1];
data[i * 2 + j + 2] = m.data[2];
data[i * 2 + j + 3] = m.data[3];
return *this;
}
/// Forms a 2-by-2 matrix by horizontally concatenating a 2-by-1 matrix with a 2-by-1 matrix
CUTLASS_HOST_DEVICE
static Matrix hcat(Matrix<Element, 2, 1> const & lhs, Matrix<Element, 2, 1> const & rhs) {
return Matrix(
lhs.at(0, 0), rhs.at(0, 0)
, lhs.at(1, 0), rhs.at(1, 0));
}
/// Concatenates this matrix with a a 2-by-1 matrix to form a 2-by-3 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 3> hcat(Matrix<Element, 2, 1> const & rhs) const {
return Matrix<Element, 2, 3>::hcat(*this, rhs);
}
/// Concatenates this matrix with a a 2-by-2 matrix to form a 2-by-4 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 4> hcat(Matrix<Element, 2, 2> const & rhs) const {
return Matrix<Element, 2, 4>::hcat(*this, rhs);
}
/// Forms a 2-by-2 matrix by vertically concatenating a 1-by-2 matrix with a 1-by-2 matrix
CUTLASS_HOST_DEVICE
static Matrix vcat(Matrix<Element, 1, 2> const & upper, Matrix<Element, 1, 2> const & lower) {
return Matrix(
upper.at(0, 0), upper.at(0, 1)
, lower.at(0, 0), lower.at(0, 1));
}
/// Concatenates this matrix with a a 1-by-2 matrix to form a 3-by-2 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 2> vcat(Matrix<Element, 1, 2> const & rhs) const {
return Matrix<Element, 3, 2>::vcat(*this, rhs);
}
/// Concatenates this matrix with a a 2-by-2 matrix to form a 4-by-2 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 2> vcat(Matrix<Element, 2, 2> const & rhs) const {
return Matrix<Element, 4, 2>::vcat(*this, rhs);
}
/// Forms a 2-by-2 matrix by concatenating four components
CUTLASS_HOST_DEVICE
static Matrix block(
Element A, Element B,
Element C, Element D) {
return Matrix(
A, B
, C, D
);
}
/// Elementwise add operator (2-by-2)
CUTLASS_HOST_DEVICE
Matrix add(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] + rhs.data[0];
result.data[1] = data[1] + rhs.data[1];
result.data[2] = data[2] + rhs.data[2];
result.data[3] = data[3] + rhs.data[3];
return result;
}
/// Elementwise add operator (2-by-2)
CUTLASS_HOST_DEVICE
Matrix operator +(Matrix const &rhs) const {
return add(rhs);
}
/// Elementwise add operator (2-by-2)
CUTLASS_HOST_DEVICE
Matrix & operator +=(Matrix const &rhs) {
data[0] += rhs.data[0];
data[1] += rhs.data[1];
data[2] += rhs.data[2];
data[3] += rhs.data[3];
return *this;
}
/// Elementwise subtract operator (2-by-2)
CUTLASS_HOST_DEVICE
Matrix subtract(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] - rhs.data[0];
result.data[1] = data[1] - rhs.data[1];
result.data[2] = data[2] - rhs.data[2];
result.data[3] = data[3] - rhs.data[3];
return result;
}
/// Elementwise subtract operator (2-by-2)
CUTLASS_HOST_DEVICE
Matrix operator -(Matrix const &rhs) const {
return subtract(rhs);
}
/// Elementwise subtract operator (2-by-2)
CUTLASS_HOST_DEVICE
Matrix & operator -=(Matrix const &rhs) {
data[0] -= rhs.data[0];
data[1] -= rhs.data[1];
data[2] -= rhs.data[2];
data[3] -= rhs.data[3];
return *this;
}
/// Elementwise multiply operator (2-by-2)
CUTLASS_HOST_DEVICE
Matrix multiply(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] * rhs.data[0];
result.data[1] = data[1] * rhs.data[1];
result.data[2] = data[2] * rhs.data[2];
result.data[3] = data[3] * rhs.data[3];
return result;
}
/// Scalar multiply operator (2-by-2)
CUTLASS_HOST_DEVICE
Matrix multiply(Element const &s) const {
Matrix result;
result.data[0] = data[0] * s;
result.data[1] = data[1] * s;
result.data[2] = data[2] * s;
result.data[3] = data[3] * s;
return result;
}
/// Scalar multiply operator (2-by-2)
CUTLASS_HOST_DEVICE
Matrix operator *(Element const &s) const {
return multiply(s);
}
/// Scalar multiply operator (2-by-2)
CUTLASS_HOST_DEVICE
Matrix & operator *=(Element const &s) {
data[0] *= s;
data[1] *= s;
data[2] *= s;
data[3] *= s;
return *this;
}
/// Elementwise divide operator (2-by-2)
CUTLASS_HOST_DEVICE
Matrix divide(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] / rhs.data[0];
result.data[1] = data[1] / rhs.data[1];
result.data[2] = data[2] / rhs.data[2];
result.data[3] = data[3] / rhs.data[3];
return result;
}
/// Scalar divide operator (2-by-2)
CUTLASS_HOST_DEVICE
Matrix divide(Element const &s) const {
Matrix result;
result.data[0] = data[0] / s;
result.data[1] = data[1] / s;
result.data[2] = data[2] / s;
result.data[3] = data[3] / s;
return result;
}
/// Scalar divide operator (2-by-2)
CUTLASS_HOST_DEVICE
Matrix operator /(Element const &s) const {
return divide(s);
}
/// Scalar divide operator (2-by-2)
CUTLASS_HOST_DEVICE
Matrix & operator /=(Element const &s) {
data[0] /= s;
data[1] /= s;
data[2] /= s;
data[3] /= s;
return *this;
}
/// Elementwise divide operator (2-by-2)
CUTLASS_HOST_DEVICE
Matrix operator /(Matrix const &rhs) const {
return divide(rhs);
}
/// Elementwise divide operator (2-by-2)
CUTLASS_HOST_DEVICE
Matrix & operator /=(Matrix const &rhs) {
data[0] /= rhs.data[0];
data[1] /= rhs.data[1];
data[2] /= rhs.data[2];
data[3] /= rhs.data[3];
return *this;
}
/// Negates each element of the matrix
CUTLASS_HOST_DEVICE
Matrix operator-() const {
Matrix m;
m.data[0] = -m.data[0];
m.data[1] = -m.data[1];
m.data[2] = -m.data[2];
m.data[3] = -m.data[3];
return m;
}
/// Matrix product of size 2-by-1-by-2
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 1> product(
Matrix<Element, 2, 1> const &rhs,
Matrix<Element, 2, 1> accum = Matrix<Element, 2, 1>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[2] * rhs.data[0];
// k=1
accum.data[0] += data[1] * rhs.data[1];
accum.data[1] += data[3] * rhs.data[1];
return accum;
}
/// Matrix product of size 2-by-1-by-2
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 1> operator*(Matrix<Element, 2, 1> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 2-by-2-by-2
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 2> product(
Matrix<Element, 2, 2> const &rhs,
Matrix<Element, 2, 2> accum = Matrix<Element, 2, 2>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[2] * rhs.data[0];
accum.data[3] += data[2] * rhs.data[1];
// k=1
accum.data[0] += data[1] * rhs.data[2];
accum.data[1] += data[1] * rhs.data[3];
accum.data[2] += data[3] * rhs.data[2];
accum.data[3] += data[3] * rhs.data[3];
return accum;
}
/// Matrix product of size 2-by-2-by-2
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 2> operator*(Matrix<Element, 2, 2> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 2-by-2-by-2
CUTLASS_HOST_DEVICE
Matrix & operator*=(Matrix<Element, 2, 2> const &rhs) {
*this = product(rhs);
return *this;
}
/// Matrix product of size 2-by-3-by-2
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 3> product(
Matrix<Element, 2, 3> const &rhs,
Matrix<Element, 2, 3> accum = Matrix<Element, 2, 3>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[0] * rhs.data[2];
accum.data[3] += data[2] * rhs.data[0];
accum.data[4] += data[2] * rhs.data[1];
accum.data[5] += data[2] * rhs.data[2];
// k=1
accum.data[0] += data[1] * rhs.data[3];
accum.data[1] += data[1] * rhs.data[4];
accum.data[2] += data[1] * rhs.data[5];
accum.data[3] += data[3] * rhs.data[3];
accum.data[4] += data[3] * rhs.data[4];
accum.data[5] += data[3] * rhs.data[5];
return accum;
}
/// Matrix product of size 2-by-3-by-2
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 3> operator*(Matrix<Element, 2, 3> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 2-by-4-by-2
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 4> product(
Matrix<Element, 2, 4> const &rhs,
Matrix<Element, 2, 4> accum = Matrix<Element, 2, 4>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[0] * rhs.data[2];
accum.data[3] += data[0] * rhs.data[3];
accum.data[4] += data[2] * rhs.data[0];
accum.data[5] += data[2] * rhs.data[1];
accum.data[6] += data[2] * rhs.data[2];
accum.data[7] += data[2] * rhs.data[3];
// k=1
accum.data[0] += data[1] * rhs.data[4];
accum.data[1] += data[1] * rhs.data[5];
accum.data[2] += data[1] * rhs.data[6];
accum.data[3] += data[1] * rhs.data[7];
accum.data[4] += data[3] * rhs.data[4];
accum.data[5] += data[3] * rhs.data[5];
accum.data[6] += data[3] * rhs.data[6];
accum.data[7] += data[3] * rhs.data[7];
return accum;
}
/// Matrix product of size 2-by-4-by-2
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 4> operator*(Matrix<Element, 2, 4> const &rhs) const {
return product(rhs);
}
/// Returns the sum of elements
CUTLASS_HOST_DEVICE
Element sum(Element accum = Element()) const {
accum += data[0];
accum += data[1];
accum += data[2];
accum += data[3];
return accum;
}
/// Returns the sum of squared elements
CUTLASS_HOST_DEVICE
Element norm(Element accum = Element()) const {
accum += data[0] * data[0];
accum += data[1] * data[1];
accum += data[2] * data[2];
accum += data[3] * data[3];
return accum;
}
/// Returns square root of the norm
CUTLASS_HOST_DEVICE
Element magnitude() const {
return fast_sqrt(norm());
}
/// Returns the sum of diagonal elements
CUTLASS_HOST_DEVICE
Element trace(Element accum = Element()) const {
accum += data[0];
accum += data[3];
return accum;
}
/// Returns 2-by-2 rotation matrix
CUTLASS_HOST_DEVICE
static Matrix rotation(Element theta) {
Element c = fast_cos(theta);
Element s = fast_sin(theta);
return Matrix(
c, -s,
s, c
);
}
/// Computes the determinant of a 2-by-2 matrix
CUTLASS_HOST_DEVICE
Element determinant(Element accum = Element()) const {
accum += data[0] * data[3] - data[1] * data[2];
return accum;
}
/// Computes the inverse of a 2-by-2 matrix given
/// the matrix's determinant
CUTLASS_HOST_DEVICE
Matrix inverse(Element det) const {
return Matrix(
data[3], -data[1],
-data[2], data[0]
) * (Element(1) / det);
}
/// Computes the inverse of a 2-by-2 matrix.
CUTLASS_HOST_DEVICE
Matrix inverse() const {
return inverse(determinant());
}
};
/// Template alias for 2-by-2 matrix
template <typename Element>
using Matrix2x2 = Matrix<Element, 2, 2>;
/// Free funciton to infer element type from template arguments
template <typename Element>
CUTLASS_HOST_DEVICE Matrix2x2<Element> make_Matrix2x2(
Element _0_0, Element _0_1,
Element _1_0, Element _1_1
) {
return Matrix2x2<Element>(
_0_0, _0_1,
_1_0, _1_1
);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// 2-by-3 matrix template class definition
template <typename Element_>
struct Matrix<Element_, 2, 3> {
//
// Type definitions
//
/// Element data type
using Element = Element_;
/// Number of rows in matrix
static int const kRows = 2;
/// Number of columns in matrix
static int const kColumns = 3;
/// Layout of matrix in underlying array
using Layout = layout::RowMajor;
/// Number of elements in matrix
static int const kCount = 6;
//
// Data members
//
/// Elements of the matrix in row-major layout
Array<Element, kCount> data;
//
// Methods
//
/// Constructs a zero matrix
CUTLASS_HOST_DEVICE
Matrix() {
data.clear();
}
/// Copy constructor for a 2-by-3 matrix
CUTLASS_HOST_DEVICE
Matrix(Matrix const &rhs) {
data = rhs.data;
}
/// Constucts a 2-by-3 matrix from scalar elements
CUTLASS_HOST_DEVICE
Matrix(
Element _0_0, Element _0_1, Element _0_2,
Element _1_0, Element _1_1, Element _1_2
) {
data[0] = _0_0; data[1] = _0_1; data[2] = _0_2;
data[3] = _1_0; data[4] = _1_1; data[5] = _1_2;
}
/// Constucts a 2-by-3 matrix from row vectors
CUTLASS_HOST_DEVICE
Matrix(
Matrix<Element, 1, 3> const &row_0,
Matrix<Element, 1, 3> const &row_1
) {
data[0] = row_0.data[0];
data[1] = row_0.data[1];
data[2] = row_0.data[2];
data[3] = row_1.data[0];
data[4] = row_1.data[1];
data[5] = row_1.data[2];
}
/// Static method to construct a 2-by-3 matrix from column vectors
CUTLASS_HOST_DEVICE
static Matrix from_columns(
Matrix<Element, 3, 1> const &column_0,
Matrix<Element, 3, 1> const &column_1,
Matrix<Element, 3, 1> const &column_2
) {
Matrix result;
result.data[0] = column_0.data[0];
result.data[1] = column_1.data[0];
result.data[2] = column_2.data[0];
result.data[3] = column_0.data[1];
result.data[4] = column_1.data[1];
result.data[5] = column_2.data[1];
return result;
}
/// Constructs a matrix from a uniform element
CUTLASS_HOST_DEVICE
static Matrix uniform(Element s) {
Matrix m;
m.data[0] = s;
m.data[1] = s;
m.data[2] = s;
m.data[3] = s;
m.data[4] = s;
m.data[5] = s;
return m;
}
/// Constructs a matrix from a uniform element 1
CUTLASS_HOST_DEVICE
static Matrix ones() {
return uniform(Element(1));
}
/// Constructs a matrix from a uniform element 0
CUTLASS_HOST_DEVICE
static Matrix zero() {
return Matrix();
}
/// Constructs a matrix from elements along its diagonal
CUTLASS_HOST_DEVICE
static Matrix from_diagonal(Matrix<Element, 2, 1> const &diag) {
Matrix m;
m.data[0] = diag.data[0];
m.data[3] = diag.data[1];
return m;
}
/// Constructs a matrix from elements along its diagonal
CUTLASS_HOST_DEVICE
static Matrix from_diagonal(Matrix<Element, 1, 2> const &diag) {
Matrix m;
m.data[0] = diag.data[0];
m.data[3] = diag.data[1];
return m;
}
/// Gets an array of diagonal elements
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 1> diagonal() const {
Matrix<Element, 2, 1> diag;
diag.data[0] = data[0];
diag.data[1] = data[3];
return diag;
}
/// Returns a transposed matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 2> transpose() const {
Matrix<Element, 3, 2> mt;
mt.data[0] = data[0];
mt.data[2] = data[1];
mt.data[4] = data[2];
mt.data[1] = data[3];
mt.data[3] = data[4];
mt.data[5] = data[5];
return mt;
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element at(int i, int j) const {
return data[i * 2 + j];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & at(int i, int j) {
return data[i * 2 + j];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element at(Coord<2> const &coord) const {
return at(coord[0], coord[1]);
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & at(Coord<2> const &coord) {
return at(coord[0], coord[1]);
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element &at(int offset) {
return data[offset];
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element at(int offset) const {
return data[offset];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element operator[](Coord<2> const &coord) const {
return at(coord[0], coord[1]);
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & operator[](Coord<2> const &coord) {
return at(coord[0], coord[1]);
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element & operator[](int offset) {
return data[offset];
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element operator[](int offset) const {
return data[offset];
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 2> slice_1x2(int i = 0, int j = 0) const {
Matrix<Element, 1, 2> m;
m.data[0] = data[i * 3 + j + 0];
m.data[1] = data[i * 3 + j + 1];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_1x2(Matrix<Element, 1, 2> const &m, int i = 0, int j = 0) {
data[i * 3 + j + 0] = m.data[0];
data[i * 3 + j + 1] = m.data[1];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 3> slice_1x3(int i = 0, int j = 0) const {
Matrix<Element, 1, 3> m;
m.data[0] = data[i * 3 + j + 0];
m.data[1] = data[i * 3 + j + 1];
m.data[2] = data[i * 3 + j + 2];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_1x3(Matrix<Element, 1, 3> const &m, int i = 0, int j = 0) {
data[i * 3 + j + 0] = m.data[0];
data[i * 3 + j + 1] = m.data[1];
data[i * 3 + j + 2] = m.data[2];
return *this;
}
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 3> row(int i) const {
return slice_1x3(i, 0);
}
CUTLASS_HOST_DEVICE
Matrix &set_row(Matrix<Element, 1, 3> const &v, int i = 0) {
return set_slice_1x3(v, i, 0);
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 1> slice_2x1(int i = 0, int j = 0) const {
Matrix<Element, 2, 1> m;
m.data[0] = data[i * 3 + j + 0];
m.data[1] = data[i * 3 + j + 3];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_2x1(Matrix<Element, 2, 1> const &m, int i = 0, int j = 0) {
data[i * 3 + j + 0] = m.data[0];
data[i * 3 + j + 3] = m.data[1];
return *this;
}
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 1> column(int j) const {
return slice_2x1(0, j);
}
CUTLASS_HOST_DEVICE
Matrix &set_column(Matrix<Element, 2, 1> const &v, int j =0) {
return set_slice_2x1(v, 0, j);
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 2> slice_2x2(int i = 0, int j = 0) const {
Matrix<Element, 2, 2> m;
m.data[0] = data[i * 3 + j + 0];
m.data[1] = data[i * 3 + j + 1];
m.data[2] = data[i * 3 + j + 3];
m.data[3] = data[i * 3 + j + 4];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_2x2(Matrix<Element, 2, 2> const &m, int i = 0, int j = 0) {
data[i * 3 + j + 0] = m.data[0];
data[i * 3 + j + 1] = m.data[1];
data[i * 3 + j + 3] = m.data[2];
data[i * 3 + j + 4] = m.data[3];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 3> slice_2x3(int i = 0, int j = 0) const {
Matrix<Element, 2, 3> m;
m.data[0] = data[i * 3 + j + 0];
m.data[1] = data[i * 3 + j + 1];
m.data[2] = data[i * 3 + j + 2];
m.data[3] = data[i * 3 + j + 3];
m.data[4] = data[i * 3 + j + 4];
m.data[5] = data[i * 3 + j + 5];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_2x3(Matrix<Element, 2, 3> const &m, int i = 0, int j = 0) {
data[i * 3 + j + 0] = m.data[0];
data[i * 3 + j + 1] = m.data[1];
data[i * 3 + j + 2] = m.data[2];
data[i * 3 + j + 3] = m.data[3];
data[i * 3 + j + 4] = m.data[4];
data[i * 3 + j + 5] = m.data[5];
return *this;
}
/// Forms a 2-by-3 matrix by horizontally concatenating a 2-by-1 matrix with a 2-by-2 matrix
CUTLASS_HOST_DEVICE
static Matrix hcat(Matrix<Element, 2, 1> const & lhs, Matrix<Element, 2, 2> const & rhs) {
return Matrix(
lhs.at(0, 0), rhs.at(0, 0), rhs.at(0, 1)
, lhs.at(1, 0), rhs.at(1, 0), rhs.at(1, 1));
}
/// Forms a 2-by-3 matrix by horizontally concatenating a 2-by-2 matrix with a 2-by-1 matrix
CUTLASS_HOST_DEVICE
static Matrix hcat(Matrix<Element, 2, 2> const & lhs, Matrix<Element, 2, 1> const & rhs) {
return Matrix(
lhs.at(0, 0), lhs.at(0, 1), rhs.at(0, 0)
, lhs.at(1, 0), lhs.at(1, 1), rhs.at(1, 0));
}
/// Concatenates this matrix with a a 2-by-1 matrix to form a 2-by-4 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 4> hcat(Matrix<Element, 2, 1> const & rhs) const {
return Matrix<Element, 2, 4>::hcat(*this, rhs);
}
/// Forms a 2-by-3 matrix by vertically concatenating a 1-by-3 matrix with a 1-by-3 matrix
CUTLASS_HOST_DEVICE
static Matrix vcat(Matrix<Element, 1, 3> const & upper, Matrix<Element, 1, 3> const & lower) {
return Matrix(
upper.at(0, 0), upper.at(0, 1), upper.at(0, 2)
, lower.at(0, 0), lower.at(0, 1), lower.at(0, 2));
}
/// Concatenates this matrix with a a 1-by-3 matrix to form a 3-by-3 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 3> vcat(Matrix<Element, 1, 3> const & rhs) const {
return Matrix<Element, 3, 3>::vcat(*this, rhs);
}
/// Concatenates this matrix with a a 2-by-3 matrix to form a 4-by-3 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 3> vcat(Matrix<Element, 2, 3> const & rhs) const {
return Matrix<Element, 4, 3>::vcat(*this, rhs);
}
/// Forms a 2-by-3 matrix by concatenating four components
CUTLASS_HOST_DEVICE
static Matrix block(
Element A, Matrix<Element, 1, 2> const & B,
Element C, Matrix<Element, 1, 2> const & D) {
return Matrix(
A, B.at(0, 0), B.at(0, 1)
, C, D.at(0, 0), D.at(0, 1)
);
}
/// Forms a 2-by-3 matrix by concatenating four components
CUTLASS_HOST_DEVICE
static Matrix block(
Matrix<Element, 1, 2> const & A, Element B,
Matrix<Element, 1, 2> const & C, Element D) {
return Matrix(
A.at(0, 0), A.at(0, 1), B
, C.at(0, 0), C.at(0, 1), D
);
}
/// Elementwise add operator (2-by-3)
CUTLASS_HOST_DEVICE
Matrix add(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] + rhs.data[0];
result.data[1] = data[1] + rhs.data[1];
result.data[2] = data[2] + rhs.data[2];
result.data[3] = data[3] + rhs.data[3];
result.data[4] = data[4] + rhs.data[4];
result.data[5] = data[5] + rhs.data[5];
return result;
}
/// Elementwise add operator (2-by-3)
CUTLASS_HOST_DEVICE
Matrix operator +(Matrix const &rhs) const {
return add(rhs);
}
/// Elementwise add operator (2-by-3)
CUTLASS_HOST_DEVICE
Matrix & operator +=(Matrix const &rhs) {
data[0] += rhs.data[0];
data[1] += rhs.data[1];
data[2] += rhs.data[2];
data[3] += rhs.data[3];
data[4] += rhs.data[4];
data[5] += rhs.data[5];
return *this;
}
/// Elementwise subtract operator (2-by-3)
CUTLASS_HOST_DEVICE
Matrix subtract(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] - rhs.data[0];
result.data[1] = data[1] - rhs.data[1];
result.data[2] = data[2] - rhs.data[2];
result.data[3] = data[3] - rhs.data[3];
result.data[4] = data[4] - rhs.data[4];
result.data[5] = data[5] - rhs.data[5];
return result;
}
/// Elementwise subtract operator (2-by-3)
CUTLASS_HOST_DEVICE
Matrix operator -(Matrix const &rhs) const {
return subtract(rhs);
}
/// Elementwise subtract operator (2-by-3)
CUTLASS_HOST_DEVICE
Matrix & operator -=(Matrix const &rhs) {
data[0] -= rhs.data[0];
data[1] -= rhs.data[1];
data[2] -= rhs.data[2];
data[3] -= rhs.data[3];
data[4] -= rhs.data[4];
data[5] -= rhs.data[5];
return *this;
}
/// Elementwise multiply operator (2-by-3)
CUTLASS_HOST_DEVICE
Matrix multiply(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] * rhs.data[0];
result.data[1] = data[1] * rhs.data[1];
result.data[2] = data[2] * rhs.data[2];
result.data[3] = data[3] * rhs.data[3];
result.data[4] = data[4] * rhs.data[4];
result.data[5] = data[5] * rhs.data[5];
return result;
}
/// Scalar multiply operator (2-by-3)
CUTLASS_HOST_DEVICE
Matrix multiply(Element const &s) const {
Matrix result;
result.data[0] = data[0] * s;
result.data[1] = data[1] * s;
result.data[2] = data[2] * s;
result.data[3] = data[3] * s;
result.data[4] = data[4] * s;
result.data[5] = data[5] * s;
return result;
}
/// Scalar multiply operator (2-by-3)
CUTLASS_HOST_DEVICE
Matrix operator *(Element const &s) const {
return multiply(s);
}
/// Scalar multiply operator (2-by-3)
CUTLASS_HOST_DEVICE
Matrix & operator *=(Element const &s) {
data[0] *= s;
data[1] *= s;
data[2] *= s;
data[3] *= s;
data[4] *= s;
data[5] *= s;
return *this;
}
/// Elementwise divide operator (2-by-3)
CUTLASS_HOST_DEVICE
Matrix divide(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] / rhs.data[0];
result.data[1] = data[1] / rhs.data[1];
result.data[2] = data[2] / rhs.data[2];
result.data[3] = data[3] / rhs.data[3];
result.data[4] = data[4] / rhs.data[4];
result.data[5] = data[5] / rhs.data[5];
return result;
}
/// Scalar divide operator (2-by-3)
CUTLASS_HOST_DEVICE
Matrix divide(Element const &s) const {
Matrix result;
result.data[0] = data[0] / s;
result.data[1] = data[1] / s;
result.data[2] = data[2] / s;
result.data[3] = data[3] / s;
result.data[4] = data[4] / s;
result.data[5] = data[5] / s;
return result;
}
/// Scalar divide operator (2-by-3)
CUTLASS_HOST_DEVICE
Matrix operator /(Element const &s) const {
return divide(s);
}
/// Scalar divide operator (2-by-3)
CUTLASS_HOST_DEVICE
Matrix & operator /=(Element const &s) {
data[0] /= s;
data[1] /= s;
data[2] /= s;
data[3] /= s;
data[4] /= s;
data[5] /= s;
return *this;
}
/// Elementwise divide operator (2-by-3)
CUTLASS_HOST_DEVICE
Matrix operator /(Matrix const &rhs) const {
return divide(rhs);
}
/// Elementwise divide operator (2-by-3)
CUTLASS_HOST_DEVICE
Matrix & operator /=(Matrix const &rhs) {
data[0] /= rhs.data[0];
data[1] /= rhs.data[1];
data[2] /= rhs.data[2];
data[3] /= rhs.data[3];
data[4] /= rhs.data[4];
data[5] /= rhs.data[5];
return *this;
}
/// Negates each element of the matrix
CUTLASS_HOST_DEVICE
Matrix operator-() const {
Matrix m;
m.data[0] = -m.data[0];
m.data[1] = -m.data[1];
m.data[2] = -m.data[2];
m.data[3] = -m.data[3];
m.data[4] = -m.data[4];
m.data[5] = -m.data[5];
return m;
}
/// Matrix product of size 2-by-1-by-3
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 1> product(
Matrix<Element, 3, 1> const &rhs,
Matrix<Element, 2, 1> accum = Matrix<Element, 2, 1>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[3] * rhs.data[0];
// k=1
accum.data[0] += data[1] * rhs.data[1];
accum.data[1] += data[4] * rhs.data[1];
// k=2
accum.data[0] += data[2] * rhs.data[2];
accum.data[1] += data[5] * rhs.data[2];
return accum;
}
/// Matrix product of size 2-by-1-by-3
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 1> operator*(Matrix<Element, 3, 1> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 2-by-2-by-3
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 2> product(
Matrix<Element, 3, 2> const &rhs,
Matrix<Element, 2, 2> accum = Matrix<Element, 2, 2>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[3] * rhs.data[0];
accum.data[3] += data[3] * rhs.data[1];
// k=1
accum.data[0] += data[1] * rhs.data[2];
accum.data[1] += data[1] * rhs.data[3];
accum.data[2] += data[4] * rhs.data[2];
accum.data[3] += data[4] * rhs.data[3];
// k=2
accum.data[0] += data[2] * rhs.data[4];
accum.data[1] += data[2] * rhs.data[5];
accum.data[2] += data[5] * rhs.data[4];
accum.data[3] += data[5] * rhs.data[5];
return accum;
}
/// Matrix product of size 2-by-2-by-3
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 2> operator*(Matrix<Element, 3, 2> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 2-by-3-by-3
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 3> product(
Matrix<Element, 3, 3> const &rhs,
Matrix<Element, 2, 3> accum = Matrix<Element, 2, 3>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[0] * rhs.data[2];
accum.data[3] += data[3] * rhs.data[0];
accum.data[4] += data[3] * rhs.data[1];
accum.data[5] += data[3] * rhs.data[2];
// k=1
accum.data[0] += data[1] * rhs.data[3];
accum.data[1] += data[1] * rhs.data[4];
accum.data[2] += data[1] * rhs.data[5];
accum.data[3] += data[4] * rhs.data[3];
accum.data[4] += data[4] * rhs.data[4];
accum.data[5] += data[4] * rhs.data[5];
// k=2
accum.data[0] += data[2] * rhs.data[6];
accum.data[1] += data[2] * rhs.data[7];
accum.data[2] += data[2] * rhs.data[8];
accum.data[3] += data[5] * rhs.data[6];
accum.data[4] += data[5] * rhs.data[7];
accum.data[5] += data[5] * rhs.data[8];
return accum;
}
/// Matrix product of size 2-by-3-by-3
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 3> operator*(Matrix<Element, 3, 3> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 2-by-3-by-3
CUTLASS_HOST_DEVICE
Matrix & operator*=(Matrix<Element, 3, 3> const &rhs) {
*this = product(rhs);
return *this;
}
/// Matrix product of size 2-by-4-by-3
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 4> product(
Matrix<Element, 3, 4> const &rhs,
Matrix<Element, 2, 4> accum = Matrix<Element, 2, 4>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[0] * rhs.data[2];
accum.data[3] += data[0] * rhs.data[3];
accum.data[4] += data[3] * rhs.data[0];
accum.data[5] += data[3] * rhs.data[1];
accum.data[6] += data[3] * rhs.data[2];
accum.data[7] += data[3] * rhs.data[3];
// k=1
accum.data[0] += data[1] * rhs.data[4];
accum.data[1] += data[1] * rhs.data[5];
accum.data[2] += data[1] * rhs.data[6];
accum.data[3] += data[1] * rhs.data[7];
accum.data[4] += data[4] * rhs.data[4];
accum.data[5] += data[4] * rhs.data[5];
accum.data[6] += data[4] * rhs.data[6];
accum.data[7] += data[4] * rhs.data[7];
// k=2
accum.data[0] += data[2] * rhs.data[8];
accum.data[1] += data[2] * rhs.data[9];
accum.data[2] += data[2] * rhs.data[10];
accum.data[3] += data[2] * rhs.data[11];
accum.data[4] += data[5] * rhs.data[8];
accum.data[5] += data[5] * rhs.data[9];
accum.data[6] += data[5] * rhs.data[10];
accum.data[7] += data[5] * rhs.data[11];
return accum;
}
/// Matrix product of size 2-by-4-by-3
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 4> operator*(Matrix<Element, 3, 4> const &rhs) const {
return product(rhs);
}
/// Returns the sum of elements
CUTLASS_HOST_DEVICE
Element sum(Element accum = Element()) const {
accum += data[0];
accum += data[1];
accum += data[2];
accum += data[3];
accum += data[4];
accum += data[5];
return accum;
}
/// Returns the sum of squared elements
CUTLASS_HOST_DEVICE
Element norm(Element accum = Element()) const {
accum += data[0] * data[0];
accum += data[1] * data[1];
accum += data[2] * data[2];
accum += data[3] * data[3];
accum += data[4] * data[4];
accum += data[5] * data[5];
return accum;
}
/// Returns square root of the norm
CUTLASS_HOST_DEVICE
Element magnitude() const {
return fast_sqrt(norm());
}
/// Returns the sum of diagonal elements
CUTLASS_HOST_DEVICE
Element trace(Element accum = Element()) const {
accum += data[0];
accum += data[4];
return accum;
}
};
/// Template alias for 2-by-3 matrix
template <typename Element>
using Matrix2x3 = Matrix<Element, 2, 3>;
/// Free funciton to infer element type from template arguments
template <typename Element>
CUTLASS_HOST_DEVICE Matrix2x3<Element> make_Matrix2x3(
Element _0_0, Element _0_1, Element _0_2,
Element _1_0, Element _1_1, Element _1_2
) {
return Matrix2x3<Element>(
_0_0, _0_1, _0_2,
_1_0, _1_1, _1_2
);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// 2-by-4 matrix template class definition
template <typename Element_>
struct Matrix<Element_, 2, 4> {
//
// Type definitions
//
/// Element data type
using Element = Element_;
/// Number of rows in matrix
static int const kRows = 2;
/// Number of columns in matrix
static int const kColumns = 4;
/// Layout of matrix in underlying array
using Layout = layout::RowMajor;
/// Number of elements in matrix
static int const kCount = 8;
//
// Data members
//
/// Elements of the matrix in row-major layout
Array<Element, kCount> data;
//
// Methods
//
/// Constructs a zero matrix
CUTLASS_HOST_DEVICE
Matrix() {
data.clear();
}
/// Copy constructor for a 2-by-4 matrix
CUTLASS_HOST_DEVICE
Matrix(Matrix const &rhs) {
data = rhs.data;
}
/// Constucts a 2-by-4 matrix from scalar elements
CUTLASS_HOST_DEVICE
Matrix(
Element _0_0, Element _0_1, Element _0_2, Element _0_3,
Element _1_0, Element _1_1, Element _1_2, Element _1_3
) {
data[0] = _0_0; data[1] = _0_1; data[2] = _0_2; data[3] = _0_3;
data[4] = _1_0; data[5] = _1_1; data[6] = _1_2; data[7] = _1_3;
}
/// Constucts a 2-by-4 matrix from row vectors
CUTLASS_HOST_DEVICE
Matrix(
Matrix<Element, 1, 4> const &row_0,
Matrix<Element, 1, 4> const &row_1
) {
data[0] = row_0.data[0];
data[1] = row_0.data[1];
data[2] = row_0.data[2];
data[3] = row_0.data[3];
data[4] = row_1.data[0];
data[5] = row_1.data[1];
data[6] = row_1.data[2];
data[7] = row_1.data[3];
}
/// Static method to construct a 2-by-4 matrix from column vectors
CUTLASS_HOST_DEVICE
static Matrix from_columns(
Matrix<Element, 4, 1> const &column_0,
Matrix<Element, 4, 1> const &column_1,
Matrix<Element, 4, 1> const &column_2,
Matrix<Element, 4, 1> const &column_3
) {
Matrix result;
result.data[0] = column_0.data[0];
result.data[1] = column_1.data[0];
result.data[2] = column_2.data[0];
result.data[3] = column_3.data[0];
result.data[4] = column_0.data[1];
result.data[5] = column_1.data[1];
result.data[6] = column_2.data[1];
result.data[7] = column_3.data[1];
return result;
}
/// Constructs a matrix from a uniform element
CUTLASS_HOST_DEVICE
static Matrix uniform(Element s) {
Matrix m;
m.data[0] = s;
m.data[1] = s;
m.data[2] = s;
m.data[3] = s;
m.data[4] = s;
m.data[5] = s;
m.data[6] = s;
m.data[7] = s;
return m;
}
/// Constructs a matrix from a uniform element 1
CUTLASS_HOST_DEVICE
static Matrix ones() {
return uniform(Element(1));
}
/// Constructs a matrix from a uniform element 0
CUTLASS_HOST_DEVICE
static Matrix zero() {
return Matrix();
}
/// Constructs a matrix from elements along its diagonal
CUTLASS_HOST_DEVICE
static Matrix from_diagonal(Matrix<Element, 2, 1> const &diag) {
Matrix m;
m.data[0] = diag.data[0];
m.data[3] = diag.data[1];
return m;
}
/// Constructs a matrix from elements along its diagonal
CUTLASS_HOST_DEVICE
static Matrix from_diagonal(Matrix<Element, 1, 2> const &diag) {
Matrix m;
m.data[0] = diag.data[0];
m.data[3] = diag.data[1];
return m;
}
/// Gets an array of diagonal elements
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 1> diagonal() const {
Matrix<Element, 2, 1> diag;
diag.data[0] = data[0];
diag.data[1] = data[3];
return diag;
}
/// Returns a transposed matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 2> transpose() const {
Matrix<Element, 4, 2> mt;
mt.data[0] = data[0];
mt.data[2] = data[1];
mt.data[4] = data[2];
mt.data[6] = data[3];
mt.data[1] = data[4];
mt.data[3] = data[5];
mt.data[5] = data[6];
mt.data[7] = data[7];
return mt;
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element at(int i, int j) const {
return data[i * 2 + j];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & at(int i, int j) {
return data[i * 2 + j];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element at(Coord<2> const &coord) const {
return at(coord[0], coord[1]);
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & at(Coord<2> const &coord) {
return at(coord[0], coord[1]);
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element &at(int offset) {
return data[offset];
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element at(int offset) const {
return data[offset];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element operator[](Coord<2> const &coord) const {
return at(coord[0], coord[1]);
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & operator[](Coord<2> const &coord) {
return at(coord[0], coord[1]);
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element & operator[](int offset) {
return data[offset];
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element operator[](int offset) const {
return data[offset];
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 2> slice_1x2(int i = 0, int j = 0) const {
Matrix<Element, 1, 2> m;
m.data[0] = data[i * 4 + j + 0];
m.data[1] = data[i * 4 + j + 1];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_1x2(Matrix<Element, 1, 2> const &m, int i = 0, int j = 0) {
data[i * 4 + j + 0] = m.data[0];
data[i * 4 + j + 1] = m.data[1];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 3> slice_1x3(int i = 0, int j = 0) const {
Matrix<Element, 1, 3> m;
m.data[0] = data[i * 4 + j + 0];
m.data[1] = data[i * 4 + j + 1];
m.data[2] = data[i * 4 + j + 2];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_1x3(Matrix<Element, 1, 3> const &m, int i = 0, int j = 0) {
data[i * 4 + j + 0] = m.data[0];
data[i * 4 + j + 1] = m.data[1];
data[i * 4 + j + 2] = m.data[2];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 4> slice_1x4(int i = 0, int j = 0) const {
Matrix<Element, 1, 4> m;
m.data[0] = data[i * 4 + j + 0];
m.data[1] = data[i * 4 + j + 1];
m.data[2] = data[i * 4 + j + 2];
m.data[3] = data[i * 4 + j + 3];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_1x4(Matrix<Element, 1, 4> const &m, int i = 0, int j = 0) {
data[i * 4 + j + 0] = m.data[0];
data[i * 4 + j + 1] = m.data[1];
data[i * 4 + j + 2] = m.data[2];
data[i * 4 + j + 3] = m.data[3];
return *this;
}
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 4> row(int i) const {
return slice_1x4(i, 0);
}
CUTLASS_HOST_DEVICE
Matrix &set_row(Matrix<Element, 1, 4> const &v, int i = 0) {
return set_slice_1x4(v, i, 0);
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 1> slice_2x1(int i = 0, int j = 0) const {
Matrix<Element, 2, 1> m;
m.data[0] = data[i * 4 + j + 0];
m.data[1] = data[i * 4 + j + 4];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_2x1(Matrix<Element, 2, 1> const &m, int i = 0, int j = 0) {
data[i * 4 + j + 0] = m.data[0];
data[i * 4 + j + 4] = m.data[1];
return *this;
}
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 1> column(int j) const {
return slice_2x1(0, j);
}
CUTLASS_HOST_DEVICE
Matrix &set_column(Matrix<Element, 2, 1> const &v, int j =0) {
return set_slice_2x1(v, 0, j);
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 2> slice_2x2(int i = 0, int j = 0) const {
Matrix<Element, 2, 2> m;
m.data[0] = data[i * 4 + j + 0];
m.data[1] = data[i * 4 + j + 1];
m.data[2] = data[i * 4 + j + 4];
m.data[3] = data[i * 4 + j + 5];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_2x2(Matrix<Element, 2, 2> const &m, int i = 0, int j = 0) {
data[i * 4 + j + 0] = m.data[0];
data[i * 4 + j + 1] = m.data[1];
data[i * 4 + j + 4] = m.data[2];
data[i * 4 + j + 5] = m.data[3];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 3> slice_2x3(int i = 0, int j = 0) const {
Matrix<Element, 2, 3> m;
m.data[0] = data[i * 4 + j + 0];
m.data[1] = data[i * 4 + j + 1];
m.data[2] = data[i * 4 + j + 2];
m.data[3] = data[i * 4 + j + 4];
m.data[4] = data[i * 4 + j + 5];
m.data[5] = data[i * 4 + j + 6];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_2x3(Matrix<Element, 2, 3> const &m, int i = 0, int j = 0) {
data[i * 4 + j + 0] = m.data[0];
data[i * 4 + j + 1] = m.data[1];
data[i * 4 + j + 2] = m.data[2];
data[i * 4 + j + 4] = m.data[3];
data[i * 4 + j + 5] = m.data[4];
data[i * 4 + j + 6] = m.data[5];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 4> slice_2x4(int i = 0, int j = 0) const {
Matrix<Element, 2, 4> m;
m.data[0] = data[i * 4 + j + 0];
m.data[1] = data[i * 4 + j + 1];
m.data[2] = data[i * 4 + j + 2];
m.data[3] = data[i * 4 + j + 3];
m.data[4] = data[i * 4 + j + 4];
m.data[5] = data[i * 4 + j + 5];
m.data[6] = data[i * 4 + j + 6];
m.data[7] = data[i * 4 + j + 7];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_2x4(Matrix<Element, 2, 4> const &m, int i = 0, int j = 0) {
data[i * 4 + j + 0] = m.data[0];
data[i * 4 + j + 1] = m.data[1];
data[i * 4 + j + 2] = m.data[2];
data[i * 4 + j + 3] = m.data[3];
data[i * 4 + j + 4] = m.data[4];
data[i * 4 + j + 5] = m.data[5];
data[i * 4 + j + 6] = m.data[6];
data[i * 4 + j + 7] = m.data[7];
return *this;
}
/// Forms a 2-by-4 matrix by horizontally concatenating a 2-by-1 matrix with a 2-by-3 matrix
CUTLASS_HOST_DEVICE
static Matrix hcat(Matrix<Element, 2, 1> const & lhs, Matrix<Element, 2, 3> const & rhs) {
return Matrix(
lhs.at(0, 0), rhs.at(0, 0), rhs.at(0, 1), rhs.at(0, 2)
, lhs.at(1, 0), rhs.at(1, 0), rhs.at(1, 1), rhs.at(1, 2));
}
/// Forms a 2-by-4 matrix by horizontally concatenating a 2-by-2 matrix with a 2-by-2 matrix
CUTLASS_HOST_DEVICE
static Matrix hcat(Matrix<Element, 2, 2> const & lhs, Matrix<Element, 2, 2> const & rhs) {
return Matrix(
lhs.at(0, 0), lhs.at(0, 1), rhs.at(0, 0), rhs.at(0, 1)
, lhs.at(1, 0), lhs.at(1, 1), rhs.at(1, 0), rhs.at(1, 1));
}
/// Forms a 2-by-4 matrix by horizontally concatenating a 2-by-3 matrix with a 2-by-1 matrix
CUTLASS_HOST_DEVICE
static Matrix hcat(Matrix<Element, 2, 3> const & lhs, Matrix<Element, 2, 1> const & rhs) {
return Matrix(
lhs.at(0, 0), lhs.at(0, 1), lhs.at(0, 2), rhs.at(0, 0)
, lhs.at(1, 0), lhs.at(1, 1), lhs.at(1, 2), rhs.at(1, 0));
}
/// Forms a 2-by-4 matrix by vertically concatenating a 1-by-4 matrix with a 1-by-4 matrix
CUTLASS_HOST_DEVICE
static Matrix vcat(Matrix<Element, 1, 4> const & upper, Matrix<Element, 1, 4> const & lower) {
return Matrix(
upper.at(0, 0), upper.at(0, 1), upper.at(0, 2), upper.at(0, 3)
, lower.at(0, 0), lower.at(0, 1), lower.at(0, 2), lower.at(0, 3));
}
/// Concatenates this matrix with a a 1-by-4 matrix to form a 3-by-4 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 4> vcat(Matrix<Element, 1, 4> const & rhs) const {
return Matrix<Element, 3, 4>::vcat(*this, rhs);
}
/// Concatenates this matrix with a a 2-by-4 matrix to form a 4-by-4 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 4> vcat(Matrix<Element, 2, 4> const & rhs) const {
return Matrix<Element, 4, 4>::vcat(*this, rhs);
}
/// Forms a 2-by-4 matrix by concatenating four components
CUTLASS_HOST_DEVICE
static Matrix block(
Element A, Matrix<Element, 1, 3> const & B,
Element C, Matrix<Element, 1, 3> const & D) {
return Matrix(
A, B.at(0, 0), B.at(0, 1), B.at(0, 2)
, C, D.at(0, 0), D.at(0, 1), D.at(0, 2)
);
}
/// Forms a 2-by-4 matrix by concatenating four components
CUTLASS_HOST_DEVICE
static Matrix block(
Matrix<Element, 1, 2> const & A, Matrix<Element, 1, 2> const & B,
Matrix<Element, 1, 2> const & C, Matrix<Element, 1, 2> const & D) {
return Matrix(
A.at(0, 0), A.at(0, 1), B.at(0, 0), B.at(0, 1)
, C.at(0, 0), C.at(0, 1), D.at(0, 0), D.at(0, 1)
);
}
/// Forms a 2-by-4 matrix by concatenating four components
CUTLASS_HOST_DEVICE
static Matrix block(
Matrix<Element, 1, 3> const & A, Element B,
Matrix<Element, 1, 3> const & C, Element D) {
return Matrix(
A.at(0, 0), A.at(0, 1), A.at(0, 2), B
, C.at(0, 0), C.at(0, 1), C.at(0, 2), D
);
}
/// Elementwise add operator (2-by-4)
CUTLASS_HOST_DEVICE
Matrix add(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] + rhs.data[0];
result.data[1] = data[1] + rhs.data[1];
result.data[2] = data[2] + rhs.data[2];
result.data[3] = data[3] + rhs.data[3];
result.data[4] = data[4] + rhs.data[4];
result.data[5] = data[5] + rhs.data[5];
result.data[6] = data[6] + rhs.data[6];
result.data[7] = data[7] + rhs.data[7];
return result;
}
/// Elementwise add operator (2-by-4)
CUTLASS_HOST_DEVICE
Matrix operator +(Matrix const &rhs) const {
return add(rhs);
}
/// Elementwise add operator (2-by-4)
CUTLASS_HOST_DEVICE
Matrix & operator +=(Matrix const &rhs) {
data[0] += rhs.data[0];
data[1] += rhs.data[1];
data[2] += rhs.data[2];
data[3] += rhs.data[3];
data[4] += rhs.data[4];
data[5] += rhs.data[5];
data[6] += rhs.data[6];
data[7] += rhs.data[7];
return *this;
}
/// Elementwise subtract operator (2-by-4)
CUTLASS_HOST_DEVICE
Matrix subtract(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] - rhs.data[0];
result.data[1] = data[1] - rhs.data[1];
result.data[2] = data[2] - rhs.data[2];
result.data[3] = data[3] - rhs.data[3];
result.data[4] = data[4] - rhs.data[4];
result.data[5] = data[5] - rhs.data[5];
result.data[6] = data[6] - rhs.data[6];
result.data[7] = data[7] - rhs.data[7];
return result;
}
/// Elementwise subtract operator (2-by-4)
CUTLASS_HOST_DEVICE
Matrix operator -(Matrix const &rhs) const {
return subtract(rhs);
}
/// Elementwise subtract operator (2-by-4)
CUTLASS_HOST_DEVICE
Matrix & operator -=(Matrix const &rhs) {
data[0] -= rhs.data[0];
data[1] -= rhs.data[1];
data[2] -= rhs.data[2];
data[3] -= rhs.data[3];
data[4] -= rhs.data[4];
data[5] -= rhs.data[5];
data[6] -= rhs.data[6];
data[7] -= rhs.data[7];
return *this;
}
/// Elementwise multiply operator (2-by-4)
CUTLASS_HOST_DEVICE
Matrix multiply(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] * rhs.data[0];
result.data[1] = data[1] * rhs.data[1];
result.data[2] = data[2] * rhs.data[2];
result.data[3] = data[3] * rhs.data[3];
result.data[4] = data[4] * rhs.data[4];
result.data[5] = data[5] * rhs.data[5];
result.data[6] = data[6] * rhs.data[6];
result.data[7] = data[7] * rhs.data[7];
return result;
}
/// Scalar multiply operator (2-by-4)
CUTLASS_HOST_DEVICE
Matrix multiply(Element const &s) const {
Matrix result;
result.data[0] = data[0] * s;
result.data[1] = data[1] * s;
result.data[2] = data[2] * s;
result.data[3] = data[3] * s;
result.data[4] = data[4] * s;
result.data[5] = data[5] * s;
result.data[6] = data[6] * s;
result.data[7] = data[7] * s;
return result;
}
/// Scalar multiply operator (2-by-4)
CUTLASS_HOST_DEVICE
Matrix operator *(Element const &s) const {
return multiply(s);
}
/// Scalar multiply operator (2-by-4)
CUTLASS_HOST_DEVICE
Matrix & operator *=(Element const &s) {
data[0] *= s;
data[1] *= s;
data[2] *= s;
data[3] *= s;
data[4] *= s;
data[5] *= s;
data[6] *= s;
data[7] *= s;
return *this;
}
/// Elementwise divide operator (2-by-4)
CUTLASS_HOST_DEVICE
Matrix divide(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] / rhs.data[0];
result.data[1] = data[1] / rhs.data[1];
result.data[2] = data[2] / rhs.data[2];
result.data[3] = data[3] / rhs.data[3];
result.data[4] = data[4] / rhs.data[4];
result.data[5] = data[5] / rhs.data[5];
result.data[6] = data[6] / rhs.data[6];
result.data[7] = data[7] / rhs.data[7];
return result;
}
/// Scalar divide operator (2-by-4)
CUTLASS_HOST_DEVICE
Matrix divide(Element const &s) const {
Matrix result;
result.data[0] = data[0] / s;
result.data[1] = data[1] / s;
result.data[2] = data[2] / s;
result.data[3] = data[3] / s;
result.data[4] = data[4] / s;
result.data[5] = data[5] / s;
result.data[6] = data[6] / s;
result.data[7] = data[7] / s;
return result;
}
/// Scalar divide operator (2-by-4)
CUTLASS_HOST_DEVICE
Matrix operator /(Element const &s) const {
return divide(s);
}
/// Scalar divide operator (2-by-4)
CUTLASS_HOST_DEVICE
Matrix & operator /=(Element const &s) {
data[0] /= s;
data[1] /= s;
data[2] /= s;
data[3] /= s;
data[4] /= s;
data[5] /= s;
data[6] /= s;
data[7] /= s;
return *this;
}
/// Elementwise divide operator (2-by-4)
CUTLASS_HOST_DEVICE
Matrix operator /(Matrix const &rhs) const {
return divide(rhs);
}
/// Elementwise divide operator (2-by-4)
CUTLASS_HOST_DEVICE
Matrix & operator /=(Matrix const &rhs) {
data[0] /= rhs.data[0];
data[1] /= rhs.data[1];
data[2] /= rhs.data[2];
data[3] /= rhs.data[3];
data[4] /= rhs.data[4];
data[5] /= rhs.data[5];
data[6] /= rhs.data[6];
data[7] /= rhs.data[7];
return *this;
}
/// Negates each element of the matrix
CUTLASS_HOST_DEVICE
Matrix operator-() const {
Matrix m;
m.data[0] = -m.data[0];
m.data[1] = -m.data[1];
m.data[2] = -m.data[2];
m.data[3] = -m.data[3];
m.data[4] = -m.data[4];
m.data[5] = -m.data[5];
m.data[6] = -m.data[6];
m.data[7] = -m.data[7];
return m;
}
/// Matrix product of size 2-by-1-by-4
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 1> product(
Matrix<Element, 4, 1> const &rhs,
Matrix<Element, 2, 1> accum = Matrix<Element, 2, 1>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[4] * rhs.data[0];
// k=1
accum.data[0] += data[1] * rhs.data[1];
accum.data[1] += data[5] * rhs.data[1];
// k=2
accum.data[0] += data[2] * rhs.data[2];
accum.data[1] += data[6] * rhs.data[2];
// k=3
accum.data[0] += data[3] * rhs.data[3];
accum.data[1] += data[7] * rhs.data[3];
return accum;
}
/// Matrix product of size 2-by-1-by-4
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 1> operator*(Matrix<Element, 4, 1> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 2-by-2-by-4
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 2> product(
Matrix<Element, 4, 2> const &rhs,
Matrix<Element, 2, 2> accum = Matrix<Element, 2, 2>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[4] * rhs.data[0];
accum.data[3] += data[4] * rhs.data[1];
// k=1
accum.data[0] += data[1] * rhs.data[2];
accum.data[1] += data[1] * rhs.data[3];
accum.data[2] += data[5] * rhs.data[2];
accum.data[3] += data[5] * rhs.data[3];
// k=2
accum.data[0] += data[2] * rhs.data[4];
accum.data[1] += data[2] * rhs.data[5];
accum.data[2] += data[6] * rhs.data[4];
accum.data[3] += data[6] * rhs.data[5];
// k=3
accum.data[0] += data[3] * rhs.data[6];
accum.data[1] += data[3] * rhs.data[7];
accum.data[2] += data[7] * rhs.data[6];
accum.data[3] += data[7] * rhs.data[7];
return accum;
}
/// Matrix product of size 2-by-2-by-4
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 2> operator*(Matrix<Element, 4, 2> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 2-by-3-by-4
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 3> product(
Matrix<Element, 4, 3> const &rhs,
Matrix<Element, 2, 3> accum = Matrix<Element, 2, 3>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[0] * rhs.data[2];
accum.data[3] += data[4] * rhs.data[0];
accum.data[4] += data[4] * rhs.data[1];
accum.data[5] += data[4] * rhs.data[2];
// k=1
accum.data[0] += data[1] * rhs.data[3];
accum.data[1] += data[1] * rhs.data[4];
accum.data[2] += data[1] * rhs.data[5];
accum.data[3] += data[5] * rhs.data[3];
accum.data[4] += data[5] * rhs.data[4];
accum.data[5] += data[5] * rhs.data[5];
// k=2
accum.data[0] += data[2] * rhs.data[6];
accum.data[1] += data[2] * rhs.data[7];
accum.data[2] += data[2] * rhs.data[8];
accum.data[3] += data[6] * rhs.data[6];
accum.data[4] += data[6] * rhs.data[7];
accum.data[5] += data[6] * rhs.data[8];
// k=3
accum.data[0] += data[3] * rhs.data[9];
accum.data[1] += data[3] * rhs.data[10];
accum.data[2] += data[3] * rhs.data[11];
accum.data[3] += data[7] * rhs.data[9];
accum.data[4] += data[7] * rhs.data[10];
accum.data[5] += data[7] * rhs.data[11];
return accum;
}
/// Matrix product of size 2-by-3-by-4
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 3> operator*(Matrix<Element, 4, 3> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 2-by-4-by-4
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 4> product(
Matrix<Element, 4, 4> const &rhs,
Matrix<Element, 2, 4> accum = Matrix<Element, 2, 4>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[0] * rhs.data[2];
accum.data[3] += data[0] * rhs.data[3];
accum.data[4] += data[4] * rhs.data[0];
accum.data[5] += data[4] * rhs.data[1];
accum.data[6] += data[4] * rhs.data[2];
accum.data[7] += data[4] * rhs.data[3];
// k=1
accum.data[0] += data[1] * rhs.data[4];
accum.data[1] += data[1] * rhs.data[5];
accum.data[2] += data[1] * rhs.data[6];
accum.data[3] += data[1] * rhs.data[7];
accum.data[4] += data[5] * rhs.data[4];
accum.data[5] += data[5] * rhs.data[5];
accum.data[6] += data[5] * rhs.data[6];
accum.data[7] += data[5] * rhs.data[7];
// k=2
accum.data[0] += data[2] * rhs.data[8];
accum.data[1] += data[2] * rhs.data[9];
accum.data[2] += data[2] * rhs.data[10];
accum.data[3] += data[2] * rhs.data[11];
accum.data[4] += data[6] * rhs.data[8];
accum.data[5] += data[6] * rhs.data[9];
accum.data[6] += data[6] * rhs.data[10];
accum.data[7] += data[6] * rhs.data[11];
// k=3
accum.data[0] += data[3] * rhs.data[12];
accum.data[1] += data[3] * rhs.data[13];
accum.data[2] += data[3] * rhs.data[14];
accum.data[3] += data[3] * rhs.data[15];
accum.data[4] += data[7] * rhs.data[12];
accum.data[5] += data[7] * rhs.data[13];
accum.data[6] += data[7] * rhs.data[14];
accum.data[7] += data[7] * rhs.data[15];
return accum;
}
/// Matrix product of size 2-by-4-by-4
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 4> operator*(Matrix<Element, 4, 4> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 2-by-4-by-4
CUTLASS_HOST_DEVICE
Matrix & operator*=(Matrix<Element, 4, 4> const &rhs) {
*this = product(rhs);
return *this;
}
/// Returns the sum of elements
CUTLASS_HOST_DEVICE
Element sum(Element accum = Element()) const {
accum += data[0];
accum += data[1];
accum += data[2];
accum += data[3];
accum += data[4];
accum += data[5];
accum += data[6];
accum += data[7];
return accum;
}
/// Returns the sum of squared elements
CUTLASS_HOST_DEVICE
Element norm(Element accum = Element()) const {
accum += data[0] * data[0];
accum += data[1] * data[1];
accum += data[2] * data[2];
accum += data[3] * data[3];
accum += data[4] * data[4];
accum += data[5] * data[5];
accum += data[6] * data[6];
accum += data[7] * data[7];
return accum;
}
/// Returns square root of the norm
CUTLASS_HOST_DEVICE
Element magnitude() const {
return fast_sqrt(norm());
}
/// Returns the sum of diagonal elements
CUTLASS_HOST_DEVICE
Element trace(Element accum = Element()) const {
accum += data[0];
accum += data[5];
return accum;
}
};
/// Template alias for 2-by-4 matrix
template <typename Element>
using Matrix2x4 = Matrix<Element, 2, 4>;
/// Free funciton to infer element type from template arguments
template <typename Element>
CUTLASS_HOST_DEVICE Matrix2x4<Element> make_Matrix2x4(
Element _0_0, Element _0_1, Element _0_2, Element _0_3,
Element _1_0, Element _1_1, Element _1_2, Element _1_3
) {
return Matrix2x4<Element>(
_0_0, _0_1, _0_2, _0_3,
_1_0, _1_1, _1_2, _1_3
);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// 3-by-1 matrix template class definition
template <typename Element_>
struct Matrix<Element_, 3, 1> {
//
// Type definitions
//
/// Element data type
using Element = Element_;
/// Number of rows in matrix
static int const kRows = 3;
/// Number of columns in matrix
static int const kColumns = 1;
/// Layout of matrix in underlying array
using Layout = layout::RowMajor;
/// Number of elements in matrix
static int const kCount = 3;
//
// Data members
//
/// Elements of the matrix in row-major layout
Array<Element, kCount> data;
//
// Methods
//
/// Constructs a zero matrix
CUTLASS_HOST_DEVICE
Matrix() {
data.clear();
}
/// Copy constructor for a 3-by-1 matrix
CUTLASS_HOST_DEVICE
Matrix(Matrix const &rhs) {
data = rhs.data;
}
/// Constucts a 3-by-1 matrix from scalar elements
CUTLASS_HOST_DEVICE
Matrix(
Element _0_0,
Element _1_0,
Element _2_0
) {
data[0] = _0_0;
data[1] = _1_0;
data[2] = _2_0;
}
/// Constructs a matrix from a uniform element
CUTLASS_HOST_DEVICE
static Matrix uniform(Element s) {
Matrix m;
m.data[0] = s;
m.data[1] = s;
m.data[2] = s;
return m;
}
/// Constructs a matrix from a uniform element 1
CUTLASS_HOST_DEVICE
static Matrix ones() {
return uniform(Element(1));
}
/// Constructs a matrix from a uniform element 0
CUTLASS_HOST_DEVICE
static Matrix zero() {
return Matrix();
}
/// Returns a transposed matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 3> transpose() const {
Matrix<Element, 1, 3> mt;
mt.data[0] = data[0];
mt.data[1] = data[1];
mt.data[2] = data[2];
return mt;
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element at(int i, int j) const {
return data[i * 3 + j];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & at(int i, int j) {
return data[i * 3 + j];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element at(Coord<2> const &coord) const {
return at(coord[0], coord[1]);
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & at(Coord<2> const &coord) {
return at(coord[0], coord[1]);
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element &at(int offset) {
return data[offset];
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element at(int offset) const {
return data[offset];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element operator[](Coord<2> const &coord) const {
return at(coord[0], coord[1]);
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & operator[](Coord<2> const &coord) {
return at(coord[0], coord[1]);
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element & operator[](int offset) {
return data[offset];
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element operator[](int offset) const {
return data[offset];
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 1> slice_2x1(int i = 0, int j = 0) const {
Matrix<Element, 2, 1> m;
m.data[0] = data[i * 1 + j + 0];
m.data[1] = data[i * 1 + j + 1];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_2x1(Matrix<Element, 2, 1> const &m, int i = 0, int j = 0) {
data[i * 1 + j + 0] = m.data[0];
data[i * 1 + j + 1] = m.data[1];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 1> slice_3x1(int i = 0, int j = 0) const {
Matrix<Element, 3, 1> m;
m.data[0] = data[i * 1 + j + 0];
m.data[1] = data[i * 1 + j + 1];
m.data[2] = data[i * 1 + j + 2];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_3x1(Matrix<Element, 3, 1> const &m, int i = 0, int j = 0) {
data[i * 1 + j + 0] = m.data[0];
data[i * 1 + j + 1] = m.data[1];
data[i * 1 + j + 2] = m.data[2];
return *this;
}
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 1> column(int j) const {
return slice_3x1(0, j);
}
CUTLASS_HOST_DEVICE
Matrix &set_column(Matrix<Element, 3, 1> const &v, int j =0) {
return set_slice_3x1(v, 0, j);
}
/// Concatenates this matrix with a a 3-by-1 matrix to form a 3-by-2 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 2> hcat(Matrix<Element, 3, 1> const & rhs) const {
return Matrix<Element, 3, 2>::hcat(*this, rhs);
}
/// Concatenates this matrix with a a 3-by-2 matrix to form a 3-by-3 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 3> hcat(Matrix<Element, 3, 2> const & rhs) const {
return Matrix<Element, 3, 3>::hcat(*this, rhs);
}
/// Concatenates this matrix with a a 3-by-3 matrix to form a 3-by-4 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 4> hcat(Matrix<Element, 3, 3> const & rhs) const {
return Matrix<Element, 3, 4>::hcat(*this, rhs);
}
/// Forms a 3-by-1 matrix by vertically concatenating an Element with a 2-by-1 matrix
CUTLASS_HOST_DEVICE
static Matrix vcat(Element upper, Matrix<Element, 2, 1> const & lower) {
return Matrix(
upper
, lower.at(0, 0)
, lower.at(1, 0));
}
/// Forms a 3-by-1 matrix by vertically concatenating a 2-by-1 matrix with an Element
CUTLASS_HOST_DEVICE
static Matrix vcat(Matrix<Element, 2, 1> const & upper, Element lower) {
return Matrix(
upper.at(0, 0)
, upper.at(1, 0)
, lower);
}
/// Concatenates this matrix with a an Element to form a 4-by-1 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 1> vcat(Element rhs) const {
return Matrix<Element, 4, 1>::vcat(*this, rhs);
}
/// Elementwise add operator (3-by-1)
CUTLASS_HOST_DEVICE
Matrix add(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] + rhs.data[0];
result.data[1] = data[1] + rhs.data[1];
result.data[2] = data[2] + rhs.data[2];
return result;
}
/// Elementwise add operator (3-by-1)
CUTLASS_HOST_DEVICE
Matrix operator +(Matrix const &rhs) const {
return add(rhs);
}
/// Elementwise add operator (3-by-1)
CUTLASS_HOST_DEVICE
Matrix & operator +=(Matrix const &rhs) {
data[0] += rhs.data[0];
data[1] += rhs.data[1];
data[2] += rhs.data[2];
return *this;
}
/// Elementwise subtract operator (3-by-1)
CUTLASS_HOST_DEVICE
Matrix subtract(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] - rhs.data[0];
result.data[1] = data[1] - rhs.data[1];
result.data[2] = data[2] - rhs.data[2];
return result;
}
/// Elementwise subtract operator (3-by-1)
CUTLASS_HOST_DEVICE
Matrix operator -(Matrix const &rhs) const {
return subtract(rhs);
}
/// Elementwise subtract operator (3-by-1)
CUTLASS_HOST_DEVICE
Matrix & operator -=(Matrix const &rhs) {
data[0] -= rhs.data[0];
data[1] -= rhs.data[1];
data[2] -= rhs.data[2];
return *this;
}
/// Elementwise multiply operator (3-by-1)
CUTLASS_HOST_DEVICE
Matrix multiply(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] * rhs.data[0];
result.data[1] = data[1] * rhs.data[1];
result.data[2] = data[2] * rhs.data[2];
return result;
}
/// Scalar multiply operator (3-by-1)
CUTLASS_HOST_DEVICE
Matrix multiply(Element const &s) const {
Matrix result;
result.data[0] = data[0] * s;
result.data[1] = data[1] * s;
result.data[2] = data[2] * s;
return result;
}
/// Scalar multiply operator (3-by-1)
CUTLASS_HOST_DEVICE
Matrix operator *(Element const &s) const {
return multiply(s);
}
/// Scalar multiply operator (3-by-1)
CUTLASS_HOST_DEVICE
Matrix & operator *=(Element const &s) {
data[0] *= s;
data[1] *= s;
data[2] *= s;
return *this;
}
/// Elementwise divide operator (3-by-1)
CUTLASS_HOST_DEVICE
Matrix divide(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] / rhs.data[0];
result.data[1] = data[1] / rhs.data[1];
result.data[2] = data[2] / rhs.data[2];
return result;
}
/// Scalar divide operator (3-by-1)
CUTLASS_HOST_DEVICE
Matrix divide(Element const &s) const {
Matrix result;
result.data[0] = data[0] / s;
result.data[1] = data[1] / s;
result.data[2] = data[2] / s;
return result;
}
/// Scalar divide operator (3-by-1)
CUTLASS_HOST_DEVICE
Matrix operator /(Element const &s) const {
return divide(s);
}
/// Scalar divide operator (3-by-1)
CUTLASS_HOST_DEVICE
Matrix & operator /=(Element const &s) {
data[0] /= s;
data[1] /= s;
data[2] /= s;
return *this;
}
/// Elementwise divide operator (3-by-1)
CUTLASS_HOST_DEVICE
Matrix operator /(Matrix const &rhs) const {
return divide(rhs);
}
/// Elementwise divide operator (3-by-1)
CUTLASS_HOST_DEVICE
Matrix & operator /=(Matrix const &rhs) {
data[0] /= rhs.data[0];
data[1] /= rhs.data[1];
data[2] /= rhs.data[2];
return *this;
}
/// Negates each element of the matrix
CUTLASS_HOST_DEVICE
Matrix operator-() const {
Matrix m;
m.data[0] = -m.data[0];
m.data[1] = -m.data[1];
m.data[2] = -m.data[2];
return m;
}
/// Matrix product of size 3-by-1-by-1
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 1> product(
Matrix<Element, 1, 1> const &rhs,
Matrix<Element, 3, 1> accum = Matrix<Element, 3, 1>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[1] * rhs.data[0];
accum.data[2] += data[2] * rhs.data[0];
return accum;
}
/// Matrix product of size 3-by-1-by-1
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 1> operator*(Matrix<Element, 1, 1> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 3-by-1-by-1
CUTLASS_HOST_DEVICE
Matrix & operator*=(Matrix<Element, 1, 1> const &rhs) {
*this = product(rhs);
return *this;
}
/// Matrix product of size 3-by-2-by-1
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 2> product(
Matrix<Element, 1, 2> const &rhs,
Matrix<Element, 3, 2> accum = Matrix<Element, 3, 2>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[1] * rhs.data[0];
accum.data[3] += data[1] * rhs.data[1];
accum.data[4] += data[2] * rhs.data[0];
accum.data[5] += data[2] * rhs.data[1];
return accum;
}
/// Matrix product of size 3-by-2-by-1
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 2> operator*(Matrix<Element, 1, 2> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 3-by-3-by-1
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 3> product(
Matrix<Element, 1, 3> const &rhs,
Matrix<Element, 3, 3> accum = Matrix<Element, 3, 3>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[0] * rhs.data[2];
accum.data[3] += data[1] * rhs.data[0];
accum.data[4] += data[1] * rhs.data[1];
accum.data[5] += data[1] * rhs.data[2];
accum.data[6] += data[2] * rhs.data[0];
accum.data[7] += data[2] * rhs.data[1];
accum.data[8] += data[2] * rhs.data[2];
return accum;
}
/// Matrix product of size 3-by-3-by-1
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 3> operator*(Matrix<Element, 1, 3> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 3-by-4-by-1
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 4> product(
Matrix<Element, 1, 4> const &rhs,
Matrix<Element, 3, 4> accum = Matrix<Element, 3, 4>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[0] * rhs.data[2];
accum.data[3] += data[0] * rhs.data[3];
accum.data[4] += data[1] * rhs.data[0];
accum.data[5] += data[1] * rhs.data[1];
accum.data[6] += data[1] * rhs.data[2];
accum.data[7] += data[1] * rhs.data[3];
accum.data[8] += data[2] * rhs.data[0];
accum.data[9] += data[2] * rhs.data[1];
accum.data[10] += data[2] * rhs.data[2];
accum.data[11] += data[2] * rhs.data[3];
return accum;
}
/// Matrix product of size 3-by-4-by-1
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 4> operator*(Matrix<Element, 1, 4> const &rhs) const {
return product(rhs);
}
/// Dot product of vectors with extent 3
CUTLASS_HOST_DEVICE
Element dot(Matrix<Element, 3, 1> const &rhs, Element accum = Element()) const {
accum += data[0] * rhs.data[0];
accum += data[1] * rhs.data[1];
accum += data[2] * rhs.data[2];
return accum;
}
/// Dot product of vectors with extent 3
CUTLASS_HOST_DEVICE
Element dot(Matrix<Element, 1, 3> const &rhs, Element accum = Element()) const {
accum += data[0] * rhs.data[0];
accum += data[1] * rhs.data[1];
accum += data[2] * rhs.data[2];
return accum;
}
/// Returns the sum of elements
CUTLASS_HOST_DEVICE
Element sum(Element accum = Element()) const {
accum += data[0];
accum += data[1];
accum += data[2];
return accum;
}
/// Returns the sum of squared elements
CUTLASS_HOST_DEVICE
Element norm(Element accum = Element()) const {
accum += data[0] * data[0];
accum += data[1] * data[1];
accum += data[2] * data[2];
return accum;
}
/// Returns square root of the norm
CUTLASS_HOST_DEVICE
Element magnitude() const {
return fast_sqrt(norm());
}
/// Returns the sum of diagonal elements
CUTLASS_HOST_DEVICE
Element trace(Element accum = Element()) const {
accum += data[0];
return accum;
}
/// Cross product
CUTLASS_HOST_DEVICE
Matrix cross(Matrix const &rhs) const {
return Matrix(
data[1] * rhs.data[2] - data[2] * rhs.data[1],
data[0] * rhs.data[2] - data[2] * rhs.data[1],
data[0] * rhs.data[1] - data[1] * rhs.data[0]
);
}
};
/// Template alias for 3-by-1 matrix
template <typename Element>
using Matrix3x1 = Matrix<Element, 3, 1>;
/// Free funciton to infer element type from template arguments
template <typename Element>
CUTLASS_HOST_DEVICE Matrix3x1<Element> make_Matrix3x1(
Element _0_0,
Element _1_0,
Element _2_0
) {
return Matrix3x1<Element>(
_0_0,
_1_0,
_2_0
);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// 3-by-2 matrix template class definition
template <typename Element_>
struct Matrix<Element_, 3, 2> {
//
// Type definitions
//
/// Element data type
using Element = Element_;
/// Number of rows in matrix
static int const kRows = 3;
/// Number of columns in matrix
static int const kColumns = 2;
/// Layout of matrix in underlying array
using Layout = layout::RowMajor;
/// Number of elements in matrix
static int const kCount = 6;
//
// Data members
//
/// Elements of the matrix in row-major layout
Array<Element, kCount> data;
//
// Methods
//
/// Constructs a zero matrix
CUTLASS_HOST_DEVICE
Matrix() {
data.clear();
}
/// Copy constructor for a 3-by-2 matrix
CUTLASS_HOST_DEVICE
Matrix(Matrix const &rhs) {
data = rhs.data;
}
/// Constucts a 3-by-2 matrix from scalar elements
CUTLASS_HOST_DEVICE
Matrix(
Element _0_0, Element _0_1,
Element _1_0, Element _1_1,
Element _2_0, Element _2_1
) {
data[0] = _0_0; data[1] = _0_1;
data[2] = _1_0; data[3] = _1_1;
data[4] = _2_0; data[5] = _2_1;
}
/// Constucts a 3-by-2 matrix from row vectors
CUTLASS_HOST_DEVICE
Matrix(
Matrix<Element, 1, 2> const &row_0,
Matrix<Element, 1, 2> const &row_1,
Matrix<Element, 1, 2> const &row_2
) {
data[0] = row_0.data[0];
data[1] = row_0.data[1];
data[2] = row_1.data[0];
data[3] = row_1.data[1];
data[4] = row_2.data[0];
data[5] = row_2.data[1];
}
/// Static method to construct a 3-by-2 matrix from column vectors
CUTLASS_HOST_DEVICE
static Matrix from_columns(
Matrix<Element, 2, 1> const &column_0,
Matrix<Element, 2, 1> const &column_1
) {
Matrix result;
result.data[0] = column_0.data[0];
result.data[1] = column_1.data[0];
result.data[2] = column_0.data[1];
result.data[3] = column_1.data[1];
result.data[4] = column_0.data[2];
result.data[5] = column_1.data[2];
return result;
}
/// Constructs a matrix from a uniform element
CUTLASS_HOST_DEVICE
static Matrix uniform(Element s) {
Matrix m;
m.data[0] = s;
m.data[1] = s;
m.data[2] = s;
m.data[3] = s;
m.data[4] = s;
m.data[5] = s;
return m;
}
/// Constructs a matrix from a uniform element 1
CUTLASS_HOST_DEVICE
static Matrix ones() {
return uniform(Element(1));
}
/// Constructs a matrix from a uniform element 0
CUTLASS_HOST_DEVICE
static Matrix zero() {
return Matrix();
}
/// Constructs a matrix from elements along its diagonal
CUTLASS_HOST_DEVICE
static Matrix from_diagonal(Matrix<Element, 2, 1> const &diag) {
Matrix m;
m.data[0] = diag.data[0];
m.data[4] = diag.data[1];
m.data[8] = diag.data[2];
return m;
}
/// Constructs a matrix from elements along its diagonal
CUTLASS_HOST_DEVICE
static Matrix from_diagonal(Matrix<Element, 1, 2> const &diag) {
Matrix m;
m.data[0] = diag.data[0];
m.data[4] = diag.data[1];
m.data[8] = diag.data[2];
return m;
}
/// Gets an array of diagonal elements
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 1> diagonal() const {
Matrix<Element, 2, 1> diag;
diag.data[0] = data[0];
diag.data[1] = data[4];
diag.data[2] = data[8];
return diag;
}
/// Returns a transposed matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 3> transpose() const {
Matrix<Element, 2, 3> mt;
mt.data[0] = data[0];
mt.data[3] = data[1];
mt.data[1] = data[2];
mt.data[4] = data[3];
mt.data[2] = data[4];
mt.data[5] = data[5];
return mt;
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element at(int i, int j) const {
return data[i * 3 + j];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & at(int i, int j) {
return data[i * 3 + j];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element at(Coord<2> const &coord) const {
return at(coord[0], coord[1]);
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & at(Coord<2> const &coord) {
return at(coord[0], coord[1]);
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element &at(int offset) {
return data[offset];
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element at(int offset) const {
return data[offset];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element operator[](Coord<2> const &coord) const {
return at(coord[0], coord[1]);
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & operator[](Coord<2> const &coord) {
return at(coord[0], coord[1]);
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element & operator[](int offset) {
return data[offset];
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element operator[](int offset) const {
return data[offset];
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 2> slice_1x2(int i = 0, int j = 0) const {
Matrix<Element, 1, 2> m;
m.data[0] = data[i * 2 + j + 0];
m.data[1] = data[i * 2 + j + 1];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_1x2(Matrix<Element, 1, 2> const &m, int i = 0, int j = 0) {
data[i * 2 + j + 0] = m.data[0];
data[i * 2 + j + 1] = m.data[1];
return *this;
}
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 2> row(int i) const {
return slice_1x2(i, 0);
}
CUTLASS_HOST_DEVICE
Matrix &set_row(Matrix<Element, 1, 2> const &v, int i = 0) {
return set_slice_1x2(v, i, 0);
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 1> slice_2x1(int i = 0, int j = 0) const {
Matrix<Element, 2, 1> m;
m.data[0] = data[i * 2 + j + 0];
m.data[1] = data[i * 2 + j + 2];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_2x1(Matrix<Element, 2, 1> const &m, int i = 0, int j = 0) {
data[i * 2 + j + 0] = m.data[0];
data[i * 2 + j + 2] = m.data[1];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 2> slice_2x2(int i = 0, int j = 0) const {
Matrix<Element, 2, 2> m;
m.data[0] = data[i * 2 + j + 0];
m.data[1] = data[i * 2 + j + 1];
m.data[2] = data[i * 2 + j + 2];
m.data[3] = data[i * 2 + j + 3];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_2x2(Matrix<Element, 2, 2> const &m, int i = 0, int j = 0) {
data[i * 2 + j + 0] = m.data[0];
data[i * 2 + j + 1] = m.data[1];
data[i * 2 + j + 2] = m.data[2];
data[i * 2 + j + 3] = m.data[3];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 1> slice_3x1(int i = 0, int j = 0) const {
Matrix<Element, 3, 1> m;
m.data[0] = data[i * 2 + j + 0];
m.data[1] = data[i * 2 + j + 2];
m.data[2] = data[i * 2 + j + 4];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_3x1(Matrix<Element, 3, 1> const &m, int i = 0, int j = 0) {
data[i * 2 + j + 0] = m.data[0];
data[i * 2 + j + 2] = m.data[1];
data[i * 2 + j + 4] = m.data[2];
return *this;
}
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 1> column(int j) const {
return slice_3x1(0, j);
}
CUTLASS_HOST_DEVICE
Matrix &set_column(Matrix<Element, 3, 1> const &v, int j =0) {
return set_slice_3x1(v, 0, j);
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 2> slice_3x2(int i = 0, int j = 0) const {
Matrix<Element, 3, 2> m;
m.data[0] = data[i * 2 + j + 0];
m.data[1] = data[i * 2 + j + 1];
m.data[2] = data[i * 2 + j + 2];
m.data[3] = data[i * 2 + j + 3];
m.data[4] = data[i * 2 + j + 4];
m.data[5] = data[i * 2 + j + 5];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_3x2(Matrix<Element, 3, 2> const &m, int i = 0, int j = 0) {
data[i * 2 + j + 0] = m.data[0];
data[i * 2 + j + 1] = m.data[1];
data[i * 2 + j + 2] = m.data[2];
data[i * 2 + j + 3] = m.data[3];
data[i * 2 + j + 4] = m.data[4];
data[i * 2 + j + 5] = m.data[5];
return *this;
}
/// Forms a 3-by-2 matrix by horizontally concatenating a 3-by-1 matrix with a 3-by-1 matrix
CUTLASS_HOST_DEVICE
static Matrix hcat(Matrix<Element, 3, 1> const & lhs, Matrix<Element, 3, 1> const & rhs) {
return Matrix(
lhs.at(0, 0), rhs.at(0, 0)
, lhs.at(1, 0), rhs.at(1, 0)
, lhs.at(2, 0), rhs.at(2, 0));
}
/// Concatenates this matrix with a a 3-by-1 matrix to form a 3-by-3 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 3> hcat(Matrix<Element, 3, 1> const & rhs) const {
return Matrix<Element, 3, 3>::hcat(*this, rhs);
}
/// Concatenates this matrix with a a 3-by-2 matrix to form a 3-by-4 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 4> hcat(Matrix<Element, 3, 2> const & rhs) const {
return Matrix<Element, 3, 4>::hcat(*this, rhs);
}
/// Forms a 3-by-2 matrix by vertically concatenating a 1-by-2 matrix with a 2-by-2 matrix
CUTLASS_HOST_DEVICE
static Matrix vcat(Matrix<Element, 1, 2> const & upper, Matrix<Element, 2, 2> const & lower) {
return Matrix(
upper.at(0, 0), upper.at(0, 1)
, lower.at(0, 0), lower.at(0, 1)
, lower.at(1, 0), lower.at(1, 1));
}
/// Forms a 3-by-2 matrix by vertically concatenating a 2-by-2 matrix with a 1-by-2 matrix
CUTLASS_HOST_DEVICE
static Matrix vcat(Matrix<Element, 2, 2> const & upper, Matrix<Element, 1, 2> const & lower) {
return Matrix(
upper.at(0, 0), upper.at(0, 1)
, upper.at(1, 0), upper.at(1, 1)
, lower.at(0, 0), lower.at(0, 1));
}
/// Concatenates this matrix with a a 1-by-2 matrix to form a 4-by-2 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 2> vcat(Matrix<Element, 1, 2> const & rhs) const {
return Matrix<Element, 4, 2>::vcat(*this, rhs);
}
/// Forms a 3-by-2 matrix by concatenating four components
CUTLASS_HOST_DEVICE
static Matrix block(
Element A, Element B,
Matrix<Element, 2, 1> const & C, Matrix<Element, 2, 1> const & D) {
return Matrix(
A, B
, C.at(0, 0), D.at(0, 0)
, C.at(1, 0), D.at(1, 0)
);
}
/// Forms a 3-by-2 matrix by concatenating four components
CUTLASS_HOST_DEVICE
static Matrix block(
Matrix<Element, 2, 1> const & A, Matrix<Element, 2, 1> const & B,
Element C, Element D) {
return Matrix(
A.at(0, 0), B.at(0, 0)
, A.at(1, 0), B.at(1, 0)
, C, D
);
}
/// Elementwise add operator (3-by-2)
CUTLASS_HOST_DEVICE
Matrix add(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] + rhs.data[0];
result.data[1] = data[1] + rhs.data[1];
result.data[2] = data[2] + rhs.data[2];
result.data[3] = data[3] + rhs.data[3];
result.data[4] = data[4] + rhs.data[4];
result.data[5] = data[5] + rhs.data[5];
return result;
}
/// Elementwise add operator (3-by-2)
CUTLASS_HOST_DEVICE
Matrix operator +(Matrix const &rhs) const {
return add(rhs);
}
/// Elementwise add operator (3-by-2)
CUTLASS_HOST_DEVICE
Matrix & operator +=(Matrix const &rhs) {
data[0] += rhs.data[0];
data[1] += rhs.data[1];
data[2] += rhs.data[2];
data[3] += rhs.data[3];
data[4] += rhs.data[4];
data[5] += rhs.data[5];
return *this;
}
/// Elementwise subtract operator (3-by-2)
CUTLASS_HOST_DEVICE
Matrix subtract(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] - rhs.data[0];
result.data[1] = data[1] - rhs.data[1];
result.data[2] = data[2] - rhs.data[2];
result.data[3] = data[3] - rhs.data[3];
result.data[4] = data[4] - rhs.data[4];
result.data[5] = data[5] - rhs.data[5];
return result;
}
/// Elementwise subtract operator (3-by-2)
CUTLASS_HOST_DEVICE
Matrix operator -(Matrix const &rhs) const {
return subtract(rhs);
}
/// Elementwise subtract operator (3-by-2)
CUTLASS_HOST_DEVICE
Matrix & operator -=(Matrix const &rhs) {
data[0] -= rhs.data[0];
data[1] -= rhs.data[1];
data[2] -= rhs.data[2];
data[3] -= rhs.data[3];
data[4] -= rhs.data[4];
data[5] -= rhs.data[5];
return *this;
}
/// Elementwise multiply operator (3-by-2)
CUTLASS_HOST_DEVICE
Matrix multiply(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] * rhs.data[0];
result.data[1] = data[1] * rhs.data[1];
result.data[2] = data[2] * rhs.data[2];
result.data[3] = data[3] * rhs.data[3];
result.data[4] = data[4] * rhs.data[4];
result.data[5] = data[5] * rhs.data[5];
return result;
}
/// Scalar multiply operator (3-by-2)
CUTLASS_HOST_DEVICE
Matrix multiply(Element const &s) const {
Matrix result;
result.data[0] = data[0] * s;
result.data[1] = data[1] * s;
result.data[2] = data[2] * s;
result.data[3] = data[3] * s;
result.data[4] = data[4] * s;
result.data[5] = data[5] * s;
return result;
}
/// Scalar multiply operator (3-by-2)
CUTLASS_HOST_DEVICE
Matrix operator *(Element const &s) const {
return multiply(s);
}
/// Scalar multiply operator (3-by-2)
CUTLASS_HOST_DEVICE
Matrix & operator *=(Element const &s) {
data[0] *= s;
data[1] *= s;
data[2] *= s;
data[3] *= s;
data[4] *= s;
data[5] *= s;
return *this;
}
/// Elementwise divide operator (3-by-2)
CUTLASS_HOST_DEVICE
Matrix divide(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] / rhs.data[0];
result.data[1] = data[1] / rhs.data[1];
result.data[2] = data[2] / rhs.data[2];
result.data[3] = data[3] / rhs.data[3];
result.data[4] = data[4] / rhs.data[4];
result.data[5] = data[5] / rhs.data[5];
return result;
}
/// Scalar divide operator (3-by-2)
CUTLASS_HOST_DEVICE
Matrix divide(Element const &s) const {
Matrix result;
result.data[0] = data[0] / s;
result.data[1] = data[1] / s;
result.data[2] = data[2] / s;
result.data[3] = data[3] / s;
result.data[4] = data[4] / s;
result.data[5] = data[5] / s;
return result;
}
/// Scalar divide operator (3-by-2)
CUTLASS_HOST_DEVICE
Matrix operator /(Element const &s) const {
return divide(s);
}
/// Scalar divide operator (3-by-2)
CUTLASS_HOST_DEVICE
Matrix & operator /=(Element const &s) {
data[0] /= s;
data[1] /= s;
data[2] /= s;
data[3] /= s;
data[4] /= s;
data[5] /= s;
return *this;
}
/// Elementwise divide operator (3-by-2)
CUTLASS_HOST_DEVICE
Matrix operator /(Matrix const &rhs) const {
return divide(rhs);
}
/// Elementwise divide operator (3-by-2)
CUTLASS_HOST_DEVICE
Matrix & operator /=(Matrix const &rhs) {
data[0] /= rhs.data[0];
data[1] /= rhs.data[1];
data[2] /= rhs.data[2];
data[3] /= rhs.data[3];
data[4] /= rhs.data[4];
data[5] /= rhs.data[5];
return *this;
}
/// Negates each element of the matrix
CUTLASS_HOST_DEVICE
Matrix operator-() const {
Matrix m;
m.data[0] = -m.data[0];
m.data[1] = -m.data[1];
m.data[2] = -m.data[2];
m.data[3] = -m.data[3];
m.data[4] = -m.data[4];
m.data[5] = -m.data[5];
return m;
}
/// Matrix product of size 3-by-1-by-2
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 1> product(
Matrix<Element, 2, 1> const &rhs,
Matrix<Element, 3, 1> accum = Matrix<Element, 3, 1>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[2] * rhs.data[0];
accum.data[2] += data[4] * rhs.data[0];
// k=1
accum.data[0] += data[1] * rhs.data[1];
accum.data[1] += data[3] * rhs.data[1];
accum.data[2] += data[5] * rhs.data[1];
return accum;
}
/// Matrix product of size 3-by-1-by-2
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 1> operator*(Matrix<Element, 2, 1> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 3-by-2-by-2
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 2> product(
Matrix<Element, 2, 2> const &rhs,
Matrix<Element, 3, 2> accum = Matrix<Element, 3, 2>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[2] * rhs.data[0];
accum.data[3] += data[2] * rhs.data[1];
accum.data[4] += data[4] * rhs.data[0];
accum.data[5] += data[4] * rhs.data[1];
// k=1
accum.data[0] += data[1] * rhs.data[2];
accum.data[1] += data[1] * rhs.data[3];
accum.data[2] += data[3] * rhs.data[2];
accum.data[3] += data[3] * rhs.data[3];
accum.data[4] += data[5] * rhs.data[2];
accum.data[5] += data[5] * rhs.data[3];
return accum;
}
/// Matrix product of size 3-by-2-by-2
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 2> operator*(Matrix<Element, 2, 2> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 3-by-2-by-2
CUTLASS_HOST_DEVICE
Matrix & operator*=(Matrix<Element, 2, 2> const &rhs) {
*this = product(rhs);
return *this;
}
/// Matrix product of size 3-by-3-by-2
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 3> product(
Matrix<Element, 2, 3> const &rhs,
Matrix<Element, 3, 3> accum = Matrix<Element, 3, 3>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[0] * rhs.data[2];
accum.data[3] += data[2] * rhs.data[0];
accum.data[4] += data[2] * rhs.data[1];
accum.data[5] += data[2] * rhs.data[2];
accum.data[6] += data[4] * rhs.data[0];
accum.data[7] += data[4] * rhs.data[1];
accum.data[8] += data[4] * rhs.data[2];
// k=1
accum.data[0] += data[1] * rhs.data[3];
accum.data[1] += data[1] * rhs.data[4];
accum.data[2] += data[1] * rhs.data[5];
accum.data[3] += data[3] * rhs.data[3];
accum.data[4] += data[3] * rhs.data[4];
accum.data[5] += data[3] * rhs.data[5];
accum.data[6] += data[5] * rhs.data[3];
accum.data[7] += data[5] * rhs.data[4];
accum.data[8] += data[5] * rhs.data[5];
return accum;
}
/// Matrix product of size 3-by-3-by-2
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 3> operator*(Matrix<Element, 2, 3> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 3-by-4-by-2
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 4> product(
Matrix<Element, 2, 4> const &rhs,
Matrix<Element, 3, 4> accum = Matrix<Element, 3, 4>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[0] * rhs.data[2];
accum.data[3] += data[0] * rhs.data[3];
accum.data[4] += data[2] * rhs.data[0];
accum.data[5] += data[2] * rhs.data[1];
accum.data[6] += data[2] * rhs.data[2];
accum.data[7] += data[2] * rhs.data[3];
accum.data[8] += data[4] * rhs.data[0];
accum.data[9] += data[4] * rhs.data[1];
accum.data[10] += data[4] * rhs.data[2];
accum.data[11] += data[4] * rhs.data[3];
// k=1
accum.data[0] += data[1] * rhs.data[4];
accum.data[1] += data[1] * rhs.data[5];
accum.data[2] += data[1] * rhs.data[6];
accum.data[3] += data[1] * rhs.data[7];
accum.data[4] += data[3] * rhs.data[4];
accum.data[5] += data[3] * rhs.data[5];
accum.data[6] += data[3] * rhs.data[6];
accum.data[7] += data[3] * rhs.data[7];
accum.data[8] += data[5] * rhs.data[4];
accum.data[9] += data[5] * rhs.data[5];
accum.data[10] += data[5] * rhs.data[6];
accum.data[11] += data[5] * rhs.data[7];
return accum;
}
/// Matrix product of size 3-by-4-by-2
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 4> operator*(Matrix<Element, 2, 4> const &rhs) const {
return product(rhs);
}
/// Returns the sum of elements
CUTLASS_HOST_DEVICE
Element sum(Element accum = Element()) const {
accum += data[0];
accum += data[1];
accum += data[2];
accum += data[3];
accum += data[4];
accum += data[5];
return accum;
}
/// Returns the sum of squared elements
CUTLASS_HOST_DEVICE
Element norm(Element accum = Element()) const {
accum += data[0] * data[0];
accum += data[1] * data[1];
accum += data[2] * data[2];
accum += data[3] * data[3];
accum += data[4] * data[4];
accum += data[5] * data[5];
return accum;
}
/// Returns square root of the norm
CUTLASS_HOST_DEVICE
Element magnitude() const {
return fast_sqrt(norm());
}
/// Returns the sum of diagonal elements
CUTLASS_HOST_DEVICE
Element trace(Element accum = Element()) const {
accum += data[0];
accum += data[3];
return accum;
}
};
/// Template alias for 3-by-2 matrix
template <typename Element>
using Matrix3x2 = Matrix<Element, 3, 2>;
/// Free funciton to infer element type from template arguments
template <typename Element>
CUTLASS_HOST_DEVICE Matrix3x2<Element> make_Matrix3x2(
Element _0_0, Element _0_1,
Element _1_0, Element _1_1,
Element _2_0, Element _2_1
) {
return Matrix3x2<Element>(
_0_0, _0_1,
_1_0, _1_1,
_2_0, _2_1
);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// 3-by-3 matrix template class definition
template <typename Element_>
struct Matrix<Element_, 3, 3> {
//
// Type definitions
//
/// Element data type
using Element = Element_;
/// Number of rows in matrix
static int const kRows = 3;
/// Number of columns in matrix
static int const kColumns = 3;
/// Layout of matrix in underlying array
using Layout = layout::RowMajor;
/// Number of elements in matrix
static int const kCount = 9;
//
// Data members
//
/// Elements of the matrix in row-major layout
Array<Element, kCount> data;
//
// Methods
//
/// Constructs a zero matrix
CUTLASS_HOST_DEVICE
Matrix() {
data.clear();
}
/// Copy constructor for a 3-by-3 matrix
CUTLASS_HOST_DEVICE
Matrix(Matrix const &rhs) {
data = rhs.data;
}
/// Constucts a 3-by-3 matrix from scalar elements
CUTLASS_HOST_DEVICE
Matrix(
Element _0_0, Element _0_1, Element _0_2,
Element _1_0, Element _1_1, Element _1_2,
Element _2_0, Element _2_1, Element _2_2
) {
data[0] = _0_0; data[1] = _0_1; data[2] = _0_2;
data[3] = _1_0; data[4] = _1_1; data[5] = _1_2;
data[6] = _2_0; data[7] = _2_1; data[8] = _2_2;
}
/// Constucts a 3-by-3 matrix from row vectors
CUTLASS_HOST_DEVICE
Matrix(
Matrix<Element, 1, 3> const &row_0,
Matrix<Element, 1, 3> const &row_1,
Matrix<Element, 1, 3> const &row_2
) {
data[0] = row_0.data[0];
data[1] = row_0.data[1];
data[2] = row_0.data[2];
data[3] = row_1.data[0];
data[4] = row_1.data[1];
data[5] = row_1.data[2];
data[6] = row_2.data[0];
data[7] = row_2.data[1];
data[8] = row_2.data[2];
}
/// Static method to construct a 3-by-3 matrix from column vectors
CUTLASS_HOST_DEVICE
static Matrix from_columns(
Matrix<Element, 3, 1> const &column_0,
Matrix<Element, 3, 1> const &column_1,
Matrix<Element, 3, 1> const &column_2
) {
Matrix result;
result.data[0] = column_0.data[0];
result.data[1] = column_1.data[0];
result.data[2] = column_2.data[0];
result.data[3] = column_0.data[1];
result.data[4] = column_1.data[1];
result.data[5] = column_2.data[1];
result.data[6] = column_0.data[2];
result.data[7] = column_1.data[2];
result.data[8] = column_2.data[2];
return result;
}
/// Constructs an identity matrix
CUTLASS_HOST_DEVICE
static Matrix identity() {
Matrix m;
m.data[0] = Element(1);
m.data[4] = Element(1);
m.data[8] = Element(1);
return m;
}
/// Constructs a matrix from a uniform element
CUTLASS_HOST_DEVICE
static Matrix uniform(Element s) {
Matrix m;
m.data[0] = s;
m.data[1] = s;
m.data[2] = s;
m.data[3] = s;
m.data[4] = s;
m.data[5] = s;
m.data[6] = s;
m.data[7] = s;
m.data[8] = s;
return m;
}
/// Constructs a matrix from a uniform element 1
CUTLASS_HOST_DEVICE
static Matrix ones() {
return uniform(Element(1));
}
/// Constructs a matrix from a uniform element 0
CUTLASS_HOST_DEVICE
static Matrix zero() {
return Matrix();
}
/// Constructs a matrix from elements along its diagonal
CUTLASS_HOST_DEVICE
static Matrix from_diagonal(Matrix<Element, 3, 1> const &diag) {
Matrix m;
m.data[0] = diag.data[0];
m.data[4] = diag.data[1];
m.data[8] = diag.data[2];
return m;
}
/// Constructs a matrix from elements along its diagonal
CUTLASS_HOST_DEVICE
static Matrix from_diagonal(Matrix<Element, 1, 3> const &diag) {
Matrix m;
m.data[0] = diag.data[0];
m.data[4] = diag.data[1];
m.data[8] = diag.data[2];
return m;
}
/// Gets an array of diagonal elements
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 1> diagonal() const {
Matrix<Element, 3, 1> diag;
diag.data[0] = data[0];
diag.data[1] = data[4];
diag.data[2] = data[8];
return diag;
}
/// Returns a transposed matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 3> transpose() const {
Matrix<Element, 3, 3> mt;
mt.data[0] = data[0];
mt.data[3] = data[1];
mt.data[6] = data[2];
mt.data[1] = data[3];
mt.data[4] = data[4];
mt.data[7] = data[5];
mt.data[2] = data[6];
mt.data[5] = data[7];
mt.data[8] = data[8];
return mt;
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element at(int i, int j) const {
return data[i * 3 + j];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & at(int i, int j) {
return data[i * 3 + j];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element at(Coord<2> const &coord) const {
return at(coord[0], coord[1]);
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & at(Coord<2> const &coord) {
return at(coord[0], coord[1]);
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element &at(int offset) {
return data[offset];
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element at(int offset) const {
return data[offset];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element operator[](Coord<2> const &coord) const {
return at(coord[0], coord[1]);
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & operator[](Coord<2> const &coord) {
return at(coord[0], coord[1]);
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element & operator[](int offset) {
return data[offset];
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element operator[](int offset) const {
return data[offset];
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 2> slice_1x2(int i = 0, int j = 0) const {
Matrix<Element, 1, 2> m;
m.data[0] = data[i * 3 + j + 0];
m.data[1] = data[i * 3 + j + 1];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_1x2(Matrix<Element, 1, 2> const &m, int i = 0, int j = 0) {
data[i * 3 + j + 0] = m.data[0];
data[i * 3 + j + 1] = m.data[1];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 3> slice_1x3(int i = 0, int j = 0) const {
Matrix<Element, 1, 3> m;
m.data[0] = data[i * 3 + j + 0];
m.data[1] = data[i * 3 + j + 1];
m.data[2] = data[i * 3 + j + 2];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_1x3(Matrix<Element, 1, 3> const &m, int i = 0, int j = 0) {
data[i * 3 + j + 0] = m.data[0];
data[i * 3 + j + 1] = m.data[1];
data[i * 3 + j + 2] = m.data[2];
return *this;
}
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 3> row(int i) const {
return slice_1x3(i, 0);
}
CUTLASS_HOST_DEVICE
Matrix &set_row(Matrix<Element, 1, 3> const &v, int i = 0) {
return set_slice_1x3(v, i, 0);
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 1> slice_2x1(int i = 0, int j = 0) const {
Matrix<Element, 2, 1> m;
m.data[0] = data[i * 3 + j + 0];
m.data[1] = data[i * 3 + j + 3];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_2x1(Matrix<Element, 2, 1> const &m, int i = 0, int j = 0) {
data[i * 3 + j + 0] = m.data[0];
data[i * 3 + j + 3] = m.data[1];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 2> slice_2x2(int i = 0, int j = 0) const {
Matrix<Element, 2, 2> m;
m.data[0] = data[i * 3 + j + 0];
m.data[1] = data[i * 3 + j + 1];
m.data[2] = data[i * 3 + j + 3];
m.data[3] = data[i * 3 + j + 4];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_2x2(Matrix<Element, 2, 2> const &m, int i = 0, int j = 0) {
data[i * 3 + j + 0] = m.data[0];
data[i * 3 + j + 1] = m.data[1];
data[i * 3 + j + 3] = m.data[2];
data[i * 3 + j + 4] = m.data[3];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 3> slice_2x3(int i = 0, int j = 0) const {
Matrix<Element, 2, 3> m;
m.data[0] = data[i * 3 + j + 0];
m.data[1] = data[i * 3 + j + 1];
m.data[2] = data[i * 3 + j + 2];
m.data[3] = data[i * 3 + j + 3];
m.data[4] = data[i * 3 + j + 4];
m.data[5] = data[i * 3 + j + 5];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_2x3(Matrix<Element, 2, 3> const &m, int i = 0, int j = 0) {
data[i * 3 + j + 0] = m.data[0];
data[i * 3 + j + 1] = m.data[1];
data[i * 3 + j + 2] = m.data[2];
data[i * 3 + j + 3] = m.data[3];
data[i * 3 + j + 4] = m.data[4];
data[i * 3 + j + 5] = m.data[5];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 1> slice_3x1(int i = 0, int j = 0) const {
Matrix<Element, 3, 1> m;
m.data[0] = data[i * 3 + j + 0];
m.data[1] = data[i * 3 + j + 3];
m.data[2] = data[i * 3 + j + 6];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_3x1(Matrix<Element, 3, 1> const &m, int i = 0, int j = 0) {
data[i * 3 + j + 0] = m.data[0];
data[i * 3 + j + 3] = m.data[1];
data[i * 3 + j + 6] = m.data[2];
return *this;
}
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 1> column(int j) const {
return slice_3x1(0, j);
}
CUTLASS_HOST_DEVICE
Matrix &set_column(Matrix<Element, 3, 1> const &v, int j =0) {
return set_slice_3x1(v, 0, j);
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 2> slice_3x2(int i = 0, int j = 0) const {
Matrix<Element, 3, 2> m;
m.data[0] = data[i * 3 + j + 0];
m.data[1] = data[i * 3 + j + 1];
m.data[2] = data[i * 3 + j + 3];
m.data[3] = data[i * 3 + j + 4];
m.data[4] = data[i * 3 + j + 6];
m.data[5] = data[i * 3 + j + 7];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_3x2(Matrix<Element, 3, 2> const &m, int i = 0, int j = 0) {
data[i * 3 + j + 0] = m.data[0];
data[i * 3 + j + 1] = m.data[1];
data[i * 3 + j + 3] = m.data[2];
data[i * 3 + j + 4] = m.data[3];
data[i * 3 + j + 6] = m.data[4];
data[i * 3 + j + 7] = m.data[5];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 3> slice_3x3(int i = 0, int j = 0) const {
Matrix<Element, 3, 3> m;
m.data[0] = data[i * 3 + j + 0];
m.data[1] = data[i * 3 + j + 1];
m.data[2] = data[i * 3 + j + 2];
m.data[3] = data[i * 3 + j + 3];
m.data[4] = data[i * 3 + j + 4];
m.data[5] = data[i * 3 + j + 5];
m.data[6] = data[i * 3 + j + 6];
m.data[7] = data[i * 3 + j + 7];
m.data[8] = data[i * 3 + j + 8];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_3x3(Matrix<Element, 3, 3> const &m, int i = 0, int j = 0) {
data[i * 3 + j + 0] = m.data[0];
data[i * 3 + j + 1] = m.data[1];
data[i * 3 + j + 2] = m.data[2];
data[i * 3 + j + 3] = m.data[3];
data[i * 3 + j + 4] = m.data[4];
data[i * 3 + j + 5] = m.data[5];
data[i * 3 + j + 6] = m.data[6];
data[i * 3 + j + 7] = m.data[7];
data[i * 3 + j + 8] = m.data[8];
return *this;
}
/// Forms a 3-by-3 matrix by horizontally concatenating a 3-by-1 matrix with a 3-by-2 matrix
CUTLASS_HOST_DEVICE
static Matrix hcat(Matrix<Element, 3, 1> const & lhs, Matrix<Element, 3, 2> const & rhs) {
return Matrix(
lhs.at(0, 0), rhs.at(0, 0), rhs.at(0, 1)
, lhs.at(1, 0), rhs.at(1, 0), rhs.at(1, 1)
, lhs.at(2, 0), rhs.at(2, 0), rhs.at(2, 1));
}
/// Forms a 3-by-3 matrix by horizontally concatenating a 3-by-2 matrix with a 3-by-1 matrix
CUTLASS_HOST_DEVICE
static Matrix hcat(Matrix<Element, 3, 2> const & lhs, Matrix<Element, 3, 1> const & rhs) {
return Matrix(
lhs.at(0, 0), lhs.at(0, 1), rhs.at(0, 0)
, lhs.at(1, 0), lhs.at(1, 1), rhs.at(1, 0)
, lhs.at(2, 0), lhs.at(2, 1), rhs.at(2, 0));
}
/// Concatenates this matrix with a a 3-by-1 matrix to form a 3-by-4 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 4> hcat(Matrix<Element, 3, 1> const & rhs) const {
return Matrix<Element, 3, 4>::hcat(*this, rhs);
}
/// Forms a 3-by-3 matrix by vertically concatenating a 1-by-3 matrix with a 2-by-3 matrix
CUTLASS_HOST_DEVICE
static Matrix vcat(Matrix<Element, 1, 3> const & upper, Matrix<Element, 2, 3> const & lower) {
return Matrix(
upper.at(0, 0), upper.at(0, 1), upper.at(0, 2)
, lower.at(0, 0), lower.at(0, 1), lower.at(0, 2)
, lower.at(1, 0), lower.at(1, 1), lower.at(1, 2));
}
/// Forms a 3-by-3 matrix by vertically concatenating a 2-by-3 matrix with a 1-by-3 matrix
CUTLASS_HOST_DEVICE
static Matrix vcat(Matrix<Element, 2, 3> const & upper, Matrix<Element, 1, 3> const & lower) {
return Matrix(
upper.at(0, 0), upper.at(0, 1), upper.at(0, 2)
, upper.at(1, 0), upper.at(1, 1), upper.at(1, 2)
, lower.at(0, 0), lower.at(0, 1), lower.at(0, 2));
}
/// Concatenates this matrix with a a 1-by-3 matrix to form a 4-by-3 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 3> vcat(Matrix<Element, 1, 3> const & rhs) const {
return Matrix<Element, 4, 3>::vcat(*this, rhs);
}
/// Forms a 3-by-3 matrix by concatenating four components
CUTLASS_HOST_DEVICE
static Matrix block(
Element A, Matrix<Element, 1, 2> const & B,
Matrix<Element, 2, 1> const & C, Matrix<Element, 2, 2> const & D) {
return Matrix(
A, B.at(0, 0), B.at(0, 1)
, C.at(0, 0), D.at(0, 0), D.at(0, 1)
, C.at(1, 0), D.at(1, 0), D.at(1, 1)
);
}
/// Forms a 3-by-3 matrix by concatenating four components
CUTLASS_HOST_DEVICE
static Matrix block(
Matrix<Element, 1, 2> const & A, Element B,
Matrix<Element, 2, 2> const & C, Matrix<Element, 2, 1> const & D) {
return Matrix(
A.at(0, 0), A.at(0, 1), B
, C.at(0, 0), C.at(0, 1), D.at(0, 0)
, C.at(1, 0), C.at(1, 1), D.at(1, 0)
);
}
/// Forms a 3-by-3 matrix by concatenating four components
CUTLASS_HOST_DEVICE
static Matrix block(
Matrix<Element, 2, 1> const & A, Matrix<Element, 2, 2> const & B,
Element C, Matrix<Element, 1, 2> const & D) {
return Matrix(
A.at(0, 0), B.at(0, 0), B.at(0, 1)
, A.at(1, 0), B.at(1, 0), B.at(1, 1)
, C, D.at(0, 0), D.at(0, 1)
);
}
/// Forms a 3-by-3 matrix by concatenating four components
CUTLASS_HOST_DEVICE
static Matrix block(
Matrix<Element, 2, 2> const & A, Matrix<Element, 2, 1> const & B,
Matrix<Element, 1, 2> const & C, Element D) {
return Matrix(
A.at(0, 0), A.at(0, 1), B.at(0, 0)
, A.at(1, 0), A.at(1, 1), B.at(1, 0)
, C.at(0, 0), C.at(0, 1), D
);
}
/// Elementwise add operator (3-by-3)
CUTLASS_HOST_DEVICE
Matrix add(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] + rhs.data[0];
result.data[1] = data[1] + rhs.data[1];
result.data[2] = data[2] + rhs.data[2];
result.data[3] = data[3] + rhs.data[3];
result.data[4] = data[4] + rhs.data[4];
result.data[5] = data[5] + rhs.data[5];
result.data[6] = data[6] + rhs.data[6];
result.data[7] = data[7] + rhs.data[7];
result.data[8] = data[8] + rhs.data[8];
return result;
}
/// Elementwise add operator (3-by-3)
CUTLASS_HOST_DEVICE
Matrix operator +(Matrix const &rhs) const {
return add(rhs);
}
/// Elementwise add operator (3-by-3)
CUTLASS_HOST_DEVICE
Matrix & operator +=(Matrix const &rhs) {
data[0] += rhs.data[0];
data[1] += rhs.data[1];
data[2] += rhs.data[2];
data[3] += rhs.data[3];
data[4] += rhs.data[4];
data[5] += rhs.data[5];
data[6] += rhs.data[6];
data[7] += rhs.data[7];
data[8] += rhs.data[8];
return *this;
}
/// Elementwise subtract operator (3-by-3)
CUTLASS_HOST_DEVICE
Matrix subtract(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] - rhs.data[0];
result.data[1] = data[1] - rhs.data[1];
result.data[2] = data[2] - rhs.data[2];
result.data[3] = data[3] - rhs.data[3];
result.data[4] = data[4] - rhs.data[4];
result.data[5] = data[5] - rhs.data[5];
result.data[6] = data[6] - rhs.data[6];
result.data[7] = data[7] - rhs.data[7];
result.data[8] = data[8] - rhs.data[8];
return result;
}
/// Elementwise subtract operator (3-by-3)
CUTLASS_HOST_DEVICE
Matrix operator -(Matrix const &rhs) const {
return subtract(rhs);
}
/// Elementwise subtract operator (3-by-3)
CUTLASS_HOST_DEVICE
Matrix & operator -=(Matrix const &rhs) {
data[0] -= rhs.data[0];
data[1] -= rhs.data[1];
data[2] -= rhs.data[2];
data[3] -= rhs.data[3];
data[4] -= rhs.data[4];
data[5] -= rhs.data[5];
data[6] -= rhs.data[6];
data[7] -= rhs.data[7];
data[8] -= rhs.data[8];
return *this;
}
/// Elementwise multiply operator (3-by-3)
CUTLASS_HOST_DEVICE
Matrix multiply(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] * rhs.data[0];
result.data[1] = data[1] * rhs.data[1];
result.data[2] = data[2] * rhs.data[2];
result.data[3] = data[3] * rhs.data[3];
result.data[4] = data[4] * rhs.data[4];
result.data[5] = data[5] * rhs.data[5];
result.data[6] = data[6] * rhs.data[6];
result.data[7] = data[7] * rhs.data[7];
result.data[8] = data[8] * rhs.data[8];
return result;
}
/// Scalar multiply operator (3-by-3)
CUTLASS_HOST_DEVICE
Matrix multiply(Element const &s) const {
Matrix result;
result.data[0] = data[0] * s;
result.data[1] = data[1] * s;
result.data[2] = data[2] * s;
result.data[3] = data[3] * s;
result.data[4] = data[4] * s;
result.data[5] = data[5] * s;
result.data[6] = data[6] * s;
result.data[7] = data[7] * s;
result.data[8] = data[8] * s;
return result;
}
/// Scalar multiply operator (3-by-3)
CUTLASS_HOST_DEVICE
Matrix operator *(Element const &s) const {
return multiply(s);
}
/// Scalar multiply operator (3-by-3)
CUTLASS_HOST_DEVICE
Matrix & operator *=(Element const &s) {
data[0] *= s;
data[1] *= s;
data[2] *= s;
data[3] *= s;
data[4] *= s;
data[5] *= s;
data[6] *= s;
data[7] *= s;
data[8] *= s;
return *this;
}
/// Elementwise divide operator (3-by-3)
CUTLASS_HOST_DEVICE
Matrix divide(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] / rhs.data[0];
result.data[1] = data[1] / rhs.data[1];
result.data[2] = data[2] / rhs.data[2];
result.data[3] = data[3] / rhs.data[3];
result.data[4] = data[4] / rhs.data[4];
result.data[5] = data[5] / rhs.data[5];
result.data[6] = data[6] / rhs.data[6];
result.data[7] = data[7] / rhs.data[7];
result.data[8] = data[8] / rhs.data[8];
return result;
}
/// Scalar divide operator (3-by-3)
CUTLASS_HOST_DEVICE
Matrix divide(Element const &s) const {
Matrix result;
result.data[0] = data[0] / s;
result.data[1] = data[1] / s;
result.data[2] = data[2] / s;
result.data[3] = data[3] / s;
result.data[4] = data[4] / s;
result.data[5] = data[5] / s;
result.data[6] = data[6] / s;
result.data[7] = data[7] / s;
result.data[8] = data[8] / s;
return result;
}
/// Scalar divide operator (3-by-3)
CUTLASS_HOST_DEVICE
Matrix operator /(Element const &s) const {
return divide(s);
}
/// Scalar divide operator (3-by-3)
CUTLASS_HOST_DEVICE
Matrix & operator /=(Element const &s) {
data[0] /= s;
data[1] /= s;
data[2] /= s;
data[3] /= s;
data[4] /= s;
data[5] /= s;
data[6] /= s;
data[7] /= s;
data[8] /= s;
return *this;
}
/// Elementwise divide operator (3-by-3)
CUTLASS_HOST_DEVICE
Matrix operator /(Matrix const &rhs) const {
return divide(rhs);
}
/// Elementwise divide operator (3-by-3)
CUTLASS_HOST_DEVICE
Matrix & operator /=(Matrix const &rhs) {
data[0] /= rhs.data[0];
data[1] /= rhs.data[1];
data[2] /= rhs.data[2];
data[3] /= rhs.data[3];
data[4] /= rhs.data[4];
data[5] /= rhs.data[5];
data[6] /= rhs.data[6];
data[7] /= rhs.data[7];
data[8] /= rhs.data[8];
return *this;
}
/// Negates each element of the matrix
CUTLASS_HOST_DEVICE
Matrix operator-() const {
Matrix m;
m.data[0] = -m.data[0];
m.data[1] = -m.data[1];
m.data[2] = -m.data[2];
m.data[3] = -m.data[3];
m.data[4] = -m.data[4];
m.data[5] = -m.data[5];
m.data[6] = -m.data[6];
m.data[7] = -m.data[7];
m.data[8] = -m.data[8];
return m;
}
/// Matrix product of size 3-by-1-by-3
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 1> product(
Matrix<Element, 3, 1> const &rhs,
Matrix<Element, 3, 1> accum = Matrix<Element, 3, 1>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[3] * rhs.data[0];
accum.data[2] += data[6] * rhs.data[0];
// k=1
accum.data[0] += data[1] * rhs.data[1];
accum.data[1] += data[4] * rhs.data[1];
accum.data[2] += data[7] * rhs.data[1];
// k=2
accum.data[0] += data[2] * rhs.data[2];
accum.data[1] += data[5] * rhs.data[2];
accum.data[2] += data[8] * rhs.data[2];
return accum;
}
/// Matrix product of size 3-by-1-by-3
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 1> operator*(Matrix<Element, 3, 1> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 3-by-2-by-3
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 2> product(
Matrix<Element, 3, 2> const &rhs,
Matrix<Element, 3, 2> accum = Matrix<Element, 3, 2>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[3] * rhs.data[0];
accum.data[3] += data[3] * rhs.data[1];
accum.data[4] += data[6] * rhs.data[0];
accum.data[5] += data[6] * rhs.data[1];
// k=1
accum.data[0] += data[1] * rhs.data[2];
accum.data[1] += data[1] * rhs.data[3];
accum.data[2] += data[4] * rhs.data[2];
accum.data[3] += data[4] * rhs.data[3];
accum.data[4] += data[7] * rhs.data[2];
accum.data[5] += data[7] * rhs.data[3];
// k=2
accum.data[0] += data[2] * rhs.data[4];
accum.data[1] += data[2] * rhs.data[5];
accum.data[2] += data[5] * rhs.data[4];
accum.data[3] += data[5] * rhs.data[5];
accum.data[4] += data[8] * rhs.data[4];
accum.data[5] += data[8] * rhs.data[5];
return accum;
}
/// Matrix product of size 3-by-2-by-3
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 2> operator*(Matrix<Element, 3, 2> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 3-by-3-by-3
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 3> product(
Matrix<Element, 3, 3> const &rhs,
Matrix<Element, 3, 3> accum = Matrix<Element, 3, 3>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[0] * rhs.data[2];
accum.data[3] += data[3] * rhs.data[0];
accum.data[4] += data[3] * rhs.data[1];
accum.data[5] += data[3] * rhs.data[2];
accum.data[6] += data[6] * rhs.data[0];
accum.data[7] += data[6] * rhs.data[1];
accum.data[8] += data[6] * rhs.data[2];
// k=1
accum.data[0] += data[1] * rhs.data[3];
accum.data[1] += data[1] * rhs.data[4];
accum.data[2] += data[1] * rhs.data[5];
accum.data[3] += data[4] * rhs.data[3];
accum.data[4] += data[4] * rhs.data[4];
accum.data[5] += data[4] * rhs.data[5];
accum.data[6] += data[7] * rhs.data[3];
accum.data[7] += data[7] * rhs.data[4];
accum.data[8] += data[7] * rhs.data[5];
// k=2
accum.data[0] += data[2] * rhs.data[6];
accum.data[1] += data[2] * rhs.data[7];
accum.data[2] += data[2] * rhs.data[8];
accum.data[3] += data[5] * rhs.data[6];
accum.data[4] += data[5] * rhs.data[7];
accum.data[5] += data[5] * rhs.data[8];
accum.data[6] += data[8] * rhs.data[6];
accum.data[7] += data[8] * rhs.data[7];
accum.data[8] += data[8] * rhs.data[8];
return accum;
}
/// Matrix product of size 3-by-3-by-3
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 3> operator*(Matrix<Element, 3, 3> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 3-by-3-by-3
CUTLASS_HOST_DEVICE
Matrix & operator*=(Matrix<Element, 3, 3> const &rhs) {
*this = product(rhs);
return *this;
}
/// Matrix product of size 3-by-4-by-3
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 4> product(
Matrix<Element, 3, 4> const &rhs,
Matrix<Element, 3, 4> accum = Matrix<Element, 3, 4>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[0] * rhs.data[2];
accum.data[3] += data[0] * rhs.data[3];
accum.data[4] += data[3] * rhs.data[0];
accum.data[5] += data[3] * rhs.data[1];
accum.data[6] += data[3] * rhs.data[2];
accum.data[7] += data[3] * rhs.data[3];
accum.data[8] += data[6] * rhs.data[0];
accum.data[9] += data[6] * rhs.data[1];
accum.data[10] += data[6] * rhs.data[2];
accum.data[11] += data[6] * rhs.data[3];
// k=1
accum.data[0] += data[1] * rhs.data[4];
accum.data[1] += data[1] * rhs.data[5];
accum.data[2] += data[1] * rhs.data[6];
accum.data[3] += data[1] * rhs.data[7];
accum.data[4] += data[4] * rhs.data[4];
accum.data[5] += data[4] * rhs.data[5];
accum.data[6] += data[4] * rhs.data[6];
accum.data[7] += data[4] * rhs.data[7];
accum.data[8] += data[7] * rhs.data[4];
accum.data[9] += data[7] * rhs.data[5];
accum.data[10] += data[7] * rhs.data[6];
accum.data[11] += data[7] * rhs.data[7];
// k=2
accum.data[0] += data[2] * rhs.data[8];
accum.data[1] += data[2] * rhs.data[9];
accum.data[2] += data[2] * rhs.data[10];
accum.data[3] += data[2] * rhs.data[11];
accum.data[4] += data[5] * rhs.data[8];
accum.data[5] += data[5] * rhs.data[9];
accum.data[6] += data[5] * rhs.data[10];
accum.data[7] += data[5] * rhs.data[11];
accum.data[8] += data[8] * rhs.data[8];
accum.data[9] += data[8] * rhs.data[9];
accum.data[10] += data[8] * rhs.data[10];
accum.data[11] += data[8] * rhs.data[11];
return accum;
}
/// Matrix product of size 3-by-4-by-3
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 4> operator*(Matrix<Element, 3, 4> const &rhs) const {
return product(rhs);
}
/// Returns the sum of elements
CUTLASS_HOST_DEVICE
Element sum(Element accum = Element()) const {
accum += data[0];
accum += data[1];
accum += data[2];
accum += data[3];
accum += data[4];
accum += data[5];
accum += data[6];
accum += data[7];
accum += data[8];
return accum;
}
/// Returns the sum of squared elements
CUTLASS_HOST_DEVICE
Element norm(Element accum = Element()) const {
accum += data[0] * data[0];
accum += data[1] * data[1];
accum += data[2] * data[2];
accum += data[3] * data[3];
accum += data[4] * data[4];
accum += data[5] * data[5];
accum += data[6] * data[6];
accum += data[7] * data[7];
accum += data[8] * data[8];
return accum;
}
/// Returns square root of the norm
CUTLASS_HOST_DEVICE
Element magnitude() const {
return fast_sqrt(norm());
}
/// Returns the sum of diagonal elements
CUTLASS_HOST_DEVICE
Element trace(Element accum = Element()) const {
accum += data[0];
accum += data[4];
accum += data[8];
return accum;
}
/// Returns 3-by-3 rotation matrix around the X axis
CUTLASS_HOST_DEVICE
static Matrix rotation_X(Element theta) {
Matrix m = identity();
Element c = fast_cos(theta);
Element s = fast_sin(theta);
m.at(1, 1) = c;
m.at(1, 2) = -s;
m.at(2, 1) = s;
m.at(2, 2) = c;
return m;
}
/// Returns 3-by-3 rotation matrix around the Y axis
CUTLASS_HOST_DEVICE
static Matrix rotation_Y(Element theta) {
Matrix m = identity();
Element c = fast_cos(theta);
Element s = fast_sin(theta);
m.at(0, 0) = c;
m.at(2, 0) = -s;
m.at(0, 2) = s;
m.at(2, 2) = c;
return m;
}
/// Returns 3-by-3 rotation matrix around the Z axis
CUTLASS_HOST_DEVICE
static Matrix rotation_Z(Element theta) {
Matrix m = Matrix::identity();
Element c = fast_cos(theta);
Element s = fast_sin(theta);
m.at(0, 0) = c;
m.at(0, 1) = -s;
m.at(1, 0) = s;
m.at(1, 1) = c;
return m;
}
/// Returns a 3-by-3 rotation matrix around a unit-length axis
CUTLASS_HOST_DEVICE
static Matrix rotation(Element theta, Matrix<Element, 3, 1> const &u) {
Element x = u.data[0];
Element y = u.data[1];
Element z = u.data[2];
Element c = fast_cos(theta);
Element s = fast_sin(theta);
Element one_minus_cos = Element(1) - fast_cos(theta);
Matrix m;
m.set_slice3x3({
c + x * x * one_minus_cos, x * y * one_minus_cos - z * s, x * z * one_minus_cos + y * s,
y * x * one_minus_cos * z * s, c + y * y * one_minus_cos, y * z * one_minus_cos - x * s,
z * x * one_minus_cos - y * s, z * y * one_minus_cos + x * s, c + z * z * one_minus_cos
});
return m;
}
/// Returns a 3-by-3 reflection about the plane specified by the
/// unit-length normal vector n_unit
CUTLASS_HOST_DEVICE
static Matrix reflection(Matrix<Element, 3, 1> const &n_unit) {
Element a = n_unit.data[0];
Element b = n_unit.data[1];
Element c = n_unit.data[2];
Matrix m = Matrix::identity();
m.set_slice3x3({
Element(1) - Element(2) * a * a, Element(-2) * a * b, Element(-2) * a * c,
Element(-2) * a * b, Element(1) - Element(2) * b * b, Element(-2) * b * c,
Element(-2) * a * c, Element(-2) * b * c, Element(1) - Element(2) * c * c
});
return m;
}
/// Computes the determinant of a 3-by-3 matrix
CUTLASS_HOST_DEVICE
Element determinant(Element accum = Element()) const {
accum += at(0, 0) * Matrix<Element, 2, 2>({ at(1, 1), at(1, 2), at(2, 1), at(2, 2) }).determinant();
accum -= at(0, 1) * Matrix<Element, 2, 2>({ at(1, 0), at(1, 2), at(2, 0), at(2, 2) }).determinant();
accum += at(0, 2) * Matrix<Element, 2, 2>({ at(1, 0), at(1, 1), at(2, 0), at(2, 1) }).determinant();
return accum;
}
/// Computes the inverse of a 3-by-3 matrix given
/// the matrix's determinant
CUTLASS_HOST_DEVICE
Matrix inverse(Element det) const {
return Matrix(
at(1, 1) * at(2, 2) - at(1, 2) * at(2, 1),
at(0, 2) * at(2, 1) - at(0, 1) * at(2, 2),
at(0, 1) * at(1, 2) - at(0, 2) * at(1, 1),
at(1, 2) * at(2, 0) - at(1, 0) * at(2, 2),
at(0, 0) * at(2, 2) - at(0, 2) * at(2, 0),
at(0, 2) * at(1, 0) - at(0, 0) * at(1, 2),
at(1, 0) * at(2, 1) - at(1, 1) * at(2, 0),
at(0, 1) * at(2, 0) - at(0, 0) * at(2, 1),
at(0, 0) * at(1, 1) - at(0, 1) * at(1, 0)
) * (Element(1) / det);
}
/// Computes the inverse of a 3-by-3 matrix
CUTLASS_HOST_DEVICE
Matrix inverse() const {
return inverse(determinant());
}
};
/// Template alias for 3-by-3 matrix
template <typename Element>
using Matrix3x3 = Matrix<Element, 3, 3>;
/// Free funciton to infer element type from template arguments
template <typename Element>
CUTLASS_HOST_DEVICE Matrix3x3<Element> make_Matrix3x3(
Element _0_0, Element _0_1, Element _0_2,
Element _1_0, Element _1_1, Element _1_2,
Element _2_0, Element _2_1, Element _2_2
) {
return Matrix3x3<Element>(
_0_0, _0_1, _0_2,
_1_0, _1_1, _1_2,
_2_0, _2_1, _2_2
);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// 3-by-4 matrix template class definition
template <typename Element_>
struct Matrix<Element_, 3, 4> {
//
// Type definitions
//
/// Element data type
using Element = Element_;
/// Number of rows in matrix
static int const kRows = 3;
/// Number of columns in matrix
static int const kColumns = 4;
/// Layout of matrix in underlying array
using Layout = layout::RowMajor;
/// Number of elements in matrix
static int const kCount = 12;
//
// Data members
//
/// Elements of the matrix in row-major layout
Array<Element, kCount> data;
//
// Methods
//
/// Constructs a zero matrix
CUTLASS_HOST_DEVICE
Matrix() {
data.clear();
}
/// Copy constructor for a 3-by-4 matrix
CUTLASS_HOST_DEVICE
Matrix(Matrix const &rhs) {
data = rhs.data;
}
/// Constucts a 3-by-4 matrix from scalar elements
CUTLASS_HOST_DEVICE
Matrix(
Element _0_0, Element _0_1, Element _0_2, Element _0_3,
Element _1_0, Element _1_1, Element _1_2, Element _1_3,
Element _2_0, Element _2_1, Element _2_2, Element _2_3
) {
data[0] = _0_0; data[1] = _0_1; data[2] = _0_2; data[3] = _0_3;
data[4] = _1_0; data[5] = _1_1; data[6] = _1_2; data[7] = _1_3;
data[8] = _2_0; data[9] = _2_1; data[10] = _2_2; data[11] = _2_3;
}
/// Constucts a 3-by-4 matrix from row vectors
CUTLASS_HOST_DEVICE
Matrix(
Matrix<Element, 1, 4> const &row_0,
Matrix<Element, 1, 4> const &row_1,
Matrix<Element, 1, 4> const &row_2
) {
data[0] = row_0.data[0];
data[1] = row_0.data[1];
data[2] = row_0.data[2];
data[3] = row_0.data[3];
data[4] = row_1.data[0];
data[5] = row_1.data[1];
data[6] = row_1.data[2];
data[7] = row_1.data[3];
data[8] = row_2.data[0];
data[9] = row_2.data[1];
data[10] = row_2.data[2];
data[11] = row_2.data[3];
}
/// Static method to construct a 3-by-4 matrix from column vectors
CUTLASS_HOST_DEVICE
static Matrix from_columns(
Matrix<Element, 4, 1> const &column_0,
Matrix<Element, 4, 1> const &column_1,
Matrix<Element, 4, 1> const &column_2,
Matrix<Element, 4, 1> const &column_3
) {
Matrix result;
result.data[0] = column_0.data[0];
result.data[1] = column_1.data[0];
result.data[2] = column_2.data[0];
result.data[3] = column_3.data[0];
result.data[4] = column_0.data[1];
result.data[5] = column_1.data[1];
result.data[6] = column_2.data[1];
result.data[7] = column_3.data[1];
result.data[8] = column_0.data[2];
result.data[9] = column_1.data[2];
result.data[10] = column_2.data[2];
result.data[11] = column_3.data[2];
return result;
}
/// Constructs a matrix from a uniform element
CUTLASS_HOST_DEVICE
static Matrix uniform(Element s) {
Matrix m;
m.data[0] = s;
m.data[1] = s;
m.data[2] = s;
m.data[3] = s;
m.data[4] = s;
m.data[5] = s;
m.data[6] = s;
m.data[7] = s;
m.data[8] = s;
m.data[9] = s;
m.data[10] = s;
m.data[11] = s;
return m;
}
/// Constructs a matrix from a uniform element 1
CUTLASS_HOST_DEVICE
static Matrix ones() {
return uniform(Element(1));
}
/// Constructs a matrix from a uniform element 0
CUTLASS_HOST_DEVICE
static Matrix zero() {
return Matrix();
}
/// Constructs a matrix from elements along its diagonal
CUTLASS_HOST_DEVICE
static Matrix from_diagonal(Matrix<Element, 3, 1> const &diag) {
Matrix m;
m.data[0] = diag.data[0];
m.data[4] = diag.data[1];
m.data[8] = diag.data[2];
return m;
}
/// Constructs a matrix from elements along its diagonal
CUTLASS_HOST_DEVICE
static Matrix from_diagonal(Matrix<Element, 1, 3> const &diag) {
Matrix m;
m.data[0] = diag.data[0];
m.data[4] = diag.data[1];
m.data[8] = diag.data[2];
return m;
}
/// Gets an array of diagonal elements
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 1> diagonal() const {
Matrix<Element, 3, 1> diag;
diag.data[0] = data[0];
diag.data[1] = data[4];
diag.data[2] = data[8];
return diag;
}
/// Returns a transposed matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 3> transpose() const {
Matrix<Element, 4, 3> mt;
mt.data[0] = data[0];
mt.data[3] = data[1];
mt.data[6] = data[2];
mt.data[9] = data[3];
mt.data[1] = data[4];
mt.data[4] = data[5];
mt.data[7] = data[6];
mt.data[10] = data[7];
mt.data[2] = data[8];
mt.data[5] = data[9];
mt.data[8] = data[10];
mt.data[11] = data[11];
return mt;
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element at(int i, int j) const {
return data[i * 3 + j];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & at(int i, int j) {
return data[i * 3 + j];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element at(Coord<2> const &coord) const {
return at(coord[0], coord[1]);
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & at(Coord<2> const &coord) {
return at(coord[0], coord[1]);
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element &at(int offset) {
return data[offset];
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element at(int offset) const {
return data[offset];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element operator[](Coord<2> const &coord) const {
return at(coord[0], coord[1]);
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & operator[](Coord<2> const &coord) {
return at(coord[0], coord[1]);
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element & operator[](int offset) {
return data[offset];
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element operator[](int offset) const {
return data[offset];
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 2> slice_1x2(int i = 0, int j = 0) const {
Matrix<Element, 1, 2> m;
m.data[0] = data[i * 4 + j + 0];
m.data[1] = data[i * 4 + j + 1];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_1x2(Matrix<Element, 1, 2> const &m, int i = 0, int j = 0) {
data[i * 4 + j + 0] = m.data[0];
data[i * 4 + j + 1] = m.data[1];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 3> slice_1x3(int i = 0, int j = 0) const {
Matrix<Element, 1, 3> m;
m.data[0] = data[i * 4 + j + 0];
m.data[1] = data[i * 4 + j + 1];
m.data[2] = data[i * 4 + j + 2];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_1x3(Matrix<Element, 1, 3> const &m, int i = 0, int j = 0) {
data[i * 4 + j + 0] = m.data[0];
data[i * 4 + j + 1] = m.data[1];
data[i * 4 + j + 2] = m.data[2];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 4> slice_1x4(int i = 0, int j = 0) const {
Matrix<Element, 1, 4> m;
m.data[0] = data[i * 4 + j + 0];
m.data[1] = data[i * 4 + j + 1];
m.data[2] = data[i * 4 + j + 2];
m.data[3] = data[i * 4 + j + 3];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_1x4(Matrix<Element, 1, 4> const &m, int i = 0, int j = 0) {
data[i * 4 + j + 0] = m.data[0];
data[i * 4 + j + 1] = m.data[1];
data[i * 4 + j + 2] = m.data[2];
data[i * 4 + j + 3] = m.data[3];
return *this;
}
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 4> row(int i) const {
return slice_1x4(i, 0);
}
CUTLASS_HOST_DEVICE
Matrix &set_row(Matrix<Element, 1, 4> const &v, int i = 0) {
return set_slice_1x4(v, i, 0);
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 1> slice_2x1(int i = 0, int j = 0) const {
Matrix<Element, 2, 1> m;
m.data[0] = data[i * 4 + j + 0];
m.data[1] = data[i * 4 + j + 4];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_2x1(Matrix<Element, 2, 1> const &m, int i = 0, int j = 0) {
data[i * 4 + j + 0] = m.data[0];
data[i * 4 + j + 4] = m.data[1];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 2> slice_2x2(int i = 0, int j = 0) const {
Matrix<Element, 2, 2> m;
m.data[0] = data[i * 4 + j + 0];
m.data[1] = data[i * 4 + j + 1];
m.data[2] = data[i * 4 + j + 4];
m.data[3] = data[i * 4 + j + 5];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_2x2(Matrix<Element, 2, 2> const &m, int i = 0, int j = 0) {
data[i * 4 + j + 0] = m.data[0];
data[i * 4 + j + 1] = m.data[1];
data[i * 4 + j + 4] = m.data[2];
data[i * 4 + j + 5] = m.data[3];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 3> slice_2x3(int i = 0, int j = 0) const {
Matrix<Element, 2, 3> m;
m.data[0] = data[i * 4 + j + 0];
m.data[1] = data[i * 4 + j + 1];
m.data[2] = data[i * 4 + j + 2];
m.data[3] = data[i * 4 + j + 4];
m.data[4] = data[i * 4 + j + 5];
m.data[5] = data[i * 4 + j + 6];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_2x3(Matrix<Element, 2, 3> const &m, int i = 0, int j = 0) {
data[i * 4 + j + 0] = m.data[0];
data[i * 4 + j + 1] = m.data[1];
data[i * 4 + j + 2] = m.data[2];
data[i * 4 + j + 4] = m.data[3];
data[i * 4 + j + 5] = m.data[4];
data[i * 4 + j + 6] = m.data[5];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 4> slice_2x4(int i = 0, int j = 0) const {
Matrix<Element, 2, 4> m;
m.data[0] = data[i * 4 + j + 0];
m.data[1] = data[i * 4 + j + 1];
m.data[2] = data[i * 4 + j + 2];
m.data[3] = data[i * 4 + j + 3];
m.data[4] = data[i * 4 + j + 4];
m.data[5] = data[i * 4 + j + 5];
m.data[6] = data[i * 4 + j + 6];
m.data[7] = data[i * 4 + j + 7];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_2x4(Matrix<Element, 2, 4> const &m, int i = 0, int j = 0) {
data[i * 4 + j + 0] = m.data[0];
data[i * 4 + j + 1] = m.data[1];
data[i * 4 + j + 2] = m.data[2];
data[i * 4 + j + 3] = m.data[3];
data[i * 4 + j + 4] = m.data[4];
data[i * 4 + j + 5] = m.data[5];
data[i * 4 + j + 6] = m.data[6];
data[i * 4 + j + 7] = m.data[7];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 1> slice_3x1(int i = 0, int j = 0) const {
Matrix<Element, 3, 1> m;
m.data[0] = data[i * 4 + j + 0];
m.data[1] = data[i * 4 + j + 4];
m.data[2] = data[i * 4 + j + 8];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_3x1(Matrix<Element, 3, 1> const &m, int i = 0, int j = 0) {
data[i * 4 + j + 0] = m.data[0];
data[i * 4 + j + 4] = m.data[1];
data[i * 4 + j + 8] = m.data[2];
return *this;
}
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 1> column(int j) const {
return slice_3x1(0, j);
}
CUTLASS_HOST_DEVICE
Matrix &set_column(Matrix<Element, 3, 1> const &v, int j =0) {
return set_slice_3x1(v, 0, j);
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 2> slice_3x2(int i = 0, int j = 0) const {
Matrix<Element, 3, 2> m;
m.data[0] = data[i * 4 + j + 0];
m.data[1] = data[i * 4 + j + 1];
m.data[2] = data[i * 4 + j + 4];
m.data[3] = data[i * 4 + j + 5];
m.data[4] = data[i * 4 + j + 8];
m.data[5] = data[i * 4 + j + 9];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_3x2(Matrix<Element, 3, 2> const &m, int i = 0, int j = 0) {
data[i * 4 + j + 0] = m.data[0];
data[i * 4 + j + 1] = m.data[1];
data[i * 4 + j + 4] = m.data[2];
data[i * 4 + j + 5] = m.data[3];
data[i * 4 + j + 8] = m.data[4];
data[i * 4 + j + 9] = m.data[5];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 3> slice_3x3(int i = 0, int j = 0) const {
Matrix<Element, 3, 3> m;
m.data[0] = data[i * 4 + j + 0];
m.data[1] = data[i * 4 + j + 1];
m.data[2] = data[i * 4 + j + 2];
m.data[3] = data[i * 4 + j + 4];
m.data[4] = data[i * 4 + j + 5];
m.data[5] = data[i * 4 + j + 6];
m.data[6] = data[i * 4 + j + 8];
m.data[7] = data[i * 4 + j + 9];
m.data[8] = data[i * 4 + j + 10];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_3x3(Matrix<Element, 3, 3> const &m, int i = 0, int j = 0) {
data[i * 4 + j + 0] = m.data[0];
data[i * 4 + j + 1] = m.data[1];
data[i * 4 + j + 2] = m.data[2];
data[i * 4 + j + 4] = m.data[3];
data[i * 4 + j + 5] = m.data[4];
data[i * 4 + j + 6] = m.data[5];
data[i * 4 + j + 8] = m.data[6];
data[i * 4 + j + 9] = m.data[7];
data[i * 4 + j + 10] = m.data[8];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 4> slice_3x4(int i = 0, int j = 0) const {
Matrix<Element, 3, 4> m;
m.data[0] = data[i * 4 + j + 0];
m.data[1] = data[i * 4 + j + 1];
m.data[2] = data[i * 4 + j + 2];
m.data[3] = data[i * 4 + j + 3];
m.data[4] = data[i * 4 + j + 4];
m.data[5] = data[i * 4 + j + 5];
m.data[6] = data[i * 4 + j + 6];
m.data[7] = data[i * 4 + j + 7];
m.data[8] = data[i * 4 + j + 8];
m.data[9] = data[i * 4 + j + 9];
m.data[10] = data[i * 4 + j + 10];
m.data[11] = data[i * 4 + j + 11];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_3x4(Matrix<Element, 3, 4> const &m, int i = 0, int j = 0) {
data[i * 4 + j + 0] = m.data[0];
data[i * 4 + j + 1] = m.data[1];
data[i * 4 + j + 2] = m.data[2];
data[i * 4 + j + 3] = m.data[3];
data[i * 4 + j + 4] = m.data[4];
data[i * 4 + j + 5] = m.data[5];
data[i * 4 + j + 6] = m.data[6];
data[i * 4 + j + 7] = m.data[7];
data[i * 4 + j + 8] = m.data[8];
data[i * 4 + j + 9] = m.data[9];
data[i * 4 + j + 10] = m.data[10];
data[i * 4 + j + 11] = m.data[11];
return *this;
}
/// Forms a 3-by-4 matrix by horizontally concatenating a 3-by-1 matrix with a 3-by-3 matrix
CUTLASS_HOST_DEVICE
static Matrix hcat(Matrix<Element, 3, 1> const & lhs, Matrix<Element, 3, 3> const & rhs) {
return Matrix(
lhs.at(0, 0), rhs.at(0, 0), rhs.at(0, 1), rhs.at(0, 2)
, lhs.at(1, 0), rhs.at(1, 0), rhs.at(1, 1), rhs.at(1, 2)
, lhs.at(2, 0), rhs.at(2, 0), rhs.at(2, 1), rhs.at(2, 2));
}
/// Forms a 3-by-4 matrix by horizontally concatenating a 3-by-2 matrix with a 3-by-2 matrix
CUTLASS_HOST_DEVICE
static Matrix hcat(Matrix<Element, 3, 2> const & lhs, Matrix<Element, 3, 2> const & rhs) {
return Matrix(
lhs.at(0, 0), lhs.at(0, 1), rhs.at(0, 0), rhs.at(0, 1)
, lhs.at(1, 0), lhs.at(1, 1), rhs.at(1, 0), rhs.at(1, 1)
, lhs.at(2, 0), lhs.at(2, 1), rhs.at(2, 0), rhs.at(2, 1));
}
/// Forms a 3-by-4 matrix by horizontally concatenating a 3-by-3 matrix with a 3-by-1 matrix
CUTLASS_HOST_DEVICE
static Matrix hcat(Matrix<Element, 3, 3> const & lhs, Matrix<Element, 3, 1> const & rhs) {
return Matrix(
lhs.at(0, 0), lhs.at(0, 1), lhs.at(0, 2), rhs.at(0, 0)
, lhs.at(1, 0), lhs.at(1, 1), lhs.at(1, 2), rhs.at(1, 0)
, lhs.at(2, 0), lhs.at(2, 1), lhs.at(2, 2), rhs.at(2, 0));
}
/// Forms a 3-by-4 matrix by vertically concatenating a 1-by-4 matrix with a 2-by-4 matrix
CUTLASS_HOST_DEVICE
static Matrix vcat(Matrix<Element, 1, 4> const & upper, Matrix<Element, 2, 4> const & lower) {
return Matrix(
upper.at(0, 0), upper.at(0, 1), upper.at(0, 2), upper.at(0, 3)
, lower.at(0, 0), lower.at(0, 1), lower.at(0, 2), lower.at(0, 3)
, lower.at(1, 0), lower.at(1, 1), lower.at(1, 2), lower.at(1, 3));
}
/// Forms a 3-by-4 matrix by vertically concatenating a 2-by-4 matrix with a 1-by-4 matrix
CUTLASS_HOST_DEVICE
static Matrix vcat(Matrix<Element, 2, 4> const & upper, Matrix<Element, 1, 4> const & lower) {
return Matrix(
upper.at(0, 0), upper.at(0, 1), upper.at(0, 2), upper.at(0, 3)
, upper.at(1, 0), upper.at(1, 1), upper.at(1, 2), upper.at(1, 3)
, lower.at(0, 0), lower.at(0, 1), lower.at(0, 2), lower.at(0, 3));
}
/// Concatenates this matrix with a a 1-by-4 matrix to form a 4-by-4 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 4> vcat(Matrix<Element, 1, 4> const & rhs) const {
return Matrix<Element, 4, 4>::vcat(*this, rhs);
}
/// Forms a 3-by-4 matrix by concatenating four components
CUTLASS_HOST_DEVICE
static Matrix block(
Element A, Matrix<Element, 1, 3> const & B,
Matrix<Element, 2, 1> const & C, Matrix<Element, 2, 3> const & D) {
return Matrix(
A, B.at(0, 0), B.at(0, 1), B.at(0, 2)
, C.at(0, 0), D.at(0, 0), D.at(0, 1), D.at(0, 2)
, C.at(1, 0), D.at(1, 0), D.at(1, 1), D.at(1, 2)
);
}
/// Forms a 3-by-4 matrix by concatenating four components
CUTLASS_HOST_DEVICE
static Matrix block(
Matrix<Element, 1, 2> const & A, Matrix<Element, 1, 2> const & B,
Matrix<Element, 2, 2> const & C, Matrix<Element, 2, 2> const & D) {
return Matrix(
A.at(0, 0), A.at(0, 1), B.at(0, 0), B.at(0, 1)
, C.at(0, 0), C.at(0, 1), D.at(0, 0), D.at(0, 1)
, C.at(1, 0), C.at(1, 1), D.at(1, 0), D.at(1, 1)
);
}
/// Forms a 3-by-4 matrix by concatenating four components
CUTLASS_HOST_DEVICE
static Matrix block(
Matrix<Element, 1, 3> const & A, Element B,
Matrix<Element, 2, 3> const & C, Matrix<Element, 2, 1> const & D) {
return Matrix(
A.at(0, 0), A.at(0, 1), A.at(0, 2), B
, C.at(0, 0), C.at(0, 1), C.at(0, 2), D.at(0, 0)
, C.at(1, 0), C.at(1, 1), C.at(1, 2), D.at(1, 0)
);
}
/// Forms a 3-by-4 matrix by concatenating four components
CUTLASS_HOST_DEVICE
static Matrix block(
Matrix<Element, 2, 1> const & A, Matrix<Element, 2, 3> const & B,
Element C, Matrix<Element, 1, 3> const & D) {
return Matrix(
A.at(0, 0), B.at(0, 0), B.at(0, 1), B.at(0, 2)
, A.at(1, 0), B.at(1, 0), B.at(1, 1), B.at(1, 2)
, C, D.at(0, 0), D.at(0, 1), D.at(0, 2)
);
}
/// Forms a 3-by-4 matrix by concatenating four components
CUTLASS_HOST_DEVICE
static Matrix block(
Matrix<Element, 2, 2> const & A, Matrix<Element, 2, 2> const & B,
Matrix<Element, 1, 2> const & C, Matrix<Element, 1, 2> const & D) {
return Matrix(
A.at(0, 0), A.at(0, 1), B.at(0, 0), B.at(0, 1)
, A.at(1, 0), A.at(1, 1), B.at(1, 0), B.at(1, 1)
, C.at(0, 0), C.at(0, 1), D.at(0, 0), D.at(0, 1)
);
}
/// Forms a 3-by-4 matrix by concatenating four components
CUTLASS_HOST_DEVICE
static Matrix block(
Matrix<Element, 2, 3> const & A, Matrix<Element, 2, 1> const & B,
Matrix<Element, 1, 3> const & C, Element D) {
return Matrix(
A.at(0, 0), A.at(0, 1), A.at(0, 2), B.at(0, 0)
, A.at(1, 0), A.at(1, 1), A.at(1, 2), B.at(1, 0)
, C.at(0, 0), C.at(0, 1), C.at(0, 2), D
);
}
/// Elementwise add operator (3-by-4)
CUTLASS_HOST_DEVICE
Matrix add(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] + rhs.data[0];
result.data[1] = data[1] + rhs.data[1];
result.data[2] = data[2] + rhs.data[2];
result.data[3] = data[3] + rhs.data[3];
result.data[4] = data[4] + rhs.data[4];
result.data[5] = data[5] + rhs.data[5];
result.data[6] = data[6] + rhs.data[6];
result.data[7] = data[7] + rhs.data[7];
result.data[8] = data[8] + rhs.data[8];
result.data[9] = data[9] + rhs.data[9];
result.data[10] = data[10] + rhs.data[10];
result.data[11] = data[11] + rhs.data[11];
return result;
}
/// Elementwise add operator (3-by-4)
CUTLASS_HOST_DEVICE
Matrix operator +(Matrix const &rhs) const {
return add(rhs);
}
/// Elementwise add operator (3-by-4)
CUTLASS_HOST_DEVICE
Matrix & operator +=(Matrix const &rhs) {
data[0] += rhs.data[0];
data[1] += rhs.data[1];
data[2] += rhs.data[2];
data[3] += rhs.data[3];
data[4] += rhs.data[4];
data[5] += rhs.data[5];
data[6] += rhs.data[6];
data[7] += rhs.data[7];
data[8] += rhs.data[8];
data[9] += rhs.data[9];
data[10] += rhs.data[10];
data[11] += rhs.data[11];
return *this;
}
/// Elementwise subtract operator (3-by-4)
CUTLASS_HOST_DEVICE
Matrix subtract(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] - rhs.data[0];
result.data[1] = data[1] - rhs.data[1];
result.data[2] = data[2] - rhs.data[2];
result.data[3] = data[3] - rhs.data[3];
result.data[4] = data[4] - rhs.data[4];
result.data[5] = data[5] - rhs.data[5];
result.data[6] = data[6] - rhs.data[6];
result.data[7] = data[7] - rhs.data[7];
result.data[8] = data[8] - rhs.data[8];
result.data[9] = data[9] - rhs.data[9];
result.data[10] = data[10] - rhs.data[10];
result.data[11] = data[11] - rhs.data[11];
return result;
}
/// Elementwise subtract operator (3-by-4)
CUTLASS_HOST_DEVICE
Matrix operator -(Matrix const &rhs) const {
return subtract(rhs);
}
/// Elementwise subtract operator (3-by-4)
CUTLASS_HOST_DEVICE
Matrix & operator -=(Matrix const &rhs) {
data[0] -= rhs.data[0];
data[1] -= rhs.data[1];
data[2] -= rhs.data[2];
data[3] -= rhs.data[3];
data[4] -= rhs.data[4];
data[5] -= rhs.data[5];
data[6] -= rhs.data[6];
data[7] -= rhs.data[7];
data[8] -= rhs.data[8];
data[9] -= rhs.data[9];
data[10] -= rhs.data[10];
data[11] -= rhs.data[11];
return *this;
}
/// Elementwise multiply operator (3-by-4)
CUTLASS_HOST_DEVICE
Matrix multiply(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] * rhs.data[0];
result.data[1] = data[1] * rhs.data[1];
result.data[2] = data[2] * rhs.data[2];
result.data[3] = data[3] * rhs.data[3];
result.data[4] = data[4] * rhs.data[4];
result.data[5] = data[5] * rhs.data[5];
result.data[6] = data[6] * rhs.data[6];
result.data[7] = data[7] * rhs.data[7];
result.data[8] = data[8] * rhs.data[8];
result.data[9] = data[9] * rhs.data[9];
result.data[10] = data[10] * rhs.data[10];
result.data[11] = data[11] * rhs.data[11];
return result;
}
/// Scalar multiply operator (3-by-4)
CUTLASS_HOST_DEVICE
Matrix multiply(Element const &s) const {
Matrix result;
result.data[0] = data[0] * s;
result.data[1] = data[1] * s;
result.data[2] = data[2] * s;
result.data[3] = data[3] * s;
result.data[4] = data[4] * s;
result.data[5] = data[5] * s;
result.data[6] = data[6] * s;
result.data[7] = data[7] * s;
result.data[8] = data[8] * s;
result.data[9] = data[9] * s;
result.data[10] = data[10] * s;
result.data[11] = data[11] * s;
return result;
}
/// Scalar multiply operator (3-by-4)
CUTLASS_HOST_DEVICE
Matrix operator *(Element const &s) const {
return multiply(s);
}
/// Scalar multiply operator (3-by-4)
CUTLASS_HOST_DEVICE
Matrix & operator *=(Element const &s) {
data[0] *= s;
data[1] *= s;
data[2] *= s;
data[3] *= s;
data[4] *= s;
data[5] *= s;
data[6] *= s;
data[7] *= s;
data[8] *= s;
data[9] *= s;
data[10] *= s;
data[11] *= s;
return *this;
}
/// Elementwise divide operator (3-by-4)
CUTLASS_HOST_DEVICE
Matrix divide(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] / rhs.data[0];
result.data[1] = data[1] / rhs.data[1];
result.data[2] = data[2] / rhs.data[2];
result.data[3] = data[3] / rhs.data[3];
result.data[4] = data[4] / rhs.data[4];
result.data[5] = data[5] / rhs.data[5];
result.data[6] = data[6] / rhs.data[6];
result.data[7] = data[7] / rhs.data[7];
result.data[8] = data[8] / rhs.data[8];
result.data[9] = data[9] / rhs.data[9];
result.data[10] = data[10] / rhs.data[10];
result.data[11] = data[11] / rhs.data[11];
return result;
}
/// Scalar divide operator (3-by-4)
CUTLASS_HOST_DEVICE
Matrix divide(Element const &s) const {
Matrix result;
result.data[0] = data[0] / s;
result.data[1] = data[1] / s;
result.data[2] = data[2] / s;
result.data[3] = data[3] / s;
result.data[4] = data[4] / s;
result.data[5] = data[5] / s;
result.data[6] = data[6] / s;
result.data[7] = data[7] / s;
result.data[8] = data[8] / s;
result.data[9] = data[9] / s;
result.data[10] = data[10] / s;
result.data[11] = data[11] / s;
return result;
}
/// Scalar divide operator (3-by-4)
CUTLASS_HOST_DEVICE
Matrix operator /(Element const &s) const {
return divide(s);
}
/// Scalar divide operator (3-by-4)
CUTLASS_HOST_DEVICE
Matrix & operator /=(Element const &s) {
data[0] /= s;
data[1] /= s;
data[2] /= s;
data[3] /= s;
data[4] /= s;
data[5] /= s;
data[6] /= s;
data[7] /= s;
data[8] /= s;
data[9] /= s;
data[10] /= s;
data[11] /= s;
return *this;
}
/// Elementwise divide operator (3-by-4)
CUTLASS_HOST_DEVICE
Matrix operator /(Matrix const &rhs) const {
return divide(rhs);
}
/// Elementwise divide operator (3-by-4)
CUTLASS_HOST_DEVICE
Matrix & operator /=(Matrix const &rhs) {
data[0] /= rhs.data[0];
data[1] /= rhs.data[1];
data[2] /= rhs.data[2];
data[3] /= rhs.data[3];
data[4] /= rhs.data[4];
data[5] /= rhs.data[5];
data[6] /= rhs.data[6];
data[7] /= rhs.data[7];
data[8] /= rhs.data[8];
data[9] /= rhs.data[9];
data[10] /= rhs.data[10];
data[11] /= rhs.data[11];
return *this;
}
/// Negates each element of the matrix
CUTLASS_HOST_DEVICE
Matrix operator-() const {
Matrix m;
m.data[0] = -m.data[0];
m.data[1] = -m.data[1];
m.data[2] = -m.data[2];
m.data[3] = -m.data[3];
m.data[4] = -m.data[4];
m.data[5] = -m.data[5];
m.data[6] = -m.data[6];
m.data[7] = -m.data[7];
m.data[8] = -m.data[8];
m.data[9] = -m.data[9];
m.data[10] = -m.data[10];
m.data[11] = -m.data[11];
return m;
}
/// Matrix product of size 3-by-1-by-4
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 1> product(
Matrix<Element, 4, 1> const &rhs,
Matrix<Element, 3, 1> accum = Matrix<Element, 3, 1>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[4] * rhs.data[0];
accum.data[2] += data[8] * rhs.data[0];
// k=1
accum.data[0] += data[1] * rhs.data[1];
accum.data[1] += data[5] * rhs.data[1];
accum.data[2] += data[9] * rhs.data[1];
// k=2
accum.data[0] += data[2] * rhs.data[2];
accum.data[1] += data[6] * rhs.data[2];
accum.data[2] += data[10] * rhs.data[2];
// k=3
accum.data[0] += data[3] * rhs.data[3];
accum.data[1] += data[7] * rhs.data[3];
accum.data[2] += data[11] * rhs.data[3];
return accum;
}
/// Matrix product of size 3-by-1-by-4
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 1> operator*(Matrix<Element, 4, 1> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 3-by-2-by-4
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 2> product(
Matrix<Element, 4, 2> const &rhs,
Matrix<Element, 3, 2> accum = Matrix<Element, 3, 2>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[4] * rhs.data[0];
accum.data[3] += data[4] * rhs.data[1];
accum.data[4] += data[8] * rhs.data[0];
accum.data[5] += data[8] * rhs.data[1];
// k=1
accum.data[0] += data[1] * rhs.data[2];
accum.data[1] += data[1] * rhs.data[3];
accum.data[2] += data[5] * rhs.data[2];
accum.data[3] += data[5] * rhs.data[3];
accum.data[4] += data[9] * rhs.data[2];
accum.data[5] += data[9] * rhs.data[3];
// k=2
accum.data[0] += data[2] * rhs.data[4];
accum.data[1] += data[2] * rhs.data[5];
accum.data[2] += data[6] * rhs.data[4];
accum.data[3] += data[6] * rhs.data[5];
accum.data[4] += data[10] * rhs.data[4];
accum.data[5] += data[10] * rhs.data[5];
// k=3
accum.data[0] += data[3] * rhs.data[6];
accum.data[1] += data[3] * rhs.data[7];
accum.data[2] += data[7] * rhs.data[6];
accum.data[3] += data[7] * rhs.data[7];
accum.data[4] += data[11] * rhs.data[6];
accum.data[5] += data[11] * rhs.data[7];
return accum;
}
/// Matrix product of size 3-by-2-by-4
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 2> operator*(Matrix<Element, 4, 2> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 3-by-3-by-4
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 3> product(
Matrix<Element, 4, 3> const &rhs,
Matrix<Element, 3, 3> accum = Matrix<Element, 3, 3>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[0] * rhs.data[2];
accum.data[3] += data[4] * rhs.data[0];
accum.data[4] += data[4] * rhs.data[1];
accum.data[5] += data[4] * rhs.data[2];
accum.data[6] += data[8] * rhs.data[0];
accum.data[7] += data[8] * rhs.data[1];
accum.data[8] += data[8] * rhs.data[2];
// k=1
accum.data[0] += data[1] * rhs.data[3];
accum.data[1] += data[1] * rhs.data[4];
accum.data[2] += data[1] * rhs.data[5];
accum.data[3] += data[5] * rhs.data[3];
accum.data[4] += data[5] * rhs.data[4];
accum.data[5] += data[5] * rhs.data[5];
accum.data[6] += data[9] * rhs.data[3];
accum.data[7] += data[9] * rhs.data[4];
accum.data[8] += data[9] * rhs.data[5];
// k=2
accum.data[0] += data[2] * rhs.data[6];
accum.data[1] += data[2] * rhs.data[7];
accum.data[2] += data[2] * rhs.data[8];
accum.data[3] += data[6] * rhs.data[6];
accum.data[4] += data[6] * rhs.data[7];
accum.data[5] += data[6] * rhs.data[8];
accum.data[6] += data[10] * rhs.data[6];
accum.data[7] += data[10] * rhs.data[7];
accum.data[8] += data[10] * rhs.data[8];
// k=3
accum.data[0] += data[3] * rhs.data[9];
accum.data[1] += data[3] * rhs.data[10];
accum.data[2] += data[3] * rhs.data[11];
accum.data[3] += data[7] * rhs.data[9];
accum.data[4] += data[7] * rhs.data[10];
accum.data[5] += data[7] * rhs.data[11];
accum.data[6] += data[11] * rhs.data[9];
accum.data[7] += data[11] * rhs.data[10];
accum.data[8] += data[11] * rhs.data[11];
return accum;
}
/// Matrix product of size 3-by-3-by-4
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 3> operator*(Matrix<Element, 4, 3> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 3-by-4-by-4
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 4> product(
Matrix<Element, 4, 4> const &rhs,
Matrix<Element, 3, 4> accum = Matrix<Element, 3, 4>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[0] * rhs.data[2];
accum.data[3] += data[0] * rhs.data[3];
accum.data[4] += data[4] * rhs.data[0];
accum.data[5] += data[4] * rhs.data[1];
accum.data[6] += data[4] * rhs.data[2];
accum.data[7] += data[4] * rhs.data[3];
accum.data[8] += data[8] * rhs.data[0];
accum.data[9] += data[8] * rhs.data[1];
accum.data[10] += data[8] * rhs.data[2];
accum.data[11] += data[8] * rhs.data[3];
// k=1
accum.data[0] += data[1] * rhs.data[4];
accum.data[1] += data[1] * rhs.data[5];
accum.data[2] += data[1] * rhs.data[6];
accum.data[3] += data[1] * rhs.data[7];
accum.data[4] += data[5] * rhs.data[4];
accum.data[5] += data[5] * rhs.data[5];
accum.data[6] += data[5] * rhs.data[6];
accum.data[7] += data[5] * rhs.data[7];
accum.data[8] += data[9] * rhs.data[4];
accum.data[9] += data[9] * rhs.data[5];
accum.data[10] += data[9] * rhs.data[6];
accum.data[11] += data[9] * rhs.data[7];
// k=2
accum.data[0] += data[2] * rhs.data[8];
accum.data[1] += data[2] * rhs.data[9];
accum.data[2] += data[2] * rhs.data[10];
accum.data[3] += data[2] * rhs.data[11];
accum.data[4] += data[6] * rhs.data[8];
accum.data[5] += data[6] * rhs.data[9];
accum.data[6] += data[6] * rhs.data[10];
accum.data[7] += data[6] * rhs.data[11];
accum.data[8] += data[10] * rhs.data[8];
accum.data[9] += data[10] * rhs.data[9];
accum.data[10] += data[10] * rhs.data[10];
accum.data[11] += data[10] * rhs.data[11];
// k=3
accum.data[0] += data[3] * rhs.data[12];
accum.data[1] += data[3] * rhs.data[13];
accum.data[2] += data[3] * rhs.data[14];
accum.data[3] += data[3] * rhs.data[15];
accum.data[4] += data[7] * rhs.data[12];
accum.data[5] += data[7] * rhs.data[13];
accum.data[6] += data[7] * rhs.data[14];
accum.data[7] += data[7] * rhs.data[15];
accum.data[8] += data[11] * rhs.data[12];
accum.data[9] += data[11] * rhs.data[13];
accum.data[10] += data[11] * rhs.data[14];
accum.data[11] += data[11] * rhs.data[15];
return accum;
}
/// Matrix product of size 3-by-4-by-4
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 4> operator*(Matrix<Element, 4, 4> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 3-by-4-by-4
CUTLASS_HOST_DEVICE
Matrix & operator*=(Matrix<Element, 4, 4> const &rhs) {
*this = product(rhs);
return *this;
}
/// Returns the sum of elements
CUTLASS_HOST_DEVICE
Element sum(Element accum = Element()) const {
accum += data[0];
accum += data[1];
accum += data[2];
accum += data[3];
accum += data[4];
accum += data[5];
accum += data[6];
accum += data[7];
accum += data[8];
accum += data[9];
accum += data[10];
accum += data[11];
return accum;
}
/// Returns the sum of squared elements
CUTLASS_HOST_DEVICE
Element norm(Element accum = Element()) const {
accum += data[0] * data[0];
accum += data[1] * data[1];
accum += data[2] * data[2];
accum += data[3] * data[3];
accum += data[4] * data[4];
accum += data[5] * data[5];
accum += data[6] * data[6];
accum += data[7] * data[7];
accum += data[8] * data[8];
accum += data[9] * data[9];
accum += data[10] * data[10];
accum += data[11] * data[11];
return accum;
}
/// Returns square root of the norm
CUTLASS_HOST_DEVICE
Element magnitude() const {
return fast_sqrt(norm());
}
/// Returns the sum of diagonal elements
CUTLASS_HOST_DEVICE
Element trace(Element accum = Element()) const {
accum += data[0];
accum += data[5];
accum += data[10];
return accum;
}
};
/// Template alias for 3-by-4 matrix
template <typename Element>
using Matrix3x4 = Matrix<Element, 3, 4>;
/// Free funciton to infer element type from template arguments
template <typename Element>
CUTLASS_HOST_DEVICE Matrix3x4<Element> make_Matrix3x4(
Element _0_0, Element _0_1, Element _0_2, Element _0_3,
Element _1_0, Element _1_1, Element _1_2, Element _1_3,
Element _2_0, Element _2_1, Element _2_2, Element _2_3
) {
return Matrix3x4<Element>(
_0_0, _0_1, _0_2, _0_3,
_1_0, _1_1, _1_2, _1_3,
_2_0, _2_1, _2_2, _2_3
);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// 4-by-1 matrix template class definition
template <typename Element_>
struct Matrix<Element_, 4, 1> {
//
// Type definitions
//
/// Element data type
using Element = Element_;
/// Number of rows in matrix
static int const kRows = 4;
/// Number of columns in matrix
static int const kColumns = 1;
/// Layout of matrix in underlying array
using Layout = layout::RowMajor;
/// Number of elements in matrix
static int const kCount = 4;
//
// Data members
//
/// Elements of the matrix in row-major layout
Array<Element, kCount> data;
//
// Methods
//
/// Constructs a zero matrix
CUTLASS_HOST_DEVICE
Matrix() {
data.clear();
}
/// Copy constructor for a 4-by-1 matrix
CUTLASS_HOST_DEVICE
Matrix(Matrix const &rhs) {
data = rhs.data;
}
/// Constucts a 4-by-1 matrix from scalar elements
CUTLASS_HOST_DEVICE
Matrix(
Element _0_0,
Element _1_0,
Element _2_0,
Element _3_0
) {
data[0] = _0_0;
data[1] = _1_0;
data[2] = _2_0;
data[3] = _3_0;
}
/// Constructs a matrix from a uniform element
CUTLASS_HOST_DEVICE
static Matrix uniform(Element s) {
Matrix m;
m.data[0] = s;
m.data[1] = s;
m.data[2] = s;
m.data[3] = s;
return m;
}
/// Constructs a matrix from a uniform element 1
CUTLASS_HOST_DEVICE
static Matrix ones() {
return uniform(Element(1));
}
/// Constructs a matrix from a uniform element 0
CUTLASS_HOST_DEVICE
static Matrix zero() {
return Matrix();
}
/// Returns a transposed matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 4> transpose() const {
Matrix<Element, 1, 4> mt;
mt.data[0] = data[0];
mt.data[1] = data[1];
mt.data[2] = data[2];
mt.data[3] = data[3];
return mt;
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element at(int i, int j) const {
return data[i * 4 + j];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & at(int i, int j) {
return data[i * 4 + j];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element at(Coord<2> const &coord) const {
return at(coord[0], coord[1]);
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & at(Coord<2> const &coord) {
return at(coord[0], coord[1]);
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element &at(int offset) {
return data[offset];
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element at(int offset) const {
return data[offset];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element operator[](Coord<2> const &coord) const {
return at(coord[0], coord[1]);
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & operator[](Coord<2> const &coord) {
return at(coord[0], coord[1]);
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element & operator[](int offset) {
return data[offset];
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element operator[](int offset) const {
return data[offset];
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 1> slice_2x1(int i = 0, int j = 0) const {
Matrix<Element, 2, 1> m;
m.data[0] = data[i * 1 + j + 0];
m.data[1] = data[i * 1 + j + 1];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_2x1(Matrix<Element, 2, 1> const &m, int i = 0, int j = 0) {
data[i * 1 + j + 0] = m.data[0];
data[i * 1 + j + 1] = m.data[1];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 1> slice_3x1(int i = 0, int j = 0) const {
Matrix<Element, 3, 1> m;
m.data[0] = data[i * 1 + j + 0];
m.data[1] = data[i * 1 + j + 1];
m.data[2] = data[i * 1 + j + 2];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_3x1(Matrix<Element, 3, 1> const &m, int i = 0, int j = 0) {
data[i * 1 + j + 0] = m.data[0];
data[i * 1 + j + 1] = m.data[1];
data[i * 1 + j + 2] = m.data[2];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 1> slice_4x1(int i = 0, int j = 0) const {
Matrix<Element, 4, 1> m;
m.data[0] = data[i * 1 + j + 0];
m.data[1] = data[i * 1 + j + 1];
m.data[2] = data[i * 1 + j + 2];
m.data[3] = data[i * 1 + j + 3];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_4x1(Matrix<Element, 4, 1> const &m, int i = 0, int j = 0) {
data[i * 1 + j + 0] = m.data[0];
data[i * 1 + j + 1] = m.data[1];
data[i * 1 + j + 2] = m.data[2];
data[i * 1 + j + 3] = m.data[3];
return *this;
}
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 1> column(int j) const {
return slice_4x1(0, j);
}
CUTLASS_HOST_DEVICE
Matrix &set_column(Matrix<Element, 4, 1> const &v, int j =0) {
return set_slice_4x1(v, 0, j);
}
/// Concatenates this matrix with a a 4-by-1 matrix to form a 4-by-2 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 2> hcat(Matrix<Element, 4, 1> const & rhs) const {
return Matrix<Element, 4, 2>::hcat(*this, rhs);
}
/// Concatenates this matrix with a a 4-by-2 matrix to form a 4-by-3 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 3> hcat(Matrix<Element, 4, 2> const & rhs) const {
return Matrix<Element, 4, 3>::hcat(*this, rhs);
}
/// Concatenates this matrix with a a 4-by-3 matrix to form a 4-by-4 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 4> hcat(Matrix<Element, 4, 3> const & rhs) const {
return Matrix<Element, 4, 4>::hcat(*this, rhs);
}
/// Forms a 4-by-1 matrix by vertically concatenating an Element with a 3-by-1 matrix
CUTLASS_HOST_DEVICE
static Matrix vcat(Element upper, Matrix<Element, 3, 1> const & lower) {
return Matrix(
upper
, lower.at(0, 0)
, lower.at(1, 0)
, lower.at(2, 0));
}
/// Forms a 4-by-1 matrix by vertically concatenating a 2-by-1 matrix with a 2-by-1 matrix
CUTLASS_HOST_DEVICE
static Matrix vcat(Matrix<Element, 2, 1> const & upper, Matrix<Element, 2, 1> const & lower) {
return Matrix(
upper.at(0, 0)
, upper.at(1, 0)
, lower.at(0, 0)
, lower.at(1, 0));
}
/// Forms a 4-by-1 matrix by vertically concatenating a 3-by-1 matrix with an Element
CUTLASS_HOST_DEVICE
static Matrix vcat(Matrix<Element, 3, 1> const & upper, Element lower) {
return Matrix(
upper.at(0, 0)
, upper.at(1, 0)
, upper.at(2, 0)
, lower);
}
/// Elementwise add operator (4-by-1)
CUTLASS_HOST_DEVICE
Matrix add(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] + rhs.data[0];
result.data[1] = data[1] + rhs.data[1];
result.data[2] = data[2] + rhs.data[2];
result.data[3] = data[3] + rhs.data[3];
return result;
}
/// Elementwise add operator (4-by-1)
CUTLASS_HOST_DEVICE
Matrix operator +(Matrix const &rhs) const {
return add(rhs);
}
/// Elementwise add operator (4-by-1)
CUTLASS_HOST_DEVICE
Matrix & operator +=(Matrix const &rhs) {
data[0] += rhs.data[0];
data[1] += rhs.data[1];
data[2] += rhs.data[2];
data[3] += rhs.data[3];
return *this;
}
/// Elementwise subtract operator (4-by-1)
CUTLASS_HOST_DEVICE
Matrix subtract(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] - rhs.data[0];
result.data[1] = data[1] - rhs.data[1];
result.data[2] = data[2] - rhs.data[2];
result.data[3] = data[3] - rhs.data[3];
return result;
}
/// Elementwise subtract operator (4-by-1)
CUTLASS_HOST_DEVICE
Matrix operator -(Matrix const &rhs) const {
return subtract(rhs);
}
/// Elementwise subtract operator (4-by-1)
CUTLASS_HOST_DEVICE
Matrix & operator -=(Matrix const &rhs) {
data[0] -= rhs.data[0];
data[1] -= rhs.data[1];
data[2] -= rhs.data[2];
data[3] -= rhs.data[3];
return *this;
}
/// Elementwise multiply operator (4-by-1)
CUTLASS_HOST_DEVICE
Matrix multiply(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] * rhs.data[0];
result.data[1] = data[1] * rhs.data[1];
result.data[2] = data[2] * rhs.data[2];
result.data[3] = data[3] * rhs.data[3];
return result;
}
/// Scalar multiply operator (4-by-1)
CUTLASS_HOST_DEVICE
Matrix multiply(Element const &s) const {
Matrix result;
result.data[0] = data[0] * s;
result.data[1] = data[1] * s;
result.data[2] = data[2] * s;
result.data[3] = data[3] * s;
return result;
}
/// Scalar multiply operator (4-by-1)
CUTLASS_HOST_DEVICE
Matrix operator *(Element const &s) const {
return multiply(s);
}
/// Scalar multiply operator (4-by-1)
CUTLASS_HOST_DEVICE
Matrix & operator *=(Element const &s) {
data[0] *= s;
data[1] *= s;
data[2] *= s;
data[3] *= s;
return *this;
}
/// Elementwise divide operator (4-by-1)
CUTLASS_HOST_DEVICE
Matrix divide(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] / rhs.data[0];
result.data[1] = data[1] / rhs.data[1];
result.data[2] = data[2] / rhs.data[2];
result.data[3] = data[3] / rhs.data[3];
return result;
}
/// Scalar divide operator (4-by-1)
CUTLASS_HOST_DEVICE
Matrix divide(Element const &s) const {
Matrix result;
result.data[0] = data[0] / s;
result.data[1] = data[1] / s;
result.data[2] = data[2] / s;
result.data[3] = data[3] / s;
return result;
}
/// Scalar divide operator (4-by-1)
CUTLASS_HOST_DEVICE
Matrix operator /(Element const &s) const {
return divide(s);
}
/// Scalar divide operator (4-by-1)
CUTLASS_HOST_DEVICE
Matrix & operator /=(Element const &s) {
data[0] /= s;
data[1] /= s;
data[2] /= s;
data[3] /= s;
return *this;
}
/// Elementwise divide operator (4-by-1)
CUTLASS_HOST_DEVICE
Matrix operator /(Matrix const &rhs) const {
return divide(rhs);
}
/// Elementwise divide operator (4-by-1)
CUTLASS_HOST_DEVICE
Matrix & operator /=(Matrix const &rhs) {
data[0] /= rhs.data[0];
data[1] /= rhs.data[1];
data[2] /= rhs.data[2];
data[3] /= rhs.data[3];
return *this;
}
/// Negates each element of the matrix
CUTLASS_HOST_DEVICE
Matrix operator-() const {
Matrix m;
m.data[0] = -m.data[0];
m.data[1] = -m.data[1];
m.data[2] = -m.data[2];
m.data[3] = -m.data[3];
return m;
}
/// Matrix product of size 4-by-1-by-1
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 1> product(
Matrix<Element, 1, 1> const &rhs,
Matrix<Element, 4, 1> accum = Matrix<Element, 4, 1>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[1] * rhs.data[0];
accum.data[2] += data[2] * rhs.data[0];
accum.data[3] += data[3] * rhs.data[0];
return accum;
}
/// Matrix product of size 4-by-1-by-1
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 1> operator*(Matrix<Element, 1, 1> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 4-by-1-by-1
CUTLASS_HOST_DEVICE
Matrix & operator*=(Matrix<Element, 1, 1> const &rhs) {
*this = product(rhs);
return *this;
}
/// Matrix product of size 4-by-2-by-1
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 2> product(
Matrix<Element, 1, 2> const &rhs,
Matrix<Element, 4, 2> accum = Matrix<Element, 4, 2>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[1] * rhs.data[0];
accum.data[3] += data[1] * rhs.data[1];
accum.data[4] += data[2] * rhs.data[0];
accum.data[5] += data[2] * rhs.data[1];
accum.data[6] += data[3] * rhs.data[0];
accum.data[7] += data[3] * rhs.data[1];
return accum;
}
/// Matrix product of size 4-by-2-by-1
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 2> operator*(Matrix<Element, 1, 2> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 4-by-3-by-1
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 3> product(
Matrix<Element, 1, 3> const &rhs,
Matrix<Element, 4, 3> accum = Matrix<Element, 4, 3>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[0] * rhs.data[2];
accum.data[3] += data[1] * rhs.data[0];
accum.data[4] += data[1] * rhs.data[1];
accum.data[5] += data[1] * rhs.data[2];
accum.data[6] += data[2] * rhs.data[0];
accum.data[7] += data[2] * rhs.data[1];
accum.data[8] += data[2] * rhs.data[2];
accum.data[9] += data[3] * rhs.data[0];
accum.data[10] += data[3] * rhs.data[1];
accum.data[11] += data[3] * rhs.data[2];
return accum;
}
/// Matrix product of size 4-by-3-by-1
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 3> operator*(Matrix<Element, 1, 3> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 4-by-4-by-1
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 4> product(
Matrix<Element, 1, 4> const &rhs,
Matrix<Element, 4, 4> accum = Matrix<Element, 4, 4>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[0] * rhs.data[2];
accum.data[3] += data[0] * rhs.data[3];
accum.data[4] += data[1] * rhs.data[0];
accum.data[5] += data[1] * rhs.data[1];
accum.data[6] += data[1] * rhs.data[2];
accum.data[7] += data[1] * rhs.data[3];
accum.data[8] += data[2] * rhs.data[0];
accum.data[9] += data[2] * rhs.data[1];
accum.data[10] += data[2] * rhs.data[2];
accum.data[11] += data[2] * rhs.data[3];
accum.data[12] += data[3] * rhs.data[0];
accum.data[13] += data[3] * rhs.data[1];
accum.data[14] += data[3] * rhs.data[2];
accum.data[15] += data[3] * rhs.data[3];
return accum;
}
/// Matrix product of size 4-by-4-by-1
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 4> operator*(Matrix<Element, 1, 4> const &rhs) const {
return product(rhs);
}
/// Dot product of vectors with extent 4
CUTLASS_HOST_DEVICE
Element dot(Matrix<Element, 4, 1> const &rhs, Element accum = Element()) const {
accum += data[0] * rhs.data[0];
accum += data[1] * rhs.data[1];
accum += data[2] * rhs.data[2];
accum += data[3] * rhs.data[3];
return accum;
}
/// Dot product of vectors with extent 4
CUTLASS_HOST_DEVICE
Element dot(Matrix<Element, 1, 4> const &rhs, Element accum = Element()) const {
accum += data[0] * rhs.data[0];
accum += data[1] * rhs.data[1];
accum += data[2] * rhs.data[2];
accum += data[3] * rhs.data[3];
return accum;
}
/// Returns the sum of elements
CUTLASS_HOST_DEVICE
Element sum(Element accum = Element()) const {
accum += data[0];
accum += data[1];
accum += data[2];
accum += data[3];
return accum;
}
/// Returns the sum of squared elements
CUTLASS_HOST_DEVICE
Element norm(Element accum = Element()) const {
accum += data[0] * data[0];
accum += data[1] * data[1];
accum += data[2] * data[2];
accum += data[3] * data[3];
return accum;
}
/// Returns square root of the norm
CUTLASS_HOST_DEVICE
Element magnitude() const {
return fast_sqrt(norm());
}
/// Returns the sum of diagonal elements
CUTLASS_HOST_DEVICE
Element trace(Element accum = Element()) const {
accum += data[0];
return accum;
}
};
/// Template alias for 4-by-1 matrix
template <typename Element>
using Matrix4x1 = Matrix<Element, 4, 1>;
/// Free funciton to infer element type from template arguments
template <typename Element>
CUTLASS_HOST_DEVICE Matrix4x1<Element> make_Matrix4x1(
Element _0_0,
Element _1_0,
Element _2_0,
Element _3_0
) {
return Matrix4x1<Element>(
_0_0,
_1_0,
_2_0,
_3_0
);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// 4-by-2 matrix template class definition
template <typename Element_>
struct Matrix<Element_, 4, 2> {
//
// Type definitions
//
/// Element data type
using Element = Element_;
/// Number of rows in matrix
static int const kRows = 4;
/// Number of columns in matrix
static int const kColumns = 2;
/// Layout of matrix in underlying array
using Layout = layout::RowMajor;
/// Number of elements in matrix
static int const kCount = 8;
//
// Data members
//
/// Elements of the matrix in row-major layout
Array<Element, kCount> data;
//
// Methods
//
/// Constructs a zero matrix
CUTLASS_HOST_DEVICE
Matrix() {
data.clear();
}
/// Copy constructor for a 4-by-2 matrix
CUTLASS_HOST_DEVICE
Matrix(Matrix const &rhs) {
data = rhs.data;
}
/// Constucts a 4-by-2 matrix from scalar elements
CUTLASS_HOST_DEVICE
Matrix(
Element _0_0, Element _0_1,
Element _1_0, Element _1_1,
Element _2_0, Element _2_1,
Element _3_0, Element _3_1
) {
data[0] = _0_0; data[1] = _0_1;
data[2] = _1_0; data[3] = _1_1;
data[4] = _2_0; data[5] = _2_1;
data[6] = _3_0; data[7] = _3_1;
}
/// Constucts a 4-by-2 matrix from row vectors
CUTLASS_HOST_DEVICE
Matrix(
Matrix<Element, 1, 2> const &row_0,
Matrix<Element, 1, 2> const &row_1,
Matrix<Element, 1, 2> const &row_2,
Matrix<Element, 1, 2> const &row_3
) {
data[0] = row_0.data[0];
data[1] = row_0.data[1];
data[2] = row_1.data[0];
data[3] = row_1.data[1];
data[4] = row_2.data[0];
data[5] = row_2.data[1];
data[6] = row_3.data[0];
data[7] = row_3.data[1];
}
/// Static method to construct a 4-by-2 matrix from column vectors
CUTLASS_HOST_DEVICE
static Matrix from_columns(
Matrix<Element, 2, 1> const &column_0,
Matrix<Element, 2, 1> const &column_1
) {
Matrix result;
result.data[0] = column_0.data[0];
result.data[1] = column_1.data[0];
result.data[2] = column_0.data[1];
result.data[3] = column_1.data[1];
result.data[4] = column_0.data[2];
result.data[5] = column_1.data[2];
result.data[6] = column_0.data[3];
result.data[7] = column_1.data[3];
return result;
}
/// Constructs a matrix from a uniform element
CUTLASS_HOST_DEVICE
static Matrix uniform(Element s) {
Matrix m;
m.data[0] = s;
m.data[1] = s;
m.data[2] = s;
m.data[3] = s;
m.data[4] = s;
m.data[5] = s;
m.data[6] = s;
m.data[7] = s;
return m;
}
/// Constructs a matrix from a uniform element 1
CUTLASS_HOST_DEVICE
static Matrix ones() {
return uniform(Element(1));
}
/// Constructs a matrix from a uniform element 0
CUTLASS_HOST_DEVICE
static Matrix zero() {
return Matrix();
}
/// Constructs a matrix from elements along its diagonal
CUTLASS_HOST_DEVICE
static Matrix from_diagonal(Matrix<Element, 2, 1> const &diag) {
Matrix m;
m.data[0] = diag.data[0];
m.data[5] = diag.data[1];
m.data[10] = diag.data[2];
m.data[15] = diag.data[3];
return m;
}
/// Constructs a matrix from elements along its diagonal
CUTLASS_HOST_DEVICE
static Matrix from_diagonal(Matrix<Element, 1, 2> const &diag) {
Matrix m;
m.data[0] = diag.data[0];
m.data[5] = diag.data[1];
m.data[10] = diag.data[2];
m.data[15] = diag.data[3];
return m;
}
/// Gets an array of diagonal elements
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 1> diagonal() const {
Matrix<Element, 2, 1> diag;
diag.data[0] = data[0];
diag.data[1] = data[5];
diag.data[2] = data[10];
diag.data[3] = data[15];
return diag;
}
/// Returns a transposed matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 4> transpose() const {
Matrix<Element, 2, 4> mt;
mt.data[0] = data[0];
mt.data[4] = data[1];
mt.data[1] = data[2];
mt.data[5] = data[3];
mt.data[2] = data[4];
mt.data[6] = data[5];
mt.data[3] = data[6];
mt.data[7] = data[7];
return mt;
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element at(int i, int j) const {
return data[i * 4 + j];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & at(int i, int j) {
return data[i * 4 + j];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element at(Coord<2> const &coord) const {
return at(coord[0], coord[1]);
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & at(Coord<2> const &coord) {
return at(coord[0], coord[1]);
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element &at(int offset) {
return data[offset];
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element at(int offset) const {
return data[offset];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element operator[](Coord<2> const &coord) const {
return at(coord[0], coord[1]);
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & operator[](Coord<2> const &coord) {
return at(coord[0], coord[1]);
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element & operator[](int offset) {
return data[offset];
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element operator[](int offset) const {
return data[offset];
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 2> slice_1x2(int i = 0, int j = 0) const {
Matrix<Element, 1, 2> m;
m.data[0] = data[i * 2 + j + 0];
m.data[1] = data[i * 2 + j + 1];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_1x2(Matrix<Element, 1, 2> const &m, int i = 0, int j = 0) {
data[i * 2 + j + 0] = m.data[0];
data[i * 2 + j + 1] = m.data[1];
return *this;
}
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 2> row(int i) const {
return slice_1x2(i, 0);
}
CUTLASS_HOST_DEVICE
Matrix &set_row(Matrix<Element, 1, 2> const &v, int i = 0) {
return set_slice_1x2(v, i, 0);
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 1> slice_2x1(int i = 0, int j = 0) const {
Matrix<Element, 2, 1> m;
m.data[0] = data[i * 2 + j + 0];
m.data[1] = data[i * 2 + j + 2];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_2x1(Matrix<Element, 2, 1> const &m, int i = 0, int j = 0) {
data[i * 2 + j + 0] = m.data[0];
data[i * 2 + j + 2] = m.data[1];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 2> slice_2x2(int i = 0, int j = 0) const {
Matrix<Element, 2, 2> m;
m.data[0] = data[i * 2 + j + 0];
m.data[1] = data[i * 2 + j + 1];
m.data[2] = data[i * 2 + j + 2];
m.data[3] = data[i * 2 + j + 3];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_2x2(Matrix<Element, 2, 2> const &m, int i = 0, int j = 0) {
data[i * 2 + j + 0] = m.data[0];
data[i * 2 + j + 1] = m.data[1];
data[i * 2 + j + 2] = m.data[2];
data[i * 2 + j + 3] = m.data[3];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 1> slice_3x1(int i = 0, int j = 0) const {
Matrix<Element, 3, 1> m;
m.data[0] = data[i * 2 + j + 0];
m.data[1] = data[i * 2 + j + 2];
m.data[2] = data[i * 2 + j + 4];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_3x1(Matrix<Element, 3, 1> const &m, int i = 0, int j = 0) {
data[i * 2 + j + 0] = m.data[0];
data[i * 2 + j + 2] = m.data[1];
data[i * 2 + j + 4] = m.data[2];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 2> slice_3x2(int i = 0, int j = 0) const {
Matrix<Element, 3, 2> m;
m.data[0] = data[i * 2 + j + 0];
m.data[1] = data[i * 2 + j + 1];
m.data[2] = data[i * 2 + j + 2];
m.data[3] = data[i * 2 + j + 3];
m.data[4] = data[i * 2 + j + 4];
m.data[5] = data[i * 2 + j + 5];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_3x2(Matrix<Element, 3, 2> const &m, int i = 0, int j = 0) {
data[i * 2 + j + 0] = m.data[0];
data[i * 2 + j + 1] = m.data[1];
data[i * 2 + j + 2] = m.data[2];
data[i * 2 + j + 3] = m.data[3];
data[i * 2 + j + 4] = m.data[4];
data[i * 2 + j + 5] = m.data[5];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 1> slice_4x1(int i = 0, int j = 0) const {
Matrix<Element, 4, 1> m;
m.data[0] = data[i * 2 + j + 0];
m.data[1] = data[i * 2 + j + 2];
m.data[2] = data[i * 2 + j + 4];
m.data[3] = data[i * 2 + j + 6];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_4x1(Matrix<Element, 4, 1> const &m, int i = 0, int j = 0) {
data[i * 2 + j + 0] = m.data[0];
data[i * 2 + j + 2] = m.data[1];
data[i * 2 + j + 4] = m.data[2];
data[i * 2 + j + 6] = m.data[3];
return *this;
}
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 1> column(int j) const {
return slice_4x1(0, j);
}
CUTLASS_HOST_DEVICE
Matrix &set_column(Matrix<Element, 4, 1> const &v, int j =0) {
return set_slice_4x1(v, 0, j);
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 2> slice_4x2(int i = 0, int j = 0) const {
Matrix<Element, 4, 2> m;
m.data[0] = data[i * 2 + j + 0];
m.data[1] = data[i * 2 + j + 1];
m.data[2] = data[i * 2 + j + 2];
m.data[3] = data[i * 2 + j + 3];
m.data[4] = data[i * 2 + j + 4];
m.data[5] = data[i * 2 + j + 5];
m.data[6] = data[i * 2 + j + 6];
m.data[7] = data[i * 2 + j + 7];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_4x2(Matrix<Element, 4, 2> const &m, int i = 0, int j = 0) {
data[i * 2 + j + 0] = m.data[0];
data[i * 2 + j + 1] = m.data[1];
data[i * 2 + j + 2] = m.data[2];
data[i * 2 + j + 3] = m.data[3];
data[i * 2 + j + 4] = m.data[4];
data[i * 2 + j + 5] = m.data[5];
data[i * 2 + j + 6] = m.data[6];
data[i * 2 + j + 7] = m.data[7];
return *this;
}
/// Forms a 4-by-2 matrix by horizontally concatenating a 4-by-1 matrix with a 4-by-1 matrix
CUTLASS_HOST_DEVICE
static Matrix hcat(Matrix<Element, 4, 1> const & lhs, Matrix<Element, 4, 1> const & rhs) {
return Matrix(
lhs.at(0, 0), rhs.at(0, 0)
, lhs.at(1, 0), rhs.at(1, 0)
, lhs.at(2, 0), rhs.at(2, 0)
, lhs.at(3, 0), rhs.at(3, 0));
}
/// Concatenates this matrix with a a 4-by-1 matrix to form a 4-by-3 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 3> hcat(Matrix<Element, 4, 1> const & rhs) const {
return Matrix<Element, 4, 3>::hcat(*this, rhs);
}
/// Concatenates this matrix with a a 4-by-2 matrix to form a 4-by-4 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 4> hcat(Matrix<Element, 4, 2> const & rhs) const {
return Matrix<Element, 4, 4>::hcat(*this, rhs);
}
/// Forms a 4-by-2 matrix by vertically concatenating a 1-by-2 matrix with a 3-by-2 matrix
CUTLASS_HOST_DEVICE
static Matrix vcat(Matrix<Element, 1, 2> const & upper, Matrix<Element, 3, 2> const & lower) {
return Matrix(
upper.at(0, 0), upper.at(0, 1)
, lower.at(0, 0), lower.at(0, 1)
, lower.at(1, 0), lower.at(1, 1)
, lower.at(2, 0), lower.at(2, 1));
}
/// Forms a 4-by-2 matrix by vertically concatenating a 2-by-2 matrix with a 2-by-2 matrix
CUTLASS_HOST_DEVICE
static Matrix vcat(Matrix<Element, 2, 2> const & upper, Matrix<Element, 2, 2> const & lower) {
return Matrix(
upper.at(0, 0), upper.at(0, 1)
, upper.at(1, 0), upper.at(1, 1)
, lower.at(0, 0), lower.at(0, 1)
, lower.at(1, 0), lower.at(1, 1));
}
/// Forms a 4-by-2 matrix by vertically concatenating a 3-by-2 matrix with a 1-by-2 matrix
CUTLASS_HOST_DEVICE
static Matrix vcat(Matrix<Element, 3, 2> const & upper, Matrix<Element, 1, 2> const & lower) {
return Matrix(
upper.at(0, 0), upper.at(0, 1)
, upper.at(1, 0), upper.at(1, 1)
, upper.at(2, 0), upper.at(2, 1)
, lower.at(0, 0), lower.at(0, 1));
}
/// Forms a 4-by-2 matrix by concatenating four components
CUTLASS_HOST_DEVICE
static Matrix block(
Element A, Element B,
Matrix<Element, 3, 1> const & C, Matrix<Element, 3, 1> const & D) {
return Matrix(
A, B
, C.at(0, 0), D.at(0, 0)
, C.at(1, 0), D.at(1, 0)
, C.at(2, 0), D.at(2, 0)
);
}
/// Forms a 4-by-2 matrix by concatenating four components
CUTLASS_HOST_DEVICE
static Matrix block(
Matrix<Element, 2, 1> const & A, Matrix<Element, 2, 1> const & B,
Matrix<Element, 2, 1> const & C, Matrix<Element, 2, 1> const & D) {
return Matrix(
A.at(0, 0), B.at(0, 0)
, A.at(1, 0), B.at(1, 0)
, C.at(0, 0), D.at(0, 0)
, C.at(1, 0), D.at(1, 0)
);
}
/// Forms a 4-by-2 matrix by concatenating four components
CUTLASS_HOST_DEVICE
static Matrix block(
Matrix<Element, 3, 1> const & A, Matrix<Element, 3, 1> const & B,
Element C, Element D) {
return Matrix(
A.at(0, 0), B.at(0, 0)
, A.at(1, 0), B.at(1, 0)
, A.at(2, 0), B.at(2, 0)
, C, D
);
}
/// Elementwise add operator (4-by-2)
CUTLASS_HOST_DEVICE
Matrix add(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] + rhs.data[0];
result.data[1] = data[1] + rhs.data[1];
result.data[2] = data[2] + rhs.data[2];
result.data[3] = data[3] + rhs.data[3];
result.data[4] = data[4] + rhs.data[4];
result.data[5] = data[5] + rhs.data[5];
result.data[6] = data[6] + rhs.data[6];
result.data[7] = data[7] + rhs.data[7];
return result;
}
/// Elementwise add operator (4-by-2)
CUTLASS_HOST_DEVICE
Matrix operator +(Matrix const &rhs) const {
return add(rhs);
}
/// Elementwise add operator (4-by-2)
CUTLASS_HOST_DEVICE
Matrix & operator +=(Matrix const &rhs) {
data[0] += rhs.data[0];
data[1] += rhs.data[1];
data[2] += rhs.data[2];
data[3] += rhs.data[3];
data[4] += rhs.data[4];
data[5] += rhs.data[5];
data[6] += rhs.data[6];
data[7] += rhs.data[7];
return *this;
}
/// Elementwise subtract operator (4-by-2)
CUTLASS_HOST_DEVICE
Matrix subtract(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] - rhs.data[0];
result.data[1] = data[1] - rhs.data[1];
result.data[2] = data[2] - rhs.data[2];
result.data[3] = data[3] - rhs.data[3];
result.data[4] = data[4] - rhs.data[4];
result.data[5] = data[5] - rhs.data[5];
result.data[6] = data[6] - rhs.data[6];
result.data[7] = data[7] - rhs.data[7];
return result;
}
/// Elementwise subtract operator (4-by-2)
CUTLASS_HOST_DEVICE
Matrix operator -(Matrix const &rhs) const {
return subtract(rhs);
}
/// Elementwise subtract operator (4-by-2)
CUTLASS_HOST_DEVICE
Matrix & operator -=(Matrix const &rhs) {
data[0] -= rhs.data[0];
data[1] -= rhs.data[1];
data[2] -= rhs.data[2];
data[3] -= rhs.data[3];
data[4] -= rhs.data[4];
data[5] -= rhs.data[5];
data[6] -= rhs.data[6];
data[7] -= rhs.data[7];
return *this;
}
/// Elementwise multiply operator (4-by-2)
CUTLASS_HOST_DEVICE
Matrix multiply(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] * rhs.data[0];
result.data[1] = data[1] * rhs.data[1];
result.data[2] = data[2] * rhs.data[2];
result.data[3] = data[3] * rhs.data[3];
result.data[4] = data[4] * rhs.data[4];
result.data[5] = data[5] * rhs.data[5];
result.data[6] = data[6] * rhs.data[6];
result.data[7] = data[7] * rhs.data[7];
return result;
}
/// Scalar multiply operator (4-by-2)
CUTLASS_HOST_DEVICE
Matrix multiply(Element const &s) const {
Matrix result;
result.data[0] = data[0] * s;
result.data[1] = data[1] * s;
result.data[2] = data[2] * s;
result.data[3] = data[3] * s;
result.data[4] = data[4] * s;
result.data[5] = data[5] * s;
result.data[6] = data[6] * s;
result.data[7] = data[7] * s;
return result;
}
/// Scalar multiply operator (4-by-2)
CUTLASS_HOST_DEVICE
Matrix operator *(Element const &s) const {
return multiply(s);
}
/// Scalar multiply operator (4-by-2)
CUTLASS_HOST_DEVICE
Matrix & operator *=(Element const &s) {
data[0] *= s;
data[1] *= s;
data[2] *= s;
data[3] *= s;
data[4] *= s;
data[5] *= s;
data[6] *= s;
data[7] *= s;
return *this;
}
/// Elementwise divide operator (4-by-2)
CUTLASS_HOST_DEVICE
Matrix divide(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] / rhs.data[0];
result.data[1] = data[1] / rhs.data[1];
result.data[2] = data[2] / rhs.data[2];
result.data[3] = data[3] / rhs.data[3];
result.data[4] = data[4] / rhs.data[4];
result.data[5] = data[5] / rhs.data[5];
result.data[6] = data[6] / rhs.data[6];
result.data[7] = data[7] / rhs.data[7];
return result;
}
/// Scalar divide operator (4-by-2)
CUTLASS_HOST_DEVICE
Matrix divide(Element const &s) const {
Matrix result;
result.data[0] = data[0] / s;
result.data[1] = data[1] / s;
result.data[2] = data[2] / s;
result.data[3] = data[3] / s;
result.data[4] = data[4] / s;
result.data[5] = data[5] / s;
result.data[6] = data[6] / s;
result.data[7] = data[7] / s;
return result;
}
/// Scalar divide operator (4-by-2)
CUTLASS_HOST_DEVICE
Matrix operator /(Element const &s) const {
return divide(s);
}
/// Scalar divide operator (4-by-2)
CUTLASS_HOST_DEVICE
Matrix & operator /=(Element const &s) {
data[0] /= s;
data[1] /= s;
data[2] /= s;
data[3] /= s;
data[4] /= s;
data[5] /= s;
data[6] /= s;
data[7] /= s;
return *this;
}
/// Elementwise divide operator (4-by-2)
CUTLASS_HOST_DEVICE
Matrix operator /(Matrix const &rhs) const {
return divide(rhs);
}
/// Elementwise divide operator (4-by-2)
CUTLASS_HOST_DEVICE
Matrix & operator /=(Matrix const &rhs) {
data[0] /= rhs.data[0];
data[1] /= rhs.data[1];
data[2] /= rhs.data[2];
data[3] /= rhs.data[3];
data[4] /= rhs.data[4];
data[5] /= rhs.data[5];
data[6] /= rhs.data[6];
data[7] /= rhs.data[7];
return *this;
}
/// Negates each element of the matrix
CUTLASS_HOST_DEVICE
Matrix operator-() const {
Matrix m;
m.data[0] = -m.data[0];
m.data[1] = -m.data[1];
m.data[2] = -m.data[2];
m.data[3] = -m.data[3];
m.data[4] = -m.data[4];
m.data[5] = -m.data[5];
m.data[6] = -m.data[6];
m.data[7] = -m.data[7];
return m;
}
/// Matrix product of size 4-by-1-by-2
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 1> product(
Matrix<Element, 2, 1> const &rhs,
Matrix<Element, 4, 1> accum = Matrix<Element, 4, 1>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[2] * rhs.data[0];
accum.data[2] += data[4] * rhs.data[0];
accum.data[3] += data[6] * rhs.data[0];
// k=1
accum.data[0] += data[1] * rhs.data[1];
accum.data[1] += data[3] * rhs.data[1];
accum.data[2] += data[5] * rhs.data[1];
accum.data[3] += data[7] * rhs.data[1];
return accum;
}
/// Matrix product of size 4-by-1-by-2
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 1> operator*(Matrix<Element, 2, 1> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 4-by-2-by-2
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 2> product(
Matrix<Element, 2, 2> const &rhs,
Matrix<Element, 4, 2> accum = Matrix<Element, 4, 2>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[2] * rhs.data[0];
accum.data[3] += data[2] * rhs.data[1];
accum.data[4] += data[4] * rhs.data[0];
accum.data[5] += data[4] * rhs.data[1];
accum.data[6] += data[6] * rhs.data[0];
accum.data[7] += data[6] * rhs.data[1];
// k=1
accum.data[0] += data[1] * rhs.data[2];
accum.data[1] += data[1] * rhs.data[3];
accum.data[2] += data[3] * rhs.data[2];
accum.data[3] += data[3] * rhs.data[3];
accum.data[4] += data[5] * rhs.data[2];
accum.data[5] += data[5] * rhs.data[3];
accum.data[6] += data[7] * rhs.data[2];
accum.data[7] += data[7] * rhs.data[3];
return accum;
}
/// Matrix product of size 4-by-2-by-2
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 2> operator*(Matrix<Element, 2, 2> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 4-by-2-by-2
CUTLASS_HOST_DEVICE
Matrix & operator*=(Matrix<Element, 2, 2> const &rhs) {
*this = product(rhs);
return *this;
}
/// Matrix product of size 4-by-3-by-2
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 3> product(
Matrix<Element, 2, 3> const &rhs,
Matrix<Element, 4, 3> accum = Matrix<Element, 4, 3>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[0] * rhs.data[2];
accum.data[3] += data[2] * rhs.data[0];
accum.data[4] += data[2] * rhs.data[1];
accum.data[5] += data[2] * rhs.data[2];
accum.data[6] += data[4] * rhs.data[0];
accum.data[7] += data[4] * rhs.data[1];
accum.data[8] += data[4] * rhs.data[2];
accum.data[9] += data[6] * rhs.data[0];
accum.data[10] += data[6] * rhs.data[1];
accum.data[11] += data[6] * rhs.data[2];
// k=1
accum.data[0] += data[1] * rhs.data[3];
accum.data[1] += data[1] * rhs.data[4];
accum.data[2] += data[1] * rhs.data[5];
accum.data[3] += data[3] * rhs.data[3];
accum.data[4] += data[3] * rhs.data[4];
accum.data[5] += data[3] * rhs.data[5];
accum.data[6] += data[5] * rhs.data[3];
accum.data[7] += data[5] * rhs.data[4];
accum.data[8] += data[5] * rhs.data[5];
accum.data[9] += data[7] * rhs.data[3];
accum.data[10] += data[7] * rhs.data[4];
accum.data[11] += data[7] * rhs.data[5];
return accum;
}
/// Matrix product of size 4-by-3-by-2
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 3> operator*(Matrix<Element, 2, 3> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 4-by-4-by-2
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 4> product(
Matrix<Element, 2, 4> const &rhs,
Matrix<Element, 4, 4> accum = Matrix<Element, 4, 4>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[0] * rhs.data[2];
accum.data[3] += data[0] * rhs.data[3];
accum.data[4] += data[2] * rhs.data[0];
accum.data[5] += data[2] * rhs.data[1];
accum.data[6] += data[2] * rhs.data[2];
accum.data[7] += data[2] * rhs.data[3];
accum.data[8] += data[4] * rhs.data[0];
accum.data[9] += data[4] * rhs.data[1];
accum.data[10] += data[4] * rhs.data[2];
accum.data[11] += data[4] * rhs.data[3];
accum.data[12] += data[6] * rhs.data[0];
accum.data[13] += data[6] * rhs.data[1];
accum.data[14] += data[6] * rhs.data[2];
accum.data[15] += data[6] * rhs.data[3];
// k=1
accum.data[0] += data[1] * rhs.data[4];
accum.data[1] += data[1] * rhs.data[5];
accum.data[2] += data[1] * rhs.data[6];
accum.data[3] += data[1] * rhs.data[7];
accum.data[4] += data[3] * rhs.data[4];
accum.data[5] += data[3] * rhs.data[5];
accum.data[6] += data[3] * rhs.data[6];
accum.data[7] += data[3] * rhs.data[7];
accum.data[8] += data[5] * rhs.data[4];
accum.data[9] += data[5] * rhs.data[5];
accum.data[10] += data[5] * rhs.data[6];
accum.data[11] += data[5] * rhs.data[7];
accum.data[12] += data[7] * rhs.data[4];
accum.data[13] += data[7] * rhs.data[5];
accum.data[14] += data[7] * rhs.data[6];
accum.data[15] += data[7] * rhs.data[7];
return accum;
}
/// Matrix product of size 4-by-4-by-2
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 4> operator*(Matrix<Element, 2, 4> const &rhs) const {
return product(rhs);
}
/// Returns the sum of elements
CUTLASS_HOST_DEVICE
Element sum(Element accum = Element()) const {
accum += data[0];
accum += data[1];
accum += data[2];
accum += data[3];
accum += data[4];
accum += data[5];
accum += data[6];
accum += data[7];
return accum;
}
/// Returns the sum of squared elements
CUTLASS_HOST_DEVICE
Element norm(Element accum = Element()) const {
accum += data[0] * data[0];
accum += data[1] * data[1];
accum += data[2] * data[2];
accum += data[3] * data[3];
accum += data[4] * data[4];
accum += data[5] * data[5];
accum += data[6] * data[6];
accum += data[7] * data[7];
return accum;
}
/// Returns square root of the norm
CUTLASS_HOST_DEVICE
Element magnitude() const {
return fast_sqrt(norm());
}
/// Returns the sum of diagonal elements
CUTLASS_HOST_DEVICE
Element trace(Element accum = Element()) const {
accum += data[0];
accum += data[3];
return accum;
}
};
/// Template alias for 4-by-2 matrix
template <typename Element>
using Matrix4x2 = Matrix<Element, 4, 2>;
/// Free funciton to infer element type from template arguments
template <typename Element>
CUTLASS_HOST_DEVICE Matrix4x2<Element> make_Matrix4x2(
Element _0_0, Element _0_1,
Element _1_0, Element _1_1,
Element _2_0, Element _2_1,
Element _3_0, Element _3_1
) {
return Matrix4x2<Element>(
_0_0, _0_1,
_1_0, _1_1,
_2_0, _2_1,
_3_0, _3_1
);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// 4-by-3 matrix template class definition
template <typename Element_>
struct Matrix<Element_, 4, 3> {
//
// Type definitions
//
/// Element data type
using Element = Element_;
/// Number of rows in matrix
static int const kRows = 4;
/// Number of columns in matrix
static int const kColumns = 3;
/// Layout of matrix in underlying array
using Layout = layout::RowMajor;
/// Number of elements in matrix
static int const kCount = 12;
//
// Data members
//
/// Elements of the matrix in row-major layout
Array<Element, kCount> data;
//
// Methods
//
/// Constructs a zero matrix
CUTLASS_HOST_DEVICE
Matrix() {
data.clear();
}
/// Copy constructor for a 4-by-3 matrix
CUTLASS_HOST_DEVICE
Matrix(Matrix const &rhs) {
data = rhs.data;
}
/// Constucts a 4-by-3 matrix from scalar elements
CUTLASS_HOST_DEVICE
Matrix(
Element _0_0, Element _0_1, Element _0_2,
Element _1_0, Element _1_1, Element _1_2,
Element _2_0, Element _2_1, Element _2_2,
Element _3_0, Element _3_1, Element _3_2
) {
data[0] = _0_0; data[1] = _0_1; data[2] = _0_2;
data[3] = _1_0; data[4] = _1_1; data[5] = _1_2;
data[6] = _2_0; data[7] = _2_1; data[8] = _2_2;
data[9] = _3_0; data[10] = _3_1; data[11] = _3_2;
}
/// Constucts a 4-by-3 matrix from row vectors
CUTLASS_HOST_DEVICE
Matrix(
Matrix<Element, 1, 3> const &row_0,
Matrix<Element, 1, 3> const &row_1,
Matrix<Element, 1, 3> const &row_2,
Matrix<Element, 1, 3> const &row_3
) {
data[0] = row_0.data[0];
data[1] = row_0.data[1];
data[2] = row_0.data[2];
data[3] = row_1.data[0];
data[4] = row_1.data[1];
data[5] = row_1.data[2];
data[6] = row_2.data[0];
data[7] = row_2.data[1];
data[8] = row_2.data[2];
data[9] = row_3.data[0];
data[10] = row_3.data[1];
data[11] = row_3.data[2];
}
/// Static method to construct a 4-by-3 matrix from column vectors
CUTLASS_HOST_DEVICE
static Matrix from_columns(
Matrix<Element, 3, 1> const &column_0,
Matrix<Element, 3, 1> const &column_1,
Matrix<Element, 3, 1> const &column_2
) {
Matrix result;
result.data[0] = column_0.data[0];
result.data[1] = column_1.data[0];
result.data[2] = column_2.data[0];
result.data[3] = column_0.data[1];
result.data[4] = column_1.data[1];
result.data[5] = column_2.data[1];
result.data[6] = column_0.data[2];
result.data[7] = column_1.data[2];
result.data[8] = column_2.data[2];
result.data[9] = column_0.data[3];
result.data[10] = column_1.data[3];
result.data[11] = column_2.data[3];
return result;
}
/// Constructs a matrix from a uniform element
CUTLASS_HOST_DEVICE
static Matrix uniform(Element s) {
Matrix m;
m.data[0] = s;
m.data[1] = s;
m.data[2] = s;
m.data[3] = s;
m.data[4] = s;
m.data[5] = s;
m.data[6] = s;
m.data[7] = s;
m.data[8] = s;
m.data[9] = s;
m.data[10] = s;
m.data[11] = s;
return m;
}
/// Constructs a matrix from a uniform element 1
CUTLASS_HOST_DEVICE
static Matrix ones() {
return uniform(Element(1));
}
/// Constructs a matrix from a uniform element 0
CUTLASS_HOST_DEVICE
static Matrix zero() {
return Matrix();
}
/// Constructs a matrix from elements along its diagonal
CUTLASS_HOST_DEVICE
static Matrix from_diagonal(Matrix<Element, 3, 1> const &diag) {
Matrix m;
m.data[0] = diag.data[0];
m.data[5] = diag.data[1];
m.data[10] = diag.data[2];
m.data[15] = diag.data[3];
return m;
}
/// Constructs a matrix from elements along its diagonal
CUTLASS_HOST_DEVICE
static Matrix from_diagonal(Matrix<Element, 1, 3> const &diag) {
Matrix m;
m.data[0] = diag.data[0];
m.data[5] = diag.data[1];
m.data[10] = diag.data[2];
m.data[15] = diag.data[3];
return m;
}
/// Gets an array of diagonal elements
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 1> diagonal() const {
Matrix<Element, 3, 1> diag;
diag.data[0] = data[0];
diag.data[1] = data[5];
diag.data[2] = data[10];
diag.data[3] = data[15];
return diag;
}
/// Returns a transposed matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 4> transpose() const {
Matrix<Element, 3, 4> mt;
mt.data[0] = data[0];
mt.data[4] = data[1];
mt.data[8] = data[2];
mt.data[1] = data[3];
mt.data[5] = data[4];
mt.data[9] = data[5];
mt.data[2] = data[6];
mt.data[6] = data[7];
mt.data[10] = data[8];
mt.data[3] = data[9];
mt.data[7] = data[10];
mt.data[11] = data[11];
return mt;
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element at(int i, int j) const {
return data[i * 4 + j];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & at(int i, int j) {
return data[i * 4 + j];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element at(Coord<2> const &coord) const {
return at(coord[0], coord[1]);
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & at(Coord<2> const &coord) {
return at(coord[0], coord[1]);
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element &at(int offset) {
return data[offset];
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element at(int offset) const {
return data[offset];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element operator[](Coord<2> const &coord) const {
return at(coord[0], coord[1]);
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & operator[](Coord<2> const &coord) {
return at(coord[0], coord[1]);
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element & operator[](int offset) {
return data[offset];
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element operator[](int offset) const {
return data[offset];
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 2> slice_1x2(int i = 0, int j = 0) const {
Matrix<Element, 1, 2> m;
m.data[0] = data[i * 3 + j + 0];
m.data[1] = data[i * 3 + j + 1];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_1x2(Matrix<Element, 1, 2> const &m, int i = 0, int j = 0) {
data[i * 3 + j + 0] = m.data[0];
data[i * 3 + j + 1] = m.data[1];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 3> slice_1x3(int i = 0, int j = 0) const {
Matrix<Element, 1, 3> m;
m.data[0] = data[i * 3 + j + 0];
m.data[1] = data[i * 3 + j + 1];
m.data[2] = data[i * 3 + j + 2];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_1x3(Matrix<Element, 1, 3> const &m, int i = 0, int j = 0) {
data[i * 3 + j + 0] = m.data[0];
data[i * 3 + j + 1] = m.data[1];
data[i * 3 + j + 2] = m.data[2];
return *this;
}
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 3> row(int i) const {
return slice_1x3(i, 0);
}
CUTLASS_HOST_DEVICE
Matrix &set_row(Matrix<Element, 1, 3> const &v, int i = 0) {
return set_slice_1x3(v, i, 0);
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 1> slice_2x1(int i = 0, int j = 0) const {
Matrix<Element, 2, 1> m;
m.data[0] = data[i * 3 + j + 0];
m.data[1] = data[i * 3 + j + 3];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_2x1(Matrix<Element, 2, 1> const &m, int i = 0, int j = 0) {
data[i * 3 + j + 0] = m.data[0];
data[i * 3 + j + 3] = m.data[1];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 2> slice_2x2(int i = 0, int j = 0) const {
Matrix<Element, 2, 2> m;
m.data[0] = data[i * 3 + j + 0];
m.data[1] = data[i * 3 + j + 1];
m.data[2] = data[i * 3 + j + 3];
m.data[3] = data[i * 3 + j + 4];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_2x2(Matrix<Element, 2, 2> const &m, int i = 0, int j = 0) {
data[i * 3 + j + 0] = m.data[0];
data[i * 3 + j + 1] = m.data[1];
data[i * 3 + j + 3] = m.data[2];
data[i * 3 + j + 4] = m.data[3];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 3> slice_2x3(int i = 0, int j = 0) const {
Matrix<Element, 2, 3> m;
m.data[0] = data[i * 3 + j + 0];
m.data[1] = data[i * 3 + j + 1];
m.data[2] = data[i * 3 + j + 2];
m.data[3] = data[i * 3 + j + 3];
m.data[4] = data[i * 3 + j + 4];
m.data[5] = data[i * 3 + j + 5];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_2x3(Matrix<Element, 2, 3> const &m, int i = 0, int j = 0) {
data[i * 3 + j + 0] = m.data[0];
data[i * 3 + j + 1] = m.data[1];
data[i * 3 + j + 2] = m.data[2];
data[i * 3 + j + 3] = m.data[3];
data[i * 3 + j + 4] = m.data[4];
data[i * 3 + j + 5] = m.data[5];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 1> slice_3x1(int i = 0, int j = 0) const {
Matrix<Element, 3, 1> m;
m.data[0] = data[i * 3 + j + 0];
m.data[1] = data[i * 3 + j + 3];
m.data[2] = data[i * 3 + j + 6];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_3x1(Matrix<Element, 3, 1> const &m, int i = 0, int j = 0) {
data[i * 3 + j + 0] = m.data[0];
data[i * 3 + j + 3] = m.data[1];
data[i * 3 + j + 6] = m.data[2];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 2> slice_3x2(int i = 0, int j = 0) const {
Matrix<Element, 3, 2> m;
m.data[0] = data[i * 3 + j + 0];
m.data[1] = data[i * 3 + j + 1];
m.data[2] = data[i * 3 + j + 3];
m.data[3] = data[i * 3 + j + 4];
m.data[4] = data[i * 3 + j + 6];
m.data[5] = data[i * 3 + j + 7];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_3x2(Matrix<Element, 3, 2> const &m, int i = 0, int j = 0) {
data[i * 3 + j + 0] = m.data[0];
data[i * 3 + j + 1] = m.data[1];
data[i * 3 + j + 3] = m.data[2];
data[i * 3 + j + 4] = m.data[3];
data[i * 3 + j + 6] = m.data[4];
data[i * 3 + j + 7] = m.data[5];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 3> slice_3x3(int i = 0, int j = 0) const {
Matrix<Element, 3, 3> m;
m.data[0] = data[i * 3 + j + 0];
m.data[1] = data[i * 3 + j + 1];
m.data[2] = data[i * 3 + j + 2];
m.data[3] = data[i * 3 + j + 3];
m.data[4] = data[i * 3 + j + 4];
m.data[5] = data[i * 3 + j + 5];
m.data[6] = data[i * 3 + j + 6];
m.data[7] = data[i * 3 + j + 7];
m.data[8] = data[i * 3 + j + 8];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_3x3(Matrix<Element, 3, 3> const &m, int i = 0, int j = 0) {
data[i * 3 + j + 0] = m.data[0];
data[i * 3 + j + 1] = m.data[1];
data[i * 3 + j + 2] = m.data[2];
data[i * 3 + j + 3] = m.data[3];
data[i * 3 + j + 4] = m.data[4];
data[i * 3 + j + 5] = m.data[5];
data[i * 3 + j + 6] = m.data[6];
data[i * 3 + j + 7] = m.data[7];
data[i * 3 + j + 8] = m.data[8];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 1> slice_4x1(int i = 0, int j = 0) const {
Matrix<Element, 4, 1> m;
m.data[0] = data[i * 3 + j + 0];
m.data[1] = data[i * 3 + j + 3];
m.data[2] = data[i * 3 + j + 6];
m.data[3] = data[i * 3 + j + 9];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_4x1(Matrix<Element, 4, 1> const &m, int i = 0, int j = 0) {
data[i * 3 + j + 0] = m.data[0];
data[i * 3 + j + 3] = m.data[1];
data[i * 3 + j + 6] = m.data[2];
data[i * 3 + j + 9] = m.data[3];
return *this;
}
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 1> column(int j) const {
return slice_4x1(0, j);
}
CUTLASS_HOST_DEVICE
Matrix &set_column(Matrix<Element, 4, 1> const &v, int j =0) {
return set_slice_4x1(v, 0, j);
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 2> slice_4x2(int i = 0, int j = 0) const {
Matrix<Element, 4, 2> m;
m.data[0] = data[i * 3 + j + 0];
m.data[1] = data[i * 3 + j + 1];
m.data[2] = data[i * 3 + j + 3];
m.data[3] = data[i * 3 + j + 4];
m.data[4] = data[i * 3 + j + 6];
m.data[5] = data[i * 3 + j + 7];
m.data[6] = data[i * 3 + j + 9];
m.data[7] = data[i * 3 + j + 10];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_4x2(Matrix<Element, 4, 2> const &m, int i = 0, int j = 0) {
data[i * 3 + j + 0] = m.data[0];
data[i * 3 + j + 1] = m.data[1];
data[i * 3 + j + 3] = m.data[2];
data[i * 3 + j + 4] = m.data[3];
data[i * 3 + j + 6] = m.data[4];
data[i * 3 + j + 7] = m.data[5];
data[i * 3 + j + 9] = m.data[6];
data[i * 3 + j + 10] = m.data[7];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 3> slice_4x3(int i = 0, int j = 0) const {
Matrix<Element, 4, 3> m;
m.data[0] = data[i * 3 + j + 0];
m.data[1] = data[i * 3 + j + 1];
m.data[2] = data[i * 3 + j + 2];
m.data[3] = data[i * 3 + j + 3];
m.data[4] = data[i * 3 + j + 4];
m.data[5] = data[i * 3 + j + 5];
m.data[6] = data[i * 3 + j + 6];
m.data[7] = data[i * 3 + j + 7];
m.data[8] = data[i * 3 + j + 8];
m.data[9] = data[i * 3 + j + 9];
m.data[10] = data[i * 3 + j + 10];
m.data[11] = data[i * 3 + j + 11];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_4x3(Matrix<Element, 4, 3> const &m, int i = 0, int j = 0) {
data[i * 3 + j + 0] = m.data[0];
data[i * 3 + j + 1] = m.data[1];
data[i * 3 + j + 2] = m.data[2];
data[i * 3 + j + 3] = m.data[3];
data[i * 3 + j + 4] = m.data[4];
data[i * 3 + j + 5] = m.data[5];
data[i * 3 + j + 6] = m.data[6];
data[i * 3 + j + 7] = m.data[7];
data[i * 3 + j + 8] = m.data[8];
data[i * 3 + j + 9] = m.data[9];
data[i * 3 + j + 10] = m.data[10];
data[i * 3 + j + 11] = m.data[11];
return *this;
}
/// Forms a 4-by-3 matrix by horizontally concatenating a 4-by-1 matrix with a 4-by-2 matrix
CUTLASS_HOST_DEVICE
static Matrix hcat(Matrix<Element, 4, 1> const & lhs, Matrix<Element, 4, 2> const & rhs) {
return Matrix(
lhs.at(0, 0), rhs.at(0, 0), rhs.at(0, 1)
, lhs.at(1, 0), rhs.at(1, 0), rhs.at(1, 1)
, lhs.at(2, 0), rhs.at(2, 0), rhs.at(2, 1)
, lhs.at(3, 0), rhs.at(3, 0), rhs.at(3, 1));
}
/// Forms a 4-by-3 matrix by horizontally concatenating a 4-by-2 matrix with a 4-by-1 matrix
CUTLASS_HOST_DEVICE
static Matrix hcat(Matrix<Element, 4, 2> const & lhs, Matrix<Element, 4, 1> const & rhs) {
return Matrix(
lhs.at(0, 0), lhs.at(0, 1), rhs.at(0, 0)
, lhs.at(1, 0), lhs.at(1, 1), rhs.at(1, 0)
, lhs.at(2, 0), lhs.at(2, 1), rhs.at(2, 0)
, lhs.at(3, 0), lhs.at(3, 1), rhs.at(3, 0));
}
/// Concatenates this matrix with a a 4-by-1 matrix to form a 4-by-4 matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 4> hcat(Matrix<Element, 4, 1> const & rhs) const {
return Matrix<Element, 4, 4>::hcat(*this, rhs);
}
/// Forms a 4-by-3 matrix by vertically concatenating a 1-by-3 matrix with a 3-by-3 matrix
CUTLASS_HOST_DEVICE
static Matrix vcat(Matrix<Element, 1, 3> const & upper, Matrix<Element, 3, 3> const & lower) {
return Matrix(
upper.at(0, 0), upper.at(0, 1), upper.at(0, 2)
, lower.at(0, 0), lower.at(0, 1), lower.at(0, 2)
, lower.at(1, 0), lower.at(1, 1), lower.at(1, 2)
, lower.at(2, 0), lower.at(2, 1), lower.at(2, 2));
}
/// Forms a 4-by-3 matrix by vertically concatenating a 2-by-3 matrix with a 2-by-3 matrix
CUTLASS_HOST_DEVICE
static Matrix vcat(Matrix<Element, 2, 3> const & upper, Matrix<Element, 2, 3> const & lower) {
return Matrix(
upper.at(0, 0), upper.at(0, 1), upper.at(0, 2)
, upper.at(1, 0), upper.at(1, 1), upper.at(1, 2)
, lower.at(0, 0), lower.at(0, 1), lower.at(0, 2)
, lower.at(1, 0), lower.at(1, 1), lower.at(1, 2));
}
/// Forms a 4-by-3 matrix by vertically concatenating a 3-by-3 matrix with a 1-by-3 matrix
CUTLASS_HOST_DEVICE
static Matrix vcat(Matrix<Element, 3, 3> const & upper, Matrix<Element, 1, 3> const & lower) {
return Matrix(
upper.at(0, 0), upper.at(0, 1), upper.at(0, 2)
, upper.at(1, 0), upper.at(1, 1), upper.at(1, 2)
, upper.at(2, 0), upper.at(2, 1), upper.at(2, 2)
, lower.at(0, 0), lower.at(0, 1), lower.at(0, 2));
}
/// Forms a 4-by-3 matrix by concatenating four components
CUTLASS_HOST_DEVICE
static Matrix block(
Element A, Matrix<Element, 1, 2> const & B,
Matrix<Element, 3, 1> const & C, Matrix<Element, 3, 2> const & D) {
return Matrix(
A, B.at(0, 0), B.at(0, 1)
, C.at(0, 0), D.at(0, 0), D.at(0, 1)
, C.at(1, 0), D.at(1, 0), D.at(1, 1)
, C.at(2, 0), D.at(2, 0), D.at(2, 1)
);
}
/// Forms a 4-by-3 matrix by concatenating four components
CUTLASS_HOST_DEVICE
static Matrix block(
Matrix<Element, 1, 2> const & A, Element B,
Matrix<Element, 3, 2> const & C, Matrix<Element, 3, 1> const & D) {
return Matrix(
A.at(0, 0), A.at(0, 1), B
, C.at(0, 0), C.at(0, 1), D.at(0, 0)
, C.at(1, 0), C.at(1, 1), D.at(1, 0)
, C.at(2, 0), C.at(2, 1), D.at(2, 0)
);
}
/// Forms a 4-by-3 matrix by concatenating four components
CUTLASS_HOST_DEVICE
static Matrix block(
Matrix<Element, 2, 1> const & A, Matrix<Element, 2, 2> const & B,
Matrix<Element, 2, 1> const & C, Matrix<Element, 2, 2> const & D) {
return Matrix(
A.at(0, 0), B.at(0, 0), B.at(0, 1)
, A.at(1, 0), B.at(1, 0), B.at(1, 1)
, C.at(0, 0), D.at(0, 0), D.at(0, 1)
, C.at(1, 0), D.at(1, 0), D.at(1, 1)
);
}
/// Forms a 4-by-3 matrix by concatenating four components
CUTLASS_HOST_DEVICE
static Matrix block(
Matrix<Element, 2, 2> const & A, Matrix<Element, 2, 1> const & B,
Matrix<Element, 2, 2> const & C, Matrix<Element, 2, 1> const & D) {
return Matrix(
A.at(0, 0), A.at(0, 1), B.at(0, 0)
, A.at(1, 0), A.at(1, 1), B.at(1, 0)
, C.at(0, 0), C.at(0, 1), D.at(0, 0)
, C.at(1, 0), C.at(1, 1), D.at(1, 0)
);
}
/// Forms a 4-by-3 matrix by concatenating four components
CUTLASS_HOST_DEVICE
static Matrix block(
Matrix<Element, 3, 1> const & A, Matrix<Element, 3, 2> const & B,
Element C, Matrix<Element, 1, 2> const & D) {
return Matrix(
A.at(0, 0), B.at(0, 0), B.at(0, 1)
, A.at(1, 0), B.at(1, 0), B.at(1, 1)
, A.at(2, 0), B.at(2, 0), B.at(2, 1)
, C, D.at(0, 0), D.at(0, 1)
);
}
/// Forms a 4-by-3 matrix by concatenating four components
CUTLASS_HOST_DEVICE
static Matrix block(
Matrix<Element, 3, 2> const & A, Matrix<Element, 3, 1> const & B,
Matrix<Element, 1, 2> const & C, Element D) {
return Matrix(
A.at(0, 0), A.at(0, 1), B.at(0, 0)
, A.at(1, 0), A.at(1, 1), B.at(1, 0)
, A.at(2, 0), A.at(2, 1), B.at(2, 0)
, C.at(0, 0), C.at(0, 1), D
);
}
/// Elementwise add operator (4-by-3)
CUTLASS_HOST_DEVICE
Matrix add(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] + rhs.data[0];
result.data[1] = data[1] + rhs.data[1];
result.data[2] = data[2] + rhs.data[2];
result.data[3] = data[3] + rhs.data[3];
result.data[4] = data[4] + rhs.data[4];
result.data[5] = data[5] + rhs.data[5];
result.data[6] = data[6] + rhs.data[6];
result.data[7] = data[7] + rhs.data[7];
result.data[8] = data[8] + rhs.data[8];
result.data[9] = data[9] + rhs.data[9];
result.data[10] = data[10] + rhs.data[10];
result.data[11] = data[11] + rhs.data[11];
return result;
}
/// Elementwise add operator (4-by-3)
CUTLASS_HOST_DEVICE
Matrix operator +(Matrix const &rhs) const {
return add(rhs);
}
/// Elementwise add operator (4-by-3)
CUTLASS_HOST_DEVICE
Matrix & operator +=(Matrix const &rhs) {
data[0] += rhs.data[0];
data[1] += rhs.data[1];
data[2] += rhs.data[2];
data[3] += rhs.data[3];
data[4] += rhs.data[4];
data[5] += rhs.data[5];
data[6] += rhs.data[6];
data[7] += rhs.data[7];
data[8] += rhs.data[8];
data[9] += rhs.data[9];
data[10] += rhs.data[10];
data[11] += rhs.data[11];
return *this;
}
/// Elementwise subtract operator (4-by-3)
CUTLASS_HOST_DEVICE
Matrix subtract(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] - rhs.data[0];
result.data[1] = data[1] - rhs.data[1];
result.data[2] = data[2] - rhs.data[2];
result.data[3] = data[3] - rhs.data[3];
result.data[4] = data[4] - rhs.data[4];
result.data[5] = data[5] - rhs.data[5];
result.data[6] = data[6] - rhs.data[6];
result.data[7] = data[7] - rhs.data[7];
result.data[8] = data[8] - rhs.data[8];
result.data[9] = data[9] - rhs.data[9];
result.data[10] = data[10] - rhs.data[10];
result.data[11] = data[11] - rhs.data[11];
return result;
}
/// Elementwise subtract operator (4-by-3)
CUTLASS_HOST_DEVICE
Matrix operator -(Matrix const &rhs) const {
return subtract(rhs);
}
/// Elementwise subtract operator (4-by-3)
CUTLASS_HOST_DEVICE
Matrix & operator -=(Matrix const &rhs) {
data[0] -= rhs.data[0];
data[1] -= rhs.data[1];
data[2] -= rhs.data[2];
data[3] -= rhs.data[3];
data[4] -= rhs.data[4];
data[5] -= rhs.data[5];
data[6] -= rhs.data[6];
data[7] -= rhs.data[7];
data[8] -= rhs.data[8];
data[9] -= rhs.data[9];
data[10] -= rhs.data[10];
data[11] -= rhs.data[11];
return *this;
}
/// Elementwise multiply operator (4-by-3)
CUTLASS_HOST_DEVICE
Matrix multiply(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] * rhs.data[0];
result.data[1] = data[1] * rhs.data[1];
result.data[2] = data[2] * rhs.data[2];
result.data[3] = data[3] * rhs.data[3];
result.data[4] = data[4] * rhs.data[4];
result.data[5] = data[5] * rhs.data[5];
result.data[6] = data[6] * rhs.data[6];
result.data[7] = data[7] * rhs.data[7];
result.data[8] = data[8] * rhs.data[8];
result.data[9] = data[9] * rhs.data[9];
result.data[10] = data[10] * rhs.data[10];
result.data[11] = data[11] * rhs.data[11];
return result;
}
/// Scalar multiply operator (4-by-3)
CUTLASS_HOST_DEVICE
Matrix multiply(Element const &s) const {
Matrix result;
result.data[0] = data[0] * s;
result.data[1] = data[1] * s;
result.data[2] = data[2] * s;
result.data[3] = data[3] * s;
result.data[4] = data[4] * s;
result.data[5] = data[5] * s;
result.data[6] = data[6] * s;
result.data[7] = data[7] * s;
result.data[8] = data[8] * s;
result.data[9] = data[9] * s;
result.data[10] = data[10] * s;
result.data[11] = data[11] * s;
return result;
}
/// Scalar multiply operator (4-by-3)
CUTLASS_HOST_DEVICE
Matrix operator *(Element const &s) const {
return multiply(s);
}
/// Scalar multiply operator (4-by-3)
CUTLASS_HOST_DEVICE
Matrix & operator *=(Element const &s) {
data[0] *= s;
data[1] *= s;
data[2] *= s;
data[3] *= s;
data[4] *= s;
data[5] *= s;
data[6] *= s;
data[7] *= s;
data[8] *= s;
data[9] *= s;
data[10] *= s;
data[11] *= s;
return *this;
}
/// Elementwise divide operator (4-by-3)
CUTLASS_HOST_DEVICE
Matrix divide(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] / rhs.data[0];
result.data[1] = data[1] / rhs.data[1];
result.data[2] = data[2] / rhs.data[2];
result.data[3] = data[3] / rhs.data[3];
result.data[4] = data[4] / rhs.data[4];
result.data[5] = data[5] / rhs.data[5];
result.data[6] = data[6] / rhs.data[6];
result.data[7] = data[7] / rhs.data[7];
result.data[8] = data[8] / rhs.data[8];
result.data[9] = data[9] / rhs.data[9];
result.data[10] = data[10] / rhs.data[10];
result.data[11] = data[11] / rhs.data[11];
return result;
}
/// Scalar divide operator (4-by-3)
CUTLASS_HOST_DEVICE
Matrix divide(Element const &s) const {
Matrix result;
result.data[0] = data[0] / s;
result.data[1] = data[1] / s;
result.data[2] = data[2] / s;
result.data[3] = data[3] / s;
result.data[4] = data[4] / s;
result.data[5] = data[5] / s;
result.data[6] = data[6] / s;
result.data[7] = data[7] / s;
result.data[8] = data[8] / s;
result.data[9] = data[9] / s;
result.data[10] = data[10] / s;
result.data[11] = data[11] / s;
return result;
}
/// Scalar divide operator (4-by-3)
CUTLASS_HOST_DEVICE
Matrix operator /(Element const &s) const {
return divide(s);
}
/// Scalar divide operator (4-by-3)
CUTLASS_HOST_DEVICE
Matrix & operator /=(Element const &s) {
data[0] /= s;
data[1] /= s;
data[2] /= s;
data[3] /= s;
data[4] /= s;
data[5] /= s;
data[6] /= s;
data[7] /= s;
data[8] /= s;
data[9] /= s;
data[10] /= s;
data[11] /= s;
return *this;
}
/// Elementwise divide operator (4-by-3)
CUTLASS_HOST_DEVICE
Matrix operator /(Matrix const &rhs) const {
return divide(rhs);
}
/// Elementwise divide operator (4-by-3)
CUTLASS_HOST_DEVICE
Matrix & operator /=(Matrix const &rhs) {
data[0] /= rhs.data[0];
data[1] /= rhs.data[1];
data[2] /= rhs.data[2];
data[3] /= rhs.data[3];
data[4] /= rhs.data[4];
data[5] /= rhs.data[5];
data[6] /= rhs.data[6];
data[7] /= rhs.data[7];
data[8] /= rhs.data[8];
data[9] /= rhs.data[9];
data[10] /= rhs.data[10];
data[11] /= rhs.data[11];
return *this;
}
/// Negates each element of the matrix
CUTLASS_HOST_DEVICE
Matrix operator-() const {
Matrix m;
m.data[0] = -m.data[0];
m.data[1] = -m.data[1];
m.data[2] = -m.data[2];
m.data[3] = -m.data[3];
m.data[4] = -m.data[4];
m.data[5] = -m.data[5];
m.data[6] = -m.data[6];
m.data[7] = -m.data[7];
m.data[8] = -m.data[8];
m.data[9] = -m.data[9];
m.data[10] = -m.data[10];
m.data[11] = -m.data[11];
return m;
}
/// Matrix product of size 4-by-1-by-3
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 1> product(
Matrix<Element, 3, 1> const &rhs,
Matrix<Element, 4, 1> accum = Matrix<Element, 4, 1>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[3] * rhs.data[0];
accum.data[2] += data[6] * rhs.data[0];
accum.data[3] += data[9] * rhs.data[0];
// k=1
accum.data[0] += data[1] * rhs.data[1];
accum.data[1] += data[4] * rhs.data[1];
accum.data[2] += data[7] * rhs.data[1];
accum.data[3] += data[10] * rhs.data[1];
// k=2
accum.data[0] += data[2] * rhs.data[2];
accum.data[1] += data[5] * rhs.data[2];
accum.data[2] += data[8] * rhs.data[2];
accum.data[3] += data[11] * rhs.data[2];
return accum;
}
/// Matrix product of size 4-by-1-by-3
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 1> operator*(Matrix<Element, 3, 1> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 4-by-2-by-3
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 2> product(
Matrix<Element, 3, 2> const &rhs,
Matrix<Element, 4, 2> accum = Matrix<Element, 4, 2>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[3] * rhs.data[0];
accum.data[3] += data[3] * rhs.data[1];
accum.data[4] += data[6] * rhs.data[0];
accum.data[5] += data[6] * rhs.data[1];
accum.data[6] += data[9] * rhs.data[0];
accum.data[7] += data[9] * rhs.data[1];
// k=1
accum.data[0] += data[1] * rhs.data[2];
accum.data[1] += data[1] * rhs.data[3];
accum.data[2] += data[4] * rhs.data[2];
accum.data[3] += data[4] * rhs.data[3];
accum.data[4] += data[7] * rhs.data[2];
accum.data[5] += data[7] * rhs.data[3];
accum.data[6] += data[10] * rhs.data[2];
accum.data[7] += data[10] * rhs.data[3];
// k=2
accum.data[0] += data[2] * rhs.data[4];
accum.data[1] += data[2] * rhs.data[5];
accum.data[2] += data[5] * rhs.data[4];
accum.data[3] += data[5] * rhs.data[5];
accum.data[4] += data[8] * rhs.data[4];
accum.data[5] += data[8] * rhs.data[5];
accum.data[6] += data[11] * rhs.data[4];
accum.data[7] += data[11] * rhs.data[5];
return accum;
}
/// Matrix product of size 4-by-2-by-3
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 2> operator*(Matrix<Element, 3, 2> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 4-by-3-by-3
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 3> product(
Matrix<Element, 3, 3> const &rhs,
Matrix<Element, 4, 3> accum = Matrix<Element, 4, 3>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[0] * rhs.data[2];
accum.data[3] += data[3] * rhs.data[0];
accum.data[4] += data[3] * rhs.data[1];
accum.data[5] += data[3] * rhs.data[2];
accum.data[6] += data[6] * rhs.data[0];
accum.data[7] += data[6] * rhs.data[1];
accum.data[8] += data[6] * rhs.data[2];
accum.data[9] += data[9] * rhs.data[0];
accum.data[10] += data[9] * rhs.data[1];
accum.data[11] += data[9] * rhs.data[2];
// k=1
accum.data[0] += data[1] * rhs.data[3];
accum.data[1] += data[1] * rhs.data[4];
accum.data[2] += data[1] * rhs.data[5];
accum.data[3] += data[4] * rhs.data[3];
accum.data[4] += data[4] * rhs.data[4];
accum.data[5] += data[4] * rhs.data[5];
accum.data[6] += data[7] * rhs.data[3];
accum.data[7] += data[7] * rhs.data[4];
accum.data[8] += data[7] * rhs.data[5];
accum.data[9] += data[10] * rhs.data[3];
accum.data[10] += data[10] * rhs.data[4];
accum.data[11] += data[10] * rhs.data[5];
// k=2
accum.data[0] += data[2] * rhs.data[6];
accum.data[1] += data[2] * rhs.data[7];
accum.data[2] += data[2] * rhs.data[8];
accum.data[3] += data[5] * rhs.data[6];
accum.data[4] += data[5] * rhs.data[7];
accum.data[5] += data[5] * rhs.data[8];
accum.data[6] += data[8] * rhs.data[6];
accum.data[7] += data[8] * rhs.data[7];
accum.data[8] += data[8] * rhs.data[8];
accum.data[9] += data[11] * rhs.data[6];
accum.data[10] += data[11] * rhs.data[7];
accum.data[11] += data[11] * rhs.data[8];
return accum;
}
/// Matrix product of size 4-by-3-by-3
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 3> operator*(Matrix<Element, 3, 3> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 4-by-3-by-3
CUTLASS_HOST_DEVICE
Matrix & operator*=(Matrix<Element, 3, 3> const &rhs) {
*this = product(rhs);
return *this;
}
/// Matrix product of size 4-by-4-by-3
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 4> product(
Matrix<Element, 3, 4> const &rhs,
Matrix<Element, 4, 4> accum = Matrix<Element, 4, 4>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[0] * rhs.data[2];
accum.data[3] += data[0] * rhs.data[3];
accum.data[4] += data[3] * rhs.data[0];
accum.data[5] += data[3] * rhs.data[1];
accum.data[6] += data[3] * rhs.data[2];
accum.data[7] += data[3] * rhs.data[3];
accum.data[8] += data[6] * rhs.data[0];
accum.data[9] += data[6] * rhs.data[1];
accum.data[10] += data[6] * rhs.data[2];
accum.data[11] += data[6] * rhs.data[3];
accum.data[12] += data[9] * rhs.data[0];
accum.data[13] += data[9] * rhs.data[1];
accum.data[14] += data[9] * rhs.data[2];
accum.data[15] += data[9] * rhs.data[3];
// k=1
accum.data[0] += data[1] * rhs.data[4];
accum.data[1] += data[1] * rhs.data[5];
accum.data[2] += data[1] * rhs.data[6];
accum.data[3] += data[1] * rhs.data[7];
accum.data[4] += data[4] * rhs.data[4];
accum.data[5] += data[4] * rhs.data[5];
accum.data[6] += data[4] * rhs.data[6];
accum.data[7] += data[4] * rhs.data[7];
accum.data[8] += data[7] * rhs.data[4];
accum.data[9] += data[7] * rhs.data[5];
accum.data[10] += data[7] * rhs.data[6];
accum.data[11] += data[7] * rhs.data[7];
accum.data[12] += data[10] * rhs.data[4];
accum.data[13] += data[10] * rhs.data[5];
accum.data[14] += data[10] * rhs.data[6];
accum.data[15] += data[10] * rhs.data[7];
// k=2
accum.data[0] += data[2] * rhs.data[8];
accum.data[1] += data[2] * rhs.data[9];
accum.data[2] += data[2] * rhs.data[10];
accum.data[3] += data[2] * rhs.data[11];
accum.data[4] += data[5] * rhs.data[8];
accum.data[5] += data[5] * rhs.data[9];
accum.data[6] += data[5] * rhs.data[10];
accum.data[7] += data[5] * rhs.data[11];
accum.data[8] += data[8] * rhs.data[8];
accum.data[9] += data[8] * rhs.data[9];
accum.data[10] += data[8] * rhs.data[10];
accum.data[11] += data[8] * rhs.data[11];
accum.data[12] += data[11] * rhs.data[8];
accum.data[13] += data[11] * rhs.data[9];
accum.data[14] += data[11] * rhs.data[10];
accum.data[15] += data[11] * rhs.data[11];
return accum;
}
/// Matrix product of size 4-by-4-by-3
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 4> operator*(Matrix<Element, 3, 4> const &rhs) const {
return product(rhs);
}
/// Returns the sum of elements
CUTLASS_HOST_DEVICE
Element sum(Element accum = Element()) const {
accum += data[0];
accum += data[1];
accum += data[2];
accum += data[3];
accum += data[4];
accum += data[5];
accum += data[6];
accum += data[7];
accum += data[8];
accum += data[9];
accum += data[10];
accum += data[11];
return accum;
}
/// Returns the sum of squared elements
CUTLASS_HOST_DEVICE
Element norm(Element accum = Element()) const {
accum += data[0] * data[0];
accum += data[1] * data[1];
accum += data[2] * data[2];
accum += data[3] * data[3];
accum += data[4] * data[4];
accum += data[5] * data[5];
accum += data[6] * data[6];
accum += data[7] * data[7];
accum += data[8] * data[8];
accum += data[9] * data[9];
accum += data[10] * data[10];
accum += data[11] * data[11];
return accum;
}
/// Returns square root of the norm
CUTLASS_HOST_DEVICE
Element magnitude() const {
return fast_sqrt(norm());
}
/// Returns the sum of diagonal elements
CUTLASS_HOST_DEVICE
Element trace(Element accum = Element()) const {
accum += data[0];
accum += data[4];
accum += data[8];
return accum;
}
};
/// Template alias for 4-by-3 matrix
template <typename Element>
using Matrix4x3 = Matrix<Element, 4, 3>;
/// Free funciton to infer element type from template arguments
template <typename Element>
CUTLASS_HOST_DEVICE Matrix4x3<Element> make_Matrix4x3(
Element _0_0, Element _0_1, Element _0_2,
Element _1_0, Element _1_1, Element _1_2,
Element _2_0, Element _2_1, Element _2_2,
Element _3_0, Element _3_1, Element _3_2
) {
return Matrix4x3<Element>(
_0_0, _0_1, _0_2,
_1_0, _1_1, _1_2,
_2_0, _2_1, _2_2,
_3_0, _3_1, _3_2
);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// 4-by-4 matrix template class definition
template <typename Element_>
struct Matrix<Element_, 4, 4> {
//
// Type definitions
//
/// Element data type
using Element = Element_;
/// Number of rows in matrix
static int const kRows = 4;
/// Number of columns in matrix
static int const kColumns = 4;
/// Layout of matrix in underlying array
using Layout = layout::RowMajor;
/// Number of elements in matrix
static int const kCount = 16;
//
// Data members
//
/// Elements of the matrix in row-major layout
Array<Element, kCount> data;
//
// Methods
//
/// Constructs a zero matrix
CUTLASS_HOST_DEVICE
Matrix() {
data.clear();
}
/// Copy constructor for a 4-by-4 matrix
CUTLASS_HOST_DEVICE
Matrix(Matrix const &rhs) {
data = rhs.data;
}
/// Constucts a 4-by-4 matrix from scalar elements
CUTLASS_HOST_DEVICE
Matrix(
Element _0_0, Element _0_1, Element _0_2, Element _0_3,
Element _1_0, Element _1_1, Element _1_2, Element _1_3,
Element _2_0, Element _2_1, Element _2_2, Element _2_3,
Element _3_0, Element _3_1, Element _3_2, Element _3_3
) {
data[0] = _0_0; data[1] = _0_1; data[2] = _0_2; data[3] = _0_3;
data[4] = _1_0; data[5] = _1_1; data[6] = _1_2; data[7] = _1_3;
data[8] = _2_0; data[9] = _2_1; data[10] = _2_2; data[11] = _2_3;
data[12] = _3_0; data[13] = _3_1; data[14] = _3_2; data[15] = _3_3;
}
/// Constucts a 4-by-4 matrix from row vectors
CUTLASS_HOST_DEVICE
Matrix(
Matrix<Element, 1, 4> const &row_0,
Matrix<Element, 1, 4> const &row_1,
Matrix<Element, 1, 4> const &row_2,
Matrix<Element, 1, 4> const &row_3
) {
data[0] = row_0.data[0];
data[1] = row_0.data[1];
data[2] = row_0.data[2];
data[3] = row_0.data[3];
data[4] = row_1.data[0];
data[5] = row_1.data[1];
data[6] = row_1.data[2];
data[7] = row_1.data[3];
data[8] = row_2.data[0];
data[9] = row_2.data[1];
data[10] = row_2.data[2];
data[11] = row_2.data[3];
data[12] = row_3.data[0];
data[13] = row_3.data[1];
data[14] = row_3.data[2];
data[15] = row_3.data[3];
}
/// Static method to construct a 4-by-4 matrix from column vectors
CUTLASS_HOST_DEVICE
static Matrix from_columns(
Matrix<Element, 4, 1> const &column_0,
Matrix<Element, 4, 1> const &column_1,
Matrix<Element, 4, 1> const &column_2,
Matrix<Element, 4, 1> const &column_3
) {
Matrix result;
result.data[0] = column_0.data[0];
result.data[1] = column_1.data[0];
result.data[2] = column_2.data[0];
result.data[3] = column_3.data[0];
result.data[4] = column_0.data[1];
result.data[5] = column_1.data[1];
result.data[6] = column_2.data[1];
result.data[7] = column_3.data[1];
result.data[8] = column_0.data[2];
result.data[9] = column_1.data[2];
result.data[10] = column_2.data[2];
result.data[11] = column_3.data[2];
result.data[12] = column_0.data[3];
result.data[13] = column_1.data[3];
result.data[14] = column_2.data[3];
result.data[15] = column_3.data[3];
return result;
}
/// Constructs an identity matrix
CUTLASS_HOST_DEVICE
static Matrix identity() {
Matrix m;
m.data[0] = Element(1);
m.data[5] = Element(1);
m.data[10] = Element(1);
m.data[15] = Element(1);
return m;
}
/// Constructs a matrix from a uniform element
CUTLASS_HOST_DEVICE
static Matrix uniform(Element s) {
Matrix m;
m.data[0] = s;
m.data[1] = s;
m.data[2] = s;
m.data[3] = s;
m.data[4] = s;
m.data[5] = s;
m.data[6] = s;
m.data[7] = s;
m.data[8] = s;
m.data[9] = s;
m.data[10] = s;
m.data[11] = s;
m.data[12] = s;
m.data[13] = s;
m.data[14] = s;
m.data[15] = s;
return m;
}
/// Constructs a matrix from a uniform element 1
CUTLASS_HOST_DEVICE
static Matrix ones() {
return uniform(Element(1));
}
/// Constructs a matrix from a uniform element 0
CUTLASS_HOST_DEVICE
static Matrix zero() {
return Matrix();
}
/// Constructs a matrix from elements along its diagonal
CUTLASS_HOST_DEVICE
static Matrix from_diagonal(Matrix<Element, 4, 1> const &diag) {
Matrix m;
m.data[0] = diag.data[0];
m.data[5] = diag.data[1];
m.data[10] = diag.data[2];
m.data[15] = diag.data[3];
return m;
}
/// Constructs a matrix from elements along its diagonal
CUTLASS_HOST_DEVICE
static Matrix from_diagonal(Matrix<Element, 1, 4> const &diag) {
Matrix m;
m.data[0] = diag.data[0];
m.data[5] = diag.data[1];
m.data[10] = diag.data[2];
m.data[15] = diag.data[3];
return m;
}
/// Gets an array of diagonal elements
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 1> diagonal() const {
Matrix<Element, 4, 1> diag;
diag.data[0] = data[0];
diag.data[1] = data[5];
diag.data[2] = data[10];
diag.data[3] = data[15];
return diag;
}
/// Returns a transposed matrix
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 4> transpose() const {
Matrix<Element, 4, 4> mt;
mt.data[0] = data[0];
mt.data[4] = data[1];
mt.data[8] = data[2];
mt.data[12] = data[3];
mt.data[1] = data[4];
mt.data[5] = data[5];
mt.data[9] = data[6];
mt.data[13] = data[7];
mt.data[2] = data[8];
mt.data[6] = data[9];
mt.data[10] = data[10];
mt.data[14] = data[11];
mt.data[3] = data[12];
mt.data[7] = data[13];
mt.data[11] = data[14];
mt.data[15] = data[15];
return mt;
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element at(int i, int j) const {
return data[i * 4 + j];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & at(int i, int j) {
return data[i * 4 + j];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element at(Coord<2> const &coord) const {
return at(coord[0], coord[1]);
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & at(Coord<2> const &coord) {
return at(coord[0], coord[1]);
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element &at(int offset) {
return data[offset];
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element at(int offset) const {
return data[offset];
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element operator[](Coord<2> const &coord) const {
return at(coord[0], coord[1]);
}
/// Accesses an element by coordinate
CUTLASS_HOST_DEVICE
Element & operator[](Coord<2> const &coord) {
return at(coord[0], coord[1]);
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element & operator[](int offset) {
return data[offset];
}
/// Accesses an element by offset
CUTLASS_HOST_DEVICE
Element operator[](int offset) const {
return data[offset];
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 2> slice_1x2(int i = 0, int j = 0) const {
Matrix<Element, 1, 2> m;
m.data[0] = data[i * 4 + j + 0];
m.data[1] = data[i * 4 + j + 1];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_1x2(Matrix<Element, 1, 2> const &m, int i = 0, int j = 0) {
data[i * 4 + j + 0] = m.data[0];
data[i * 4 + j + 1] = m.data[1];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 3> slice_1x3(int i = 0, int j = 0) const {
Matrix<Element, 1, 3> m;
m.data[0] = data[i * 4 + j + 0];
m.data[1] = data[i * 4 + j + 1];
m.data[2] = data[i * 4 + j + 2];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_1x3(Matrix<Element, 1, 3> const &m, int i = 0, int j = 0) {
data[i * 4 + j + 0] = m.data[0];
data[i * 4 + j + 1] = m.data[1];
data[i * 4 + j + 2] = m.data[2];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 4> slice_1x4(int i = 0, int j = 0) const {
Matrix<Element, 1, 4> m;
m.data[0] = data[i * 4 + j + 0];
m.data[1] = data[i * 4 + j + 1];
m.data[2] = data[i * 4 + j + 2];
m.data[3] = data[i * 4 + j + 3];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_1x4(Matrix<Element, 1, 4> const &m, int i = 0, int j = 0) {
data[i * 4 + j + 0] = m.data[0];
data[i * 4 + j + 1] = m.data[1];
data[i * 4 + j + 2] = m.data[2];
data[i * 4 + j + 3] = m.data[3];
return *this;
}
CUTLASS_HOST_DEVICE
Matrix<Element, 1, 4> row(int i) const {
return slice_1x4(i, 0);
}
CUTLASS_HOST_DEVICE
Matrix &set_row(Matrix<Element, 1, 4> const &v, int i = 0) {
return set_slice_1x4(v, i, 0);
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 1> slice_2x1(int i = 0, int j = 0) const {
Matrix<Element, 2, 1> m;
m.data[0] = data[i * 4 + j + 0];
m.data[1] = data[i * 4 + j + 4];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_2x1(Matrix<Element, 2, 1> const &m, int i = 0, int j = 0) {
data[i * 4 + j + 0] = m.data[0];
data[i * 4 + j + 4] = m.data[1];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 2> slice_2x2(int i = 0, int j = 0) const {
Matrix<Element, 2, 2> m;
m.data[0] = data[i * 4 + j + 0];
m.data[1] = data[i * 4 + j + 1];
m.data[2] = data[i * 4 + j + 4];
m.data[3] = data[i * 4 + j + 5];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_2x2(Matrix<Element, 2, 2> const &m, int i = 0, int j = 0) {
data[i * 4 + j + 0] = m.data[0];
data[i * 4 + j + 1] = m.data[1];
data[i * 4 + j + 4] = m.data[2];
data[i * 4 + j + 5] = m.data[3];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 3> slice_2x3(int i = 0, int j = 0) const {
Matrix<Element, 2, 3> m;
m.data[0] = data[i * 4 + j + 0];
m.data[1] = data[i * 4 + j + 1];
m.data[2] = data[i * 4 + j + 2];
m.data[3] = data[i * 4 + j + 4];
m.data[4] = data[i * 4 + j + 5];
m.data[5] = data[i * 4 + j + 6];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_2x3(Matrix<Element, 2, 3> const &m, int i = 0, int j = 0) {
data[i * 4 + j + 0] = m.data[0];
data[i * 4 + j + 1] = m.data[1];
data[i * 4 + j + 2] = m.data[2];
data[i * 4 + j + 4] = m.data[3];
data[i * 4 + j + 5] = m.data[4];
data[i * 4 + j + 6] = m.data[5];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 2, 4> slice_2x4(int i = 0, int j = 0) const {
Matrix<Element, 2, 4> m;
m.data[0] = data[i * 4 + j + 0];
m.data[1] = data[i * 4 + j + 1];
m.data[2] = data[i * 4 + j + 2];
m.data[3] = data[i * 4 + j + 3];
m.data[4] = data[i * 4 + j + 4];
m.data[5] = data[i * 4 + j + 5];
m.data[6] = data[i * 4 + j + 6];
m.data[7] = data[i * 4 + j + 7];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_2x4(Matrix<Element, 2, 4> const &m, int i = 0, int j = 0) {
data[i * 4 + j + 0] = m.data[0];
data[i * 4 + j + 1] = m.data[1];
data[i * 4 + j + 2] = m.data[2];
data[i * 4 + j + 3] = m.data[3];
data[i * 4 + j + 4] = m.data[4];
data[i * 4 + j + 5] = m.data[5];
data[i * 4 + j + 6] = m.data[6];
data[i * 4 + j + 7] = m.data[7];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 1> slice_3x1(int i = 0, int j = 0) const {
Matrix<Element, 3, 1> m;
m.data[0] = data[i * 4 + j + 0];
m.data[1] = data[i * 4 + j + 4];
m.data[2] = data[i * 4 + j + 8];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_3x1(Matrix<Element, 3, 1> const &m, int i = 0, int j = 0) {
data[i * 4 + j + 0] = m.data[0];
data[i * 4 + j + 4] = m.data[1];
data[i * 4 + j + 8] = m.data[2];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 2> slice_3x2(int i = 0, int j = 0) const {
Matrix<Element, 3, 2> m;
m.data[0] = data[i * 4 + j + 0];
m.data[1] = data[i * 4 + j + 1];
m.data[2] = data[i * 4 + j + 4];
m.data[3] = data[i * 4 + j + 5];
m.data[4] = data[i * 4 + j + 8];
m.data[5] = data[i * 4 + j + 9];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_3x2(Matrix<Element, 3, 2> const &m, int i = 0, int j = 0) {
data[i * 4 + j + 0] = m.data[0];
data[i * 4 + j + 1] = m.data[1];
data[i * 4 + j + 4] = m.data[2];
data[i * 4 + j + 5] = m.data[3];
data[i * 4 + j + 8] = m.data[4];
data[i * 4 + j + 9] = m.data[5];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 3> slice_3x3(int i = 0, int j = 0) const {
Matrix<Element, 3, 3> m;
m.data[0] = data[i * 4 + j + 0];
m.data[1] = data[i * 4 + j + 1];
m.data[2] = data[i * 4 + j + 2];
m.data[3] = data[i * 4 + j + 4];
m.data[4] = data[i * 4 + j + 5];
m.data[5] = data[i * 4 + j + 6];
m.data[6] = data[i * 4 + j + 8];
m.data[7] = data[i * 4 + j + 9];
m.data[8] = data[i * 4 + j + 10];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_3x3(Matrix<Element, 3, 3> const &m, int i = 0, int j = 0) {
data[i * 4 + j + 0] = m.data[0];
data[i * 4 + j + 1] = m.data[1];
data[i * 4 + j + 2] = m.data[2];
data[i * 4 + j + 4] = m.data[3];
data[i * 4 + j + 5] = m.data[4];
data[i * 4 + j + 6] = m.data[5];
data[i * 4 + j + 8] = m.data[6];
data[i * 4 + j + 9] = m.data[7];
data[i * 4 + j + 10] = m.data[8];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 3, 4> slice_3x4(int i = 0, int j = 0) const {
Matrix<Element, 3, 4> m;
m.data[0] = data[i * 4 + j + 0];
m.data[1] = data[i * 4 + j + 1];
m.data[2] = data[i * 4 + j + 2];
m.data[3] = data[i * 4 + j + 3];
m.data[4] = data[i * 4 + j + 4];
m.data[5] = data[i * 4 + j + 5];
m.data[6] = data[i * 4 + j + 6];
m.data[7] = data[i * 4 + j + 7];
m.data[8] = data[i * 4 + j + 8];
m.data[9] = data[i * 4 + j + 9];
m.data[10] = data[i * 4 + j + 10];
m.data[11] = data[i * 4 + j + 11];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_3x4(Matrix<Element, 3, 4> const &m, int i = 0, int j = 0) {
data[i * 4 + j + 0] = m.data[0];
data[i * 4 + j + 1] = m.data[1];
data[i * 4 + j + 2] = m.data[2];
data[i * 4 + j + 3] = m.data[3];
data[i * 4 + j + 4] = m.data[4];
data[i * 4 + j + 5] = m.data[5];
data[i * 4 + j + 6] = m.data[6];
data[i * 4 + j + 7] = m.data[7];
data[i * 4 + j + 8] = m.data[8];
data[i * 4 + j + 9] = m.data[9];
data[i * 4 + j + 10] = m.data[10];
data[i * 4 + j + 11] = m.data[11];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 1> slice_4x1(int i = 0, int j = 0) const {
Matrix<Element, 4, 1> m;
m.data[0] = data[i * 4 + j + 0];
m.data[1] = data[i * 4 + j + 4];
m.data[2] = data[i * 4 + j + 8];
m.data[3] = data[i * 4 + j + 12];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_4x1(Matrix<Element, 4, 1> const &m, int i = 0, int j = 0) {
data[i * 4 + j + 0] = m.data[0];
data[i * 4 + j + 4] = m.data[1];
data[i * 4 + j + 8] = m.data[2];
data[i * 4 + j + 12] = m.data[3];
return *this;
}
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 1> column(int j) const {
return slice_4x1(0, j);
}
CUTLASS_HOST_DEVICE
Matrix &set_column(Matrix<Element, 4, 1> const &v, int j =0) {
return set_slice_4x1(v, 0, j);
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 2> slice_4x2(int i = 0, int j = 0) const {
Matrix<Element, 4, 2> m;
m.data[0] = data[i * 4 + j + 0];
m.data[1] = data[i * 4 + j + 1];
m.data[2] = data[i * 4 + j + 4];
m.data[3] = data[i * 4 + j + 5];
m.data[4] = data[i * 4 + j + 8];
m.data[5] = data[i * 4 + j + 9];
m.data[6] = data[i * 4 + j + 12];
m.data[7] = data[i * 4 + j + 13];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_4x2(Matrix<Element, 4, 2> const &m, int i = 0, int j = 0) {
data[i * 4 + j + 0] = m.data[0];
data[i * 4 + j + 1] = m.data[1];
data[i * 4 + j + 4] = m.data[2];
data[i * 4 + j + 5] = m.data[3];
data[i * 4 + j + 8] = m.data[4];
data[i * 4 + j + 9] = m.data[5];
data[i * 4 + j + 12] = m.data[6];
data[i * 4 + j + 13] = m.data[7];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 3> slice_4x3(int i = 0, int j = 0) const {
Matrix<Element, 4, 3> m;
m.data[0] = data[i * 4 + j + 0];
m.data[1] = data[i * 4 + j + 1];
m.data[2] = data[i * 4 + j + 2];
m.data[3] = data[i * 4 + j + 4];
m.data[4] = data[i * 4 + j + 5];
m.data[5] = data[i * 4 + j + 6];
m.data[6] = data[i * 4 + j + 8];
m.data[7] = data[i * 4 + j + 9];
m.data[8] = data[i * 4 + j + 10];
m.data[9] = data[i * 4 + j + 12];
m.data[10] = data[i * 4 + j + 13];
m.data[11] = data[i * 4 + j + 14];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_4x3(Matrix<Element, 4, 3> const &m, int i = 0, int j = 0) {
data[i * 4 + j + 0] = m.data[0];
data[i * 4 + j + 1] = m.data[1];
data[i * 4 + j + 2] = m.data[2];
data[i * 4 + j + 4] = m.data[3];
data[i * 4 + j + 5] = m.data[4];
data[i * 4 + j + 6] = m.data[5];
data[i * 4 + j + 8] = m.data[6];
data[i * 4 + j + 9] = m.data[7];
data[i * 4 + j + 10] = m.data[8];
data[i * 4 + j + 12] = m.data[9];
data[i * 4 + j + 13] = m.data[10];
data[i * 4 + j + 14] = m.data[11];
return *this;
}
/// Gets a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 4> slice_4x4(int i = 0, int j = 0) const {
Matrix<Element, 4, 4> m;
m.data[0] = data[i * 4 + j + 0];
m.data[1] = data[i * 4 + j + 1];
m.data[2] = data[i * 4 + j + 2];
m.data[3] = data[i * 4 + j + 3];
m.data[4] = data[i * 4 + j + 4];
m.data[5] = data[i * 4 + j + 5];
m.data[6] = data[i * 4 + j + 6];
m.data[7] = data[i * 4 + j + 7];
m.data[8] = data[i * 4 + j + 8];
m.data[9] = data[i * 4 + j + 9];
m.data[10] = data[i * 4 + j + 10];
m.data[11] = data[i * 4 + j + 11];
m.data[12] = data[i * 4 + j + 12];
m.data[13] = data[i * 4 + j + 13];
m.data[14] = data[i * 4 + j + 14];
m.data[15] = data[i * 4 + j + 15];
return m;
}
/// Overwrites a submatrix with optional offset
CUTLASS_HOST_DEVICE
Matrix & set_slice_4x4(Matrix<Element, 4, 4> const &m, int i = 0, int j = 0) {
data[i * 4 + j + 0] = m.data[0];
data[i * 4 + j + 1] = m.data[1];
data[i * 4 + j + 2] = m.data[2];
data[i * 4 + j + 3] = m.data[3];
data[i * 4 + j + 4] = m.data[4];
data[i * 4 + j + 5] = m.data[5];
data[i * 4 + j + 6] = m.data[6];
data[i * 4 + j + 7] = m.data[7];
data[i * 4 + j + 8] = m.data[8];
data[i * 4 + j + 9] = m.data[9];
data[i * 4 + j + 10] = m.data[10];
data[i * 4 + j + 11] = m.data[11];
data[i * 4 + j + 12] = m.data[12];
data[i * 4 + j + 13] = m.data[13];
data[i * 4 + j + 14] = m.data[14];
data[i * 4 + j + 15] = m.data[15];
return *this;
}
/// Forms a 4-by-4 matrix by horizontally concatenating a 4-by-1 matrix with a 4-by-3 matrix
CUTLASS_HOST_DEVICE
static Matrix hcat(Matrix<Element, 4, 1> const & lhs, Matrix<Element, 4, 3> const & rhs) {
return Matrix(
lhs.at(0, 0), rhs.at(0, 0), rhs.at(0, 1), rhs.at(0, 2)
, lhs.at(1, 0), rhs.at(1, 0), rhs.at(1, 1), rhs.at(1, 2)
, lhs.at(2, 0), rhs.at(2, 0), rhs.at(2, 1), rhs.at(2, 2)
, lhs.at(3, 0), rhs.at(3, 0), rhs.at(3, 1), rhs.at(3, 2));
}
/// Forms a 4-by-4 matrix by horizontally concatenating a 4-by-2 matrix with a 4-by-2 matrix
CUTLASS_HOST_DEVICE
static Matrix hcat(Matrix<Element, 4, 2> const & lhs, Matrix<Element, 4, 2> const & rhs) {
return Matrix(
lhs.at(0, 0), lhs.at(0, 1), rhs.at(0, 0), rhs.at(0, 1)
, lhs.at(1, 0), lhs.at(1, 1), rhs.at(1, 0), rhs.at(1, 1)
, lhs.at(2, 0), lhs.at(2, 1), rhs.at(2, 0), rhs.at(2, 1)
, lhs.at(3, 0), lhs.at(3, 1), rhs.at(3, 0), rhs.at(3, 1));
}
/// Forms a 4-by-4 matrix by horizontally concatenating a 4-by-3 matrix with a 4-by-1 matrix
CUTLASS_HOST_DEVICE
static Matrix hcat(Matrix<Element, 4, 3> const & lhs, Matrix<Element, 4, 1> const & rhs) {
return Matrix(
lhs.at(0, 0), lhs.at(0, 1), lhs.at(0, 2), rhs.at(0, 0)
, lhs.at(1, 0), lhs.at(1, 1), lhs.at(1, 2), rhs.at(1, 0)
, lhs.at(2, 0), lhs.at(2, 1), lhs.at(2, 2), rhs.at(2, 0)
, lhs.at(3, 0), lhs.at(3, 1), lhs.at(3, 2), rhs.at(3, 0));
}
/// Forms a 4-by-4 matrix by vertically concatenating a 1-by-4 matrix with a 3-by-4 matrix
CUTLASS_HOST_DEVICE
static Matrix vcat(Matrix<Element, 1, 4> const & upper, Matrix<Element, 3, 4> const & lower) {
return Matrix(
upper.at(0, 0), upper.at(0, 1), upper.at(0, 2), upper.at(0, 3)
, lower.at(0, 0), lower.at(0, 1), lower.at(0, 2), lower.at(0, 3)
, lower.at(1, 0), lower.at(1, 1), lower.at(1, 2), lower.at(1, 3)
, lower.at(2, 0), lower.at(2, 1), lower.at(2, 2), lower.at(2, 3));
}
/// Forms a 4-by-4 matrix by vertically concatenating a 2-by-4 matrix with a 2-by-4 matrix
CUTLASS_HOST_DEVICE
static Matrix vcat(Matrix<Element, 2, 4> const & upper, Matrix<Element, 2, 4> const & lower) {
return Matrix(
upper.at(0, 0), upper.at(0, 1), upper.at(0, 2), upper.at(0, 3)
, upper.at(1, 0), upper.at(1, 1), upper.at(1, 2), upper.at(1, 3)
, lower.at(0, 0), lower.at(0, 1), lower.at(0, 2), lower.at(0, 3)
, lower.at(1, 0), lower.at(1, 1), lower.at(1, 2), lower.at(1, 3));
}
/// Forms a 4-by-4 matrix by vertically concatenating a 3-by-4 matrix with a 1-by-4 matrix
CUTLASS_HOST_DEVICE
static Matrix vcat(Matrix<Element, 3, 4> const & upper, Matrix<Element, 1, 4> const & lower) {
return Matrix(
upper.at(0, 0), upper.at(0, 1), upper.at(0, 2), upper.at(0, 3)
, upper.at(1, 0), upper.at(1, 1), upper.at(1, 2), upper.at(1, 3)
, upper.at(2, 0), upper.at(2, 1), upper.at(2, 2), upper.at(2, 3)
, lower.at(0, 0), lower.at(0, 1), lower.at(0, 2), lower.at(0, 3));
}
/// Forms a 4-by-4 matrix by concatenating four components
CUTLASS_HOST_DEVICE
static Matrix block(
Element A, Matrix<Element, 1, 3> const & B,
Matrix<Element, 3, 1> const & C, Matrix<Element, 3, 3> const & D) {
return Matrix(
A, B.at(0, 0), B.at(0, 1), B.at(0, 2)
, C.at(0, 0), D.at(0, 0), D.at(0, 1), D.at(0, 2)
, C.at(1, 0), D.at(1, 0), D.at(1, 1), D.at(1, 2)
, C.at(2, 0), D.at(2, 0), D.at(2, 1), D.at(2, 2)
);
}
/// Forms a 4-by-4 matrix by concatenating four components
CUTLASS_HOST_DEVICE
static Matrix block(
Matrix<Element, 1, 2> const & A, Matrix<Element, 1, 2> const & B,
Matrix<Element, 3, 2> const & C, Matrix<Element, 3, 2> const & D) {
return Matrix(
A.at(0, 0), A.at(0, 1), B.at(0, 0), B.at(0, 1)
, C.at(0, 0), C.at(0, 1), D.at(0, 0), D.at(0, 1)
, C.at(1, 0), C.at(1, 1), D.at(1, 0), D.at(1, 1)
, C.at(2, 0), C.at(2, 1), D.at(2, 0), D.at(2, 1)
);
}
/// Forms a 4-by-4 matrix by concatenating four components
CUTLASS_HOST_DEVICE
static Matrix block(
Matrix<Element, 1, 3> const & A, Element B,
Matrix<Element, 3, 3> const & C, Matrix<Element, 3, 1> const & D) {
return Matrix(
A.at(0, 0), A.at(0, 1), A.at(0, 2), B
, C.at(0, 0), C.at(0, 1), C.at(0, 2), D.at(0, 0)
, C.at(1, 0), C.at(1, 1), C.at(1, 2), D.at(1, 0)
, C.at(2, 0), C.at(2, 1), C.at(2, 2), D.at(2, 0)
);
}
/// Forms a 4-by-4 matrix by concatenating four components
CUTLASS_HOST_DEVICE
static Matrix block(
Matrix<Element, 2, 1> const & A, Matrix<Element, 2, 3> const & B,
Matrix<Element, 2, 1> const & C, Matrix<Element, 2, 3> const & D) {
return Matrix(
A.at(0, 0), B.at(0, 0), B.at(0, 1), B.at(0, 2)
, A.at(1, 0), B.at(1, 0), B.at(1, 1), B.at(1, 2)
, C.at(0, 0), D.at(0, 0), D.at(0, 1), D.at(0, 2)
, C.at(1, 0), D.at(1, 0), D.at(1, 1), D.at(1, 2)
);
}
/// Forms a 4-by-4 matrix by concatenating four components
CUTLASS_HOST_DEVICE
static Matrix block(
Matrix<Element, 2, 2> const & A, Matrix<Element, 2, 2> const & B,
Matrix<Element, 2, 2> const & C, Matrix<Element, 2, 2> const & D) {
return Matrix(
A.at(0, 0), A.at(0, 1), B.at(0, 0), B.at(0, 1)
, A.at(1, 0), A.at(1, 1), B.at(1, 0), B.at(1, 1)
, C.at(0, 0), C.at(0, 1), D.at(0, 0), D.at(0, 1)
, C.at(1, 0), C.at(1, 1), D.at(1, 0), D.at(1, 1)
);
}
/// Forms a 4-by-4 matrix by concatenating four components
CUTLASS_HOST_DEVICE
static Matrix block(
Matrix<Element, 2, 3> const & A, Matrix<Element, 2, 1> const & B,
Matrix<Element, 2, 3> const & C, Matrix<Element, 2, 1> const & D) {
return Matrix(
A.at(0, 0), A.at(0, 1), A.at(0, 2), B.at(0, 0)
, A.at(1, 0), A.at(1, 1), A.at(1, 2), B.at(1, 0)
, C.at(0, 0), C.at(0, 1), C.at(0, 2), D.at(0, 0)
, C.at(1, 0), C.at(1, 1), C.at(1, 2), D.at(1, 0)
);
}
/// Forms a 4-by-4 matrix by concatenating four components
CUTLASS_HOST_DEVICE
static Matrix block(
Matrix<Element, 3, 1> const & A, Matrix<Element, 3, 3> const & B,
Element C, Matrix<Element, 1, 3> const & D) {
return Matrix(
A.at(0, 0), B.at(0, 0), B.at(0, 1), B.at(0, 2)
, A.at(1, 0), B.at(1, 0), B.at(1, 1), B.at(1, 2)
, A.at(2, 0), B.at(2, 0), B.at(2, 1), B.at(2, 2)
, C, D.at(0, 0), D.at(0, 1), D.at(0, 2)
);
}
/// Forms a 4-by-4 matrix by concatenating four components
CUTLASS_HOST_DEVICE
static Matrix block(
Matrix<Element, 3, 2> const & A, Matrix<Element, 3, 2> const & B,
Matrix<Element, 1, 2> const & C, Matrix<Element, 1, 2> const & D) {
return Matrix(
A.at(0, 0), A.at(0, 1), B.at(0, 0), B.at(0, 1)
, A.at(1, 0), A.at(1, 1), B.at(1, 0), B.at(1, 1)
, A.at(2, 0), A.at(2, 1), B.at(2, 0), B.at(2, 1)
, C.at(0, 0), C.at(0, 1), D.at(0, 0), D.at(0, 1)
);
}
/// Forms a 4-by-4 matrix by concatenating four components
CUTLASS_HOST_DEVICE
static Matrix block(
Matrix<Element, 3, 3> const & A, Matrix<Element, 3, 1> const & B,
Matrix<Element, 1, 3> const & C, Element D) {
return Matrix(
A.at(0, 0), A.at(0, 1), A.at(0, 2), B.at(0, 0)
, A.at(1, 0), A.at(1, 1), A.at(1, 2), B.at(1, 0)
, A.at(2, 0), A.at(2, 1), A.at(2, 2), B.at(2, 0)
, C.at(0, 0), C.at(0, 1), C.at(0, 2), D
);
}
/// Elementwise add operator (4-by-4)
CUTLASS_HOST_DEVICE
Matrix add(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] + rhs.data[0];
result.data[1] = data[1] + rhs.data[1];
result.data[2] = data[2] + rhs.data[2];
result.data[3] = data[3] + rhs.data[3];
result.data[4] = data[4] + rhs.data[4];
result.data[5] = data[5] + rhs.data[5];
result.data[6] = data[6] + rhs.data[6];
result.data[7] = data[7] + rhs.data[7];
result.data[8] = data[8] + rhs.data[8];
result.data[9] = data[9] + rhs.data[9];
result.data[10] = data[10] + rhs.data[10];
result.data[11] = data[11] + rhs.data[11];
result.data[12] = data[12] + rhs.data[12];
result.data[13] = data[13] + rhs.data[13];
result.data[14] = data[14] + rhs.data[14];
result.data[15] = data[15] + rhs.data[15];
return result;
}
/// Elementwise add operator (4-by-4)
CUTLASS_HOST_DEVICE
Matrix operator +(Matrix const &rhs) const {
return add(rhs);
}
/// Elementwise add operator (4-by-4)
CUTLASS_HOST_DEVICE
Matrix & operator +=(Matrix const &rhs) {
data[0] += rhs.data[0];
data[1] += rhs.data[1];
data[2] += rhs.data[2];
data[3] += rhs.data[3];
data[4] += rhs.data[4];
data[5] += rhs.data[5];
data[6] += rhs.data[6];
data[7] += rhs.data[7];
data[8] += rhs.data[8];
data[9] += rhs.data[9];
data[10] += rhs.data[10];
data[11] += rhs.data[11];
data[12] += rhs.data[12];
data[13] += rhs.data[13];
data[14] += rhs.data[14];
data[15] += rhs.data[15];
return *this;
}
/// Elementwise subtract operator (4-by-4)
CUTLASS_HOST_DEVICE
Matrix subtract(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] - rhs.data[0];
result.data[1] = data[1] - rhs.data[1];
result.data[2] = data[2] - rhs.data[2];
result.data[3] = data[3] - rhs.data[3];
result.data[4] = data[4] - rhs.data[4];
result.data[5] = data[5] - rhs.data[5];
result.data[6] = data[6] - rhs.data[6];
result.data[7] = data[7] - rhs.data[7];
result.data[8] = data[8] - rhs.data[8];
result.data[9] = data[9] - rhs.data[9];
result.data[10] = data[10] - rhs.data[10];
result.data[11] = data[11] - rhs.data[11];
result.data[12] = data[12] - rhs.data[12];
result.data[13] = data[13] - rhs.data[13];
result.data[14] = data[14] - rhs.data[14];
result.data[15] = data[15] - rhs.data[15];
return result;
}
/// Elementwise subtract operator (4-by-4)
CUTLASS_HOST_DEVICE
Matrix operator -(Matrix const &rhs) const {
return subtract(rhs);
}
/// Elementwise subtract operator (4-by-4)
CUTLASS_HOST_DEVICE
Matrix & operator -=(Matrix const &rhs) {
data[0] -= rhs.data[0];
data[1] -= rhs.data[1];
data[2] -= rhs.data[2];
data[3] -= rhs.data[3];
data[4] -= rhs.data[4];
data[5] -= rhs.data[5];
data[6] -= rhs.data[6];
data[7] -= rhs.data[7];
data[8] -= rhs.data[8];
data[9] -= rhs.data[9];
data[10] -= rhs.data[10];
data[11] -= rhs.data[11];
data[12] -= rhs.data[12];
data[13] -= rhs.data[13];
data[14] -= rhs.data[14];
data[15] -= rhs.data[15];
return *this;
}
/// Elementwise multiply operator (4-by-4)
CUTLASS_HOST_DEVICE
Matrix multiply(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] * rhs.data[0];
result.data[1] = data[1] * rhs.data[1];
result.data[2] = data[2] * rhs.data[2];
result.data[3] = data[3] * rhs.data[3];
result.data[4] = data[4] * rhs.data[4];
result.data[5] = data[5] * rhs.data[5];
result.data[6] = data[6] * rhs.data[6];
result.data[7] = data[7] * rhs.data[7];
result.data[8] = data[8] * rhs.data[8];
result.data[9] = data[9] * rhs.data[9];
result.data[10] = data[10] * rhs.data[10];
result.data[11] = data[11] * rhs.data[11];
result.data[12] = data[12] * rhs.data[12];
result.data[13] = data[13] * rhs.data[13];
result.data[14] = data[14] * rhs.data[14];
result.data[15] = data[15] * rhs.data[15];
return result;
}
/// Scalar multiply operator (4-by-4)
CUTLASS_HOST_DEVICE
Matrix multiply(Element const &s) const {
Matrix result;
result.data[0] = data[0] * s;
result.data[1] = data[1] * s;
result.data[2] = data[2] * s;
result.data[3] = data[3] * s;
result.data[4] = data[4] * s;
result.data[5] = data[5] * s;
result.data[6] = data[6] * s;
result.data[7] = data[7] * s;
result.data[8] = data[8] * s;
result.data[9] = data[9] * s;
result.data[10] = data[10] * s;
result.data[11] = data[11] * s;
result.data[12] = data[12] * s;
result.data[13] = data[13] * s;
result.data[14] = data[14] * s;
result.data[15] = data[15] * s;
return result;
}
/// Scalar multiply operator (4-by-4)
CUTLASS_HOST_DEVICE
Matrix operator *(Element const &s) const {
return multiply(s);
}
/// Scalar multiply operator (4-by-4)
CUTLASS_HOST_DEVICE
Matrix & operator *=(Element const &s) {
data[0] *= s;
data[1] *= s;
data[2] *= s;
data[3] *= s;
data[4] *= s;
data[5] *= s;
data[6] *= s;
data[7] *= s;
data[8] *= s;
data[9] *= s;
data[10] *= s;
data[11] *= s;
data[12] *= s;
data[13] *= s;
data[14] *= s;
data[15] *= s;
return *this;
}
/// Elementwise divide operator (4-by-4)
CUTLASS_HOST_DEVICE
Matrix divide(Matrix const &rhs) const {
Matrix result;
result.data[0] = data[0] / rhs.data[0];
result.data[1] = data[1] / rhs.data[1];
result.data[2] = data[2] / rhs.data[2];
result.data[3] = data[3] / rhs.data[3];
result.data[4] = data[4] / rhs.data[4];
result.data[5] = data[5] / rhs.data[5];
result.data[6] = data[6] / rhs.data[6];
result.data[7] = data[7] / rhs.data[7];
result.data[8] = data[8] / rhs.data[8];
result.data[9] = data[9] / rhs.data[9];
result.data[10] = data[10] / rhs.data[10];
result.data[11] = data[11] / rhs.data[11];
result.data[12] = data[12] / rhs.data[12];
result.data[13] = data[13] / rhs.data[13];
result.data[14] = data[14] / rhs.data[14];
result.data[15] = data[15] / rhs.data[15];
return result;
}
/// Scalar divide operator (4-by-4)
CUTLASS_HOST_DEVICE
Matrix divide(Element const &s) const {
Matrix result;
result.data[0] = data[0] / s;
result.data[1] = data[1] / s;
result.data[2] = data[2] / s;
result.data[3] = data[3] / s;
result.data[4] = data[4] / s;
result.data[5] = data[5] / s;
result.data[6] = data[6] / s;
result.data[7] = data[7] / s;
result.data[8] = data[8] / s;
result.data[9] = data[9] / s;
result.data[10] = data[10] / s;
result.data[11] = data[11] / s;
result.data[12] = data[12] / s;
result.data[13] = data[13] / s;
result.data[14] = data[14] / s;
result.data[15] = data[15] / s;
return result;
}
/// Scalar divide operator (4-by-4)
CUTLASS_HOST_DEVICE
Matrix operator /(Element const &s) const {
return divide(s);
}
/// Scalar divide operator (4-by-4)
CUTLASS_HOST_DEVICE
Matrix & operator /=(Element const &s) {
data[0] /= s;
data[1] /= s;
data[2] /= s;
data[3] /= s;
data[4] /= s;
data[5] /= s;
data[6] /= s;
data[7] /= s;
data[8] /= s;
data[9] /= s;
data[10] /= s;
data[11] /= s;
data[12] /= s;
data[13] /= s;
data[14] /= s;
data[15] /= s;
return *this;
}
/// Elementwise divide operator (4-by-4)
CUTLASS_HOST_DEVICE
Matrix operator /(Matrix const &rhs) const {
return divide(rhs);
}
/// Elementwise divide operator (4-by-4)
CUTLASS_HOST_DEVICE
Matrix & operator /=(Matrix const &rhs) {
data[0] /= rhs.data[0];
data[1] /= rhs.data[1];
data[2] /= rhs.data[2];
data[3] /= rhs.data[3];
data[4] /= rhs.data[4];
data[5] /= rhs.data[5];
data[6] /= rhs.data[6];
data[7] /= rhs.data[7];
data[8] /= rhs.data[8];
data[9] /= rhs.data[9];
data[10] /= rhs.data[10];
data[11] /= rhs.data[11];
data[12] /= rhs.data[12];
data[13] /= rhs.data[13];
data[14] /= rhs.data[14];
data[15] /= rhs.data[15];
return *this;
}
/// Negates each element of the matrix
CUTLASS_HOST_DEVICE
Matrix operator-() const {
Matrix m;
m.data[0] = -m.data[0];
m.data[1] = -m.data[1];
m.data[2] = -m.data[2];
m.data[3] = -m.data[3];
m.data[4] = -m.data[4];
m.data[5] = -m.data[5];
m.data[6] = -m.data[6];
m.data[7] = -m.data[7];
m.data[8] = -m.data[8];
m.data[9] = -m.data[9];
m.data[10] = -m.data[10];
m.data[11] = -m.data[11];
m.data[12] = -m.data[12];
m.data[13] = -m.data[13];
m.data[14] = -m.data[14];
m.data[15] = -m.data[15];
return m;
}
/// Matrix product of size 4-by-1-by-4
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 1> product(
Matrix<Element, 4, 1> const &rhs,
Matrix<Element, 4, 1> accum = Matrix<Element, 4, 1>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[4] * rhs.data[0];
accum.data[2] += data[8] * rhs.data[0];
accum.data[3] += data[12] * rhs.data[0];
// k=1
accum.data[0] += data[1] * rhs.data[1];
accum.data[1] += data[5] * rhs.data[1];
accum.data[2] += data[9] * rhs.data[1];
accum.data[3] += data[13] * rhs.data[1];
// k=2
accum.data[0] += data[2] * rhs.data[2];
accum.data[1] += data[6] * rhs.data[2];
accum.data[2] += data[10] * rhs.data[2];
accum.data[3] += data[14] * rhs.data[2];
// k=3
accum.data[0] += data[3] * rhs.data[3];
accum.data[1] += data[7] * rhs.data[3];
accum.data[2] += data[11] * rhs.data[3];
accum.data[3] += data[15] * rhs.data[3];
return accum;
}
/// Matrix product of size 4-by-1-by-4
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 1> operator*(Matrix<Element, 4, 1> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 4-by-2-by-4
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 2> product(
Matrix<Element, 4, 2> const &rhs,
Matrix<Element, 4, 2> accum = Matrix<Element, 4, 2>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[4] * rhs.data[0];
accum.data[3] += data[4] * rhs.data[1];
accum.data[4] += data[8] * rhs.data[0];
accum.data[5] += data[8] * rhs.data[1];
accum.data[6] += data[12] * rhs.data[0];
accum.data[7] += data[12] * rhs.data[1];
// k=1
accum.data[0] += data[1] * rhs.data[2];
accum.data[1] += data[1] * rhs.data[3];
accum.data[2] += data[5] * rhs.data[2];
accum.data[3] += data[5] * rhs.data[3];
accum.data[4] += data[9] * rhs.data[2];
accum.data[5] += data[9] * rhs.data[3];
accum.data[6] += data[13] * rhs.data[2];
accum.data[7] += data[13] * rhs.data[3];
// k=2
accum.data[0] += data[2] * rhs.data[4];
accum.data[1] += data[2] * rhs.data[5];
accum.data[2] += data[6] * rhs.data[4];
accum.data[3] += data[6] * rhs.data[5];
accum.data[4] += data[10] * rhs.data[4];
accum.data[5] += data[10] * rhs.data[5];
accum.data[6] += data[14] * rhs.data[4];
accum.data[7] += data[14] * rhs.data[5];
// k=3
accum.data[0] += data[3] * rhs.data[6];
accum.data[1] += data[3] * rhs.data[7];
accum.data[2] += data[7] * rhs.data[6];
accum.data[3] += data[7] * rhs.data[7];
accum.data[4] += data[11] * rhs.data[6];
accum.data[5] += data[11] * rhs.data[7];
accum.data[6] += data[15] * rhs.data[6];
accum.data[7] += data[15] * rhs.data[7];
return accum;
}
/// Matrix product of size 4-by-2-by-4
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 2> operator*(Matrix<Element, 4, 2> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 4-by-3-by-4
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 3> product(
Matrix<Element, 4, 3> const &rhs,
Matrix<Element, 4, 3> accum = Matrix<Element, 4, 3>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[0] * rhs.data[2];
accum.data[3] += data[4] * rhs.data[0];
accum.data[4] += data[4] * rhs.data[1];
accum.data[5] += data[4] * rhs.data[2];
accum.data[6] += data[8] * rhs.data[0];
accum.data[7] += data[8] * rhs.data[1];
accum.data[8] += data[8] * rhs.data[2];
accum.data[9] += data[12] * rhs.data[0];
accum.data[10] += data[12] * rhs.data[1];
accum.data[11] += data[12] * rhs.data[2];
// k=1
accum.data[0] += data[1] * rhs.data[3];
accum.data[1] += data[1] * rhs.data[4];
accum.data[2] += data[1] * rhs.data[5];
accum.data[3] += data[5] * rhs.data[3];
accum.data[4] += data[5] * rhs.data[4];
accum.data[5] += data[5] * rhs.data[5];
accum.data[6] += data[9] * rhs.data[3];
accum.data[7] += data[9] * rhs.data[4];
accum.data[8] += data[9] * rhs.data[5];
accum.data[9] += data[13] * rhs.data[3];
accum.data[10] += data[13] * rhs.data[4];
accum.data[11] += data[13] * rhs.data[5];
// k=2
accum.data[0] += data[2] * rhs.data[6];
accum.data[1] += data[2] * rhs.data[7];
accum.data[2] += data[2] * rhs.data[8];
accum.data[3] += data[6] * rhs.data[6];
accum.data[4] += data[6] * rhs.data[7];
accum.data[5] += data[6] * rhs.data[8];
accum.data[6] += data[10] * rhs.data[6];
accum.data[7] += data[10] * rhs.data[7];
accum.data[8] += data[10] * rhs.data[8];
accum.data[9] += data[14] * rhs.data[6];
accum.data[10] += data[14] * rhs.data[7];
accum.data[11] += data[14] * rhs.data[8];
// k=3
accum.data[0] += data[3] * rhs.data[9];
accum.data[1] += data[3] * rhs.data[10];
accum.data[2] += data[3] * rhs.data[11];
accum.data[3] += data[7] * rhs.data[9];
accum.data[4] += data[7] * rhs.data[10];
accum.data[5] += data[7] * rhs.data[11];
accum.data[6] += data[11] * rhs.data[9];
accum.data[7] += data[11] * rhs.data[10];
accum.data[8] += data[11] * rhs.data[11];
accum.data[9] += data[15] * rhs.data[9];
accum.data[10] += data[15] * rhs.data[10];
accum.data[11] += data[15] * rhs.data[11];
return accum;
}
/// Matrix product of size 4-by-3-by-4
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 3> operator*(Matrix<Element, 4, 3> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 4-by-4-by-4
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 4> product(
Matrix<Element, 4, 4> const &rhs,
Matrix<Element, 4, 4> accum = Matrix<Element, 4, 4>()
) const {
// k=0
accum.data[0] += data[0] * rhs.data[0];
accum.data[1] += data[0] * rhs.data[1];
accum.data[2] += data[0] * rhs.data[2];
accum.data[3] += data[0] * rhs.data[3];
accum.data[4] += data[4] * rhs.data[0];
accum.data[5] += data[4] * rhs.data[1];
accum.data[6] += data[4] * rhs.data[2];
accum.data[7] += data[4] * rhs.data[3];
accum.data[8] += data[8] * rhs.data[0];
accum.data[9] += data[8] * rhs.data[1];
accum.data[10] += data[8] * rhs.data[2];
accum.data[11] += data[8] * rhs.data[3];
accum.data[12] += data[12] * rhs.data[0];
accum.data[13] += data[12] * rhs.data[1];
accum.data[14] += data[12] * rhs.data[2];
accum.data[15] += data[12] * rhs.data[3];
// k=1
accum.data[0] += data[1] * rhs.data[4];
accum.data[1] += data[1] * rhs.data[5];
accum.data[2] += data[1] * rhs.data[6];
accum.data[3] += data[1] * rhs.data[7];
accum.data[4] += data[5] * rhs.data[4];
accum.data[5] += data[5] * rhs.data[5];
accum.data[6] += data[5] * rhs.data[6];
accum.data[7] += data[5] * rhs.data[7];
accum.data[8] += data[9] * rhs.data[4];
accum.data[9] += data[9] * rhs.data[5];
accum.data[10] += data[9] * rhs.data[6];
accum.data[11] += data[9] * rhs.data[7];
accum.data[12] += data[13] * rhs.data[4];
accum.data[13] += data[13] * rhs.data[5];
accum.data[14] += data[13] * rhs.data[6];
accum.data[15] += data[13] * rhs.data[7];
// k=2
accum.data[0] += data[2] * rhs.data[8];
accum.data[1] += data[2] * rhs.data[9];
accum.data[2] += data[2] * rhs.data[10];
accum.data[3] += data[2] * rhs.data[11];
accum.data[4] += data[6] * rhs.data[8];
accum.data[5] += data[6] * rhs.data[9];
accum.data[6] += data[6] * rhs.data[10];
accum.data[7] += data[6] * rhs.data[11];
accum.data[8] += data[10] * rhs.data[8];
accum.data[9] += data[10] * rhs.data[9];
accum.data[10] += data[10] * rhs.data[10];
accum.data[11] += data[10] * rhs.data[11];
accum.data[12] += data[14] * rhs.data[8];
accum.data[13] += data[14] * rhs.data[9];
accum.data[14] += data[14] * rhs.data[10];
accum.data[15] += data[14] * rhs.data[11];
// k=3
accum.data[0] += data[3] * rhs.data[12];
accum.data[1] += data[3] * rhs.data[13];
accum.data[2] += data[3] * rhs.data[14];
accum.data[3] += data[3] * rhs.data[15];
accum.data[4] += data[7] * rhs.data[12];
accum.data[5] += data[7] * rhs.data[13];
accum.data[6] += data[7] * rhs.data[14];
accum.data[7] += data[7] * rhs.data[15];
accum.data[8] += data[11] * rhs.data[12];
accum.data[9] += data[11] * rhs.data[13];
accum.data[10] += data[11] * rhs.data[14];
accum.data[11] += data[11] * rhs.data[15];
accum.data[12] += data[15] * rhs.data[12];
accum.data[13] += data[15] * rhs.data[13];
accum.data[14] += data[15] * rhs.data[14];
accum.data[15] += data[15] * rhs.data[15];
return accum;
}
/// Matrix product of size 4-by-4-by-4
CUTLASS_HOST_DEVICE
Matrix<Element, 4, 4> operator*(Matrix<Element, 4, 4> const &rhs) const {
return product(rhs);
}
/// Matrix product of size 4-by-4-by-4
CUTLASS_HOST_DEVICE
Matrix & operator*=(Matrix<Element, 4, 4> const &rhs) {
*this = product(rhs);
return *this;
}
/// Returns the sum of elements
CUTLASS_HOST_DEVICE
Element sum(Element accum = Element()) const {
accum += data[0];
accum += data[1];
accum += data[2];
accum += data[3];
accum += data[4];
accum += data[5];
accum += data[6];
accum += data[7];
accum += data[8];
accum += data[9];
accum += data[10];
accum += data[11];
accum += data[12];
accum += data[13];
accum += data[14];
accum += data[15];
return accum;
}
/// Returns the sum of squared elements
CUTLASS_HOST_DEVICE
Element norm(Element accum = Element()) const {
accum += data[0] * data[0];
accum += data[1] * data[1];
accum += data[2] * data[2];
accum += data[3] * data[3];
accum += data[4] * data[4];
accum += data[5] * data[5];
accum += data[6] * data[6];
accum += data[7] * data[7];
accum += data[8] * data[8];
accum += data[9] * data[9];
accum += data[10] * data[10];
accum += data[11] * data[11];
accum += data[12] * data[12];
accum += data[13] * data[13];
accum += data[14] * data[14];
accum += data[15] * data[15];
return accum;
}
/// Returns square root of the norm
CUTLASS_HOST_DEVICE
Element magnitude() const {
return fast_sqrt(norm());
}
/// Returns the sum of diagonal elements
CUTLASS_HOST_DEVICE
Element trace(Element accum = Element()) const {
accum += data[0];
accum += data[5];
accum += data[10];
accum += data[15];
return accum;
}
/// Returns 4-by-4 rotation matrix around the X axis
CUTLASS_HOST_DEVICE
static Matrix rotation_X(Element theta) {
Matrix m = identity();
Element c = fast_cos(theta);
Element s = fast_sin(theta);
m.at(1, 1) = c;
m.at(1, 2) = -s;
m.at(2, 1) = s;
m.at(2, 2) = c;
return m;
}
/// Returns 4-by-4 rotation matrix around the Y axis
CUTLASS_HOST_DEVICE
static Matrix rotation_Y(Element theta) {
Matrix m = identity();
Element c = fast_cos(theta);
Element s = fast_sin(theta);
m.at(0, 0) = c;
m.at(2, 0) = -s;
m.at(0, 2) = s;
m.at(2, 2) = c;
return m;
}
/// Returns 4-by-4 rotation matrix around the Z axis
CUTLASS_HOST_DEVICE
static Matrix rotation_Z(Element theta) {
Matrix m = Matrix::identity();
Element c = fast_cos(theta);
Element s = fast_sin(theta);
m.at(0, 0) = c;
m.at(0, 1) = -s;
m.at(1, 0) = s;
m.at(1, 1) = c;
return m;
}
/// Returns a 4-by-4 rotation matrix around a unit-length axis
CUTLASS_HOST_DEVICE
static Matrix rotation(Element theta, Matrix<Element, 3, 1> const &u) {
Element x = u.data[0];
Element y = u.data[1];
Element z = u.data[2];
Element c = fast_cos(theta);
Element s = fast_sin(theta);
Element one_minus_cos = Element(1) - fast_cos(theta);
Matrix m;
m.set_slice3x3({
c + x * x * one_minus_cos, x * y * one_minus_cos - z * s, x * z * one_minus_cos + y * s,
y * x * one_minus_cos * z * s, c + y * y * one_minus_cos, y * z * one_minus_cos - x * s,
z * x * one_minus_cos - y * s, z * y * one_minus_cos + x * s, c + z * z * one_minus_cos
});
return m;
}
/// Returns a 4-by-4 reflection about the plane specified by the
/// unit-length normal vector n_unit
CUTLASS_HOST_DEVICE
static Matrix reflection(Matrix<Element, 3, 1> const &n_unit) {
Element a = n_unit.data[0];
Element b = n_unit.data[1];
Element c = n_unit.data[2];
Matrix m = Matrix::identity();
m.set_slice3x3({
Element(1) - Element(2) * a * a, Element(-2) * a * b, Element(-2) * a * c,
Element(-2) * a * b, Element(1) - Element(2) * b * b, Element(-2) * b * c,
Element(-2) * a * c, Element(-2) * b * c, Element(1) - Element(2) * c * c
});
return m;
}
/// Returns a perspective projection matrix typical of OpenGL applications
CUTLASS_HOST_DEVICE
static Matrix perspective(Element near_plane, Element far_plane, Element fovH, Element fovV) {
Element aspect = fovH / fovV;
Element f = Element(cos(fovV)) / Element(fovH);
Element Q = near_plane - far_plane;
return Matrix(
f / aspect, 0, 0, 0,
0, f, 0, 0,
0, 0, (near_plane + far_plane) / Q, Element(2) * far_plane * near_plane / Q,
0, 0, -1, 0
);
}
CUTLASS_HOST_DEVICE
static Matrix translation(Matrix<Element, 3, 1> const &v) {
return Matrix(
1, 0, 0, v.data[0],
0, 1, 0, v.data[1],
0, 0, 1, v.data[2],
0, 0, 0, 1
);
}
/// Computes the determinant of a 4-by-4 matrix
CUTLASS_HOST_DEVICE
Element determinant(Element accum = Element()) const {
accum += at(0, 0) * Matrix<Element, 3, 3>({ at(1, 1), at(1, 2), at(1, 3), at(2, 1), at(2, 2), at(2, 3), at(3, 1), at(3, 2), at(3, 3) }).determinant();
accum -= at(0, 1) * Matrix<Element, 3, 3>({ at(1, 0), at(1, 2), at(1, 3), at(2, 0), at(2, 2), at(2, 3), at(3, 0), at(3, 2), at(3, 3) }).determinant();
accum += at(0, 2) * Matrix<Element, 3, 3>({ at(1, 0), at(1, 1), at(1, 3), at(2, 0), at(2, 1), at(2, 3), at(3, 0), at(3, 1), at(3, 3) }).determinant();
accum -= at(0, 3) * Matrix<Element, 3, 3>({ at(1, 0), at(1, 1), at(1, 2), at(2, 0), at(2, 1), at(2, 2), at(3, 0), at(3, 1), at(3, 2) }).determinant();
return accum;
}
/// Computes the inverse of a 4-by-4 matrix (ignores the optional argument)
CUTLASS_HOST_DEVICE
Matrix inverse(Element ignore = 1) const {
Matrix<Element, 2, 2> B = slice_2x2(0, 2);
Matrix<Element, 2, 2> A = slice_2x2(0, 0);
Matrix<Element, 2, 2> C = slice_2x2(2, 0);
Matrix<Element, 2, 2> D = slice_2x2(2, 2);
Matrix<Element, 2, 2> D_inv = D.inverse();
Matrix<Element, 2, 2> E = (A - B * D_inv * C).inverse();
return Matrix::block(
E, -E * B * D_inv,
-D_inv * C * E, D_inv + D_inv * C * E * B * D_inv
);
}
};
/// Template alias for 4-by-4 matrix
template <typename Element>
using Matrix4x4 = Matrix<Element, 4, 4>;
/// Free funciton to infer element type from template arguments
template <typename Element>
CUTLASS_HOST_DEVICE Matrix4x4<Element> make_Matrix4x4(
Element _0_0, Element _0_1, Element _0_2, Element _0_3,
Element _1_0, Element _1_1, Element _1_2, Element _1_3,
Element _2_0, Element _2_1, Element _2_2, Element _2_3,
Element _3_0, Element _3_1, Element _3_2, Element _3_3
) {
return Matrix4x4<Element>(
_0_0, _0_1, _0_2, _0_3,
_1_0, _1_1, _1_2, _1_3,
_2_0, _2_1, _2_2, _2_3,
_3_0, _3_1, _3_2, _3_3
);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Elementwise scalar multiplication
template <typename Element, int Rows, int Columns>
CUTLASS_HOST_DEVICE
Matrix<Element, Rows, Columns> operator*(Element s, Matrix<Element, Rows, Columns> const &rhs) {
return rhs.multiply(s);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////