v4.1 release

This commit is contained in:
Junkai-Wu
2025-07-03 20:07:53 +08:00
committed by GitHub
parent b995f93317
commit a1aaf2300a
155 changed files with 18407 additions and 6068 deletions

View File

@ -260,13 +260,10 @@ copy(AutoVectorizingCopyWithAssumedAlignment<MaxVecBits> const&,
{
// If more than one element vectorizes to 8bits or more, then recast and copy
using VecType = uint_bit_t<vec_bits>;
// Preserve volatility
using SrcVecType = conditional_t<is_volatile_v<typename SrcEngine::element_type>, VecType const volatile, VecType const>;
using DstVecType = conditional_t<is_volatile_v<typename DstEngine::element_type>, VecType volatile, VecType >;
// Recast
Tensor src_v = recast<SrcVecType>(src);
Tensor dst_v = recast<DstVecType>(dst);
Tensor src_v = recast<VecType>(src);
Tensor dst_v = recast<VecType>(dst);
return copy_if(constant_fn<true_type>{}, src_v, dst_v);
} else {
return copy_if(constant_fn<true_type>{}, src, dst);

View File

@ -325,21 +325,6 @@ struct TiledCopy : Copy_Atom
return tile2thrfrg(ref_S, right_inverse(AtomLayoutRef{}).compose(AtomLayoutSrc{}))(_,_,Int<0>{});
}
CUTE_HOST_DEVICE constexpr static
auto
get_layoutS_MN()
{
// (thr_idx,val_idx) -> (M,N)
auto layoutS_TV = get_layoutS_TV();
// (M,K) -> (thr_idx,val_idx)
auto layoutS_MK = right_inverse(layoutS_TV).with_shape(shape(Tiler_MN{}));
// athrid = (v,m,k) -> thr_idx
auto thrID_S = make_layout(size<0>(TiledLayout_TV{}));
return cute::make_tuple(layoutS_MK, thrID_S);
}
CUTE_HOST_DEVICE constexpr static
auto
get_layoutD_TV()
@ -350,21 +335,6 @@ struct TiledCopy : Copy_Atom
return tile2thrfrg(ref_D, right_inverse(AtomLayoutRef{}).compose(AtomLayoutDst{}))(_,_,Int<0>{});
}
CUTE_HOST_DEVICE constexpr static
auto
get_layoutD_MN()
{
// (thr_idx,val_idx) -> (M,N)
auto layoutD_TV = get_layoutD_TV();
// (M,K) -> (thr_idx,val_idx)
auto layoutD_MK = right_inverse(layoutD_TV).with_shape(shape(Tiler_MN{}));
// athrid = (v,m,k) -> thr_idx
auto thrID_D = make_layout(size<0>(TiledLayout_TV{}));
return cute::make_tuple(layoutD_MK, thrID_D);
}
template <class ThrIdx,
__CUTE_REQUIRES(is_integral<ThrIdx>::value)>
CUTE_HOST_DEVICE static
@ -680,101 +650,6 @@ print(ThrCopy<TiledCopy, ThrIdx> const& thr_copy)
print(TiledCopy{});
}
// TiledCopy to LaTeX TikZ
template <class... Args, class TikzColorFn = TikzColor_TV>
CUTE_HOST_DEVICE
auto
print_latex(TiledCopy<Args...> const& copy,
TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string
{
auto [layoutS_MN, thrID_S] = copy.get_layoutS_MN();
auto [layoutD_MN, thrID_D] = copy.get_layoutD_MN();
print_latex_copy(layoutS_MN, thrID_S,
layoutD_MN, thrID_D);
}
// MNK Copy Layout to LaTeX TikZ
template <class LayoutS, class ThrIDS,
class LayoutD, class ThrIDD,
class TikzColorFn = TikzColor_TV>
CUTE_HOST_DEVICE
void
print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and tid -> thr_idx
LayoutD const& D, ThrIDD const& TD, // (m,n) -> (tid,vid) and tid -> thr_idx
TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string
{
CUTE_STATIC_ASSERT_V(rank(S) == Int<2>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<2>{});
assert(size<0>(S) == size<0>(D));
assert(size<1>(S) == size<1>(D));
// Commented prints
printf("%% LayoutS: "); print(S); printf("\n");
printf("%% ThrIDS : "); print(TS); printf("\n");
printf("%% LayoutD: "); print(D); printf("\n");
printf("%% ThrIDD : "); print(TD); printf("\n\n");
// Header
printf("\\documentclass[convert]{standalone}\n"
"\\usepackage{tikz}\n\n"
"\\begin{document}\n"
"\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n");
// S starting at 0,0
for (int i = 0; i < size<0>(S); ++i) {
for (int j = 0; j < size<1>(S); ++j) {
int thrid = S(i,j) % size(TS);
int val_idx = S(i,j) / size(TS);
int thr_idx = TS(thrid);
printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
color(thr_idx, val_idx),
i, j,
thr_idx, val_idx);
}
}
// Grid
printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n",
0, 0, int(size<0>(S)), int(size<1>(S)));
// S Labels
for (int i = 0, j = -1; i < size<0>(S); ++i) {
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i);
}
for (int i = -1, j = 0; j < size<1>(S); ++j) {
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j);
}
// D starting at 0,size<1>(S)+3
for (int i = 0; i < size<0>(D); ++i) {
for (int j = 0; j < size<1>(D); ++j) {
int thrid = D(i,j) % size(TD);
int val_idx = D(i,j) / size(TD);
int thr_idx = TD(thrid);
printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
color(thr_idx, val_idx),
i, j + size<1>(S) + 3,
thr_idx, val_idx);
}
}
// Grid
printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n",
0, int(size<1>(S)+3), int(size<0>(D)), int(size<1>(D)+size<1>(S)+3));
// D Labels
for (int i = 0, j = size<1>(D); i < size<0>(D); ++i) {
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j + size<1>(S) + 3, i);
}
for (int i = -1, j = 0; j < size<1>(D); ++j) {
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j + size<1>(S) + 3, j);
}
// Footer
printf("\\end{tikzpicture}\n"
"\\end{document}\n");
}
} // end namespace cute
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -180,10 +180,10 @@ struct MMA_Atom<MMA_Traits<MMAOperation, Args...>>
if constexpr (has_dereference<FrgTypeB>::value) {
// If the intended FrgTypeB is a view (of the current tensor), forward the whole
static_assert(is_same<ValTypeB, typename remove_cvref_t<BTensor>::value_type>::value
|| (sizeof_bits_v<typename remove_cvref_t<BTensor>::value_type> == 8 &&
(sizeof_bits_v<ValTypeB> == 8 || sizeof_bits_v<ValTypeB> == 6 || sizeof_bits_v<ValTypeB> == 4))
, "Expecting ValTypeB type");
return make_tensor<FrgTypeB>(static_cast<BTensor&&>(btensor));
} else {
@ -394,55 +394,22 @@ struct TiledMMA : MMA_Atom
return size(permutation_mnk<I>());
}
CUTE_HOST_DEVICE constexpr
auto
get_layoutC_MN() const
{
// (M,N) -> (M,N)
auto ref_C = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<1>()));
// (cthrid,val) -> (M,N)
auto layoutC_TV = thrfrg_C(ref_C);
// (M,N) -> (cthrid,frg)
auto layoutC_MN = right_inverse(layoutC_TV).with_shape(shape(ref_C));
// cthrid = (v,m,n) -> thr_idx
auto thrID_C = thr_layout_vmnk_(_,_,_,Int<0>{});
return cute::make_tuple(layoutC_MN, thrID_C);
}
CUTE_HOST_DEVICE constexpr
auto
get_layoutC_TV() const
{
// (M,N) -> (M,N)
auto ref_C = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<1>()));
// (cthrid,val) -> (M,N)
auto layoutC_TV = thrfrg_C(ref_C);
// thr_idx -> (ThrV,ThrM,ThrN,ThrK)
auto thridx_2_thrid = right_inverse(thr_layout_vmnk_);
auto thridx_2_thrid = composition(make_layout(make_shape (size(thr_layout_vmnk_), Int<1>{}),
make_stride(Int<1>{}, Int<0>{})),
right_inverse(make_layout(thr_layout_vmnk_, complement(thr_layout_vmnk_))));
// (thr_idx,val) -> (M,N)
return layoutC_TV.compose(thridx_2_thrid, _);
return thrfrg_C(ref_C).compose(thridx_2_thrid, _);
}
CUTE_HOST_DEVICE constexpr
auto
get_layoutA_MK() const
{
// (M,K) -> (M,K)
auto ref_A = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<2>()));
// (athrid,val) -> (M,K)
auto layoutA_TV = thrfrg_A(ref_A);
// (M,K) -> (athrid,frg)
auto layoutA_MK = right_inverse(layoutA_TV).with_shape(shape(ref_A));
// athrid = (v,m,k) -> thr_idx
auto thrID_A = thr_layout_vmnk_(_,_,Int<0>{},_);
return cute::make_tuple(layoutA_MK, thrID_A);
}
CUTE_HOST_DEVICE constexpr
auto
@ -458,29 +425,14 @@ struct TiledMMA : MMA_Atom
_));
// thr_idx -> (ThrV,ThrM,ThrN,ThrK)
auto thridx_2_thrid = right_inverse(thr_layout_vmnk_);
auto thridx_2_thrid = composition(make_layout(make_shape (size(thr_layout_vmnk_), Int<1>{}),
make_stride(Int<1>{}, Int<0>{})),
right_inverse(make_layout(thr_layout_vmnk_, complement(thr_layout_vmnk_))));
// (thr_idx,val) -> (M,K)
return thrfrg_A(ref_A).compose(atile, _).compose(thridx_2_thrid, _);
}
CUTE_HOST_DEVICE constexpr
auto
get_layoutB_NK() const
{
// (N,K) -> (N,K)
auto ref_B = make_layout(make_shape(tile_size_mnk<1>(), tile_size_mnk<2>()));
// (bthrid,val) -> (N,K)
auto layoutB_TV = thrfrg_B(ref_B);
// (N,K) -> (bthrid,frg)
auto layoutB_NK = right_inverse(layoutB_TV).with_shape(shape(ref_B));
// bthrid = (v,n,k) -> thr_idx
auto thrID_B = thr_layout_vmnk_(_,Int<0>{},_,_);
return cute::make_tuple(layoutB_NK, thrID_B);
}
CUTE_HOST_DEVICE constexpr
auto
get_layoutB_TV() const
@ -495,7 +447,9 @@ struct TiledMMA : MMA_Atom
_));
// thr_idx -> (ThrV,ThrM,ThrN,ThrK)
auto thridx_2_thrid = right_inverse(thr_layout_vmnk_);
auto thridx_2_thrid = composition(make_layout(make_shape (size(thr_layout_vmnk_), Int<1>{}),
make_stride(Int<1>{}, Int<0>{})),
right_inverse(make_layout(thr_layout_vmnk_, complement(thr_layout_vmnk_))));
// (thr_idx,val) -> (N,K)
return thrfrg_B(ref_B).compose(btile, _).compose(thridx_2_thrid, _);
@ -733,376 +687,6 @@ print(ThrMMA<TiledMMA, ThrVMNK> const& thr_mma)
print(static_cast<TiledMMA>(thr_mma));
}
// MMA Atom to LaTeX TikZ
template <class... Args, class TikzColorFn = TikzColor_TV>
CUTE_HOST_DEVICE
void
print_latex(MMA_Atom<Args...> const& mma_atom,
TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string
{
print_latex(make_tiled_mma(mma_atom));
}
// TiledMMA to LaTeX TikZ
template <class... Args, class TikzColorFn = TikzColor_TV>
CUTE_HOST_DEVICE
void
print_latex(TiledMMA<Args...> const& mma,
TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string
{
auto layout_and_thrid_C = mma.get_layoutC_MN();
auto layoutC_MN = get<0>(layout_and_thrid_C);
auto thrID_C = get<1>(layout_and_thrid_C);
auto layout_and_thrid_A = mma.get_layoutA_MK();
auto layoutA_MK = get<0>(layout_and_thrid_A);
auto thrID_A = get<1>(layout_and_thrid_A);
auto layout_and_thrid_B = mma.get_layoutB_NK();
auto layoutB_NK = get<0>(layout_and_thrid_B);
auto thrID_B = get<1>(layout_and_thrid_B);
print_latex_mma(layoutC_MN, thrID_C,
layoutA_MK, thrID_A,
layoutB_NK, thrID_B);
}
// MNK MMA Layout to LaTeX TikZ
template <class LayoutC, class ThrIDC,
class LayoutA, class ThrIDA,
class LayoutB, class ThrIDB,
class TikzColorFn = TikzColor_TV>
CUTE_HOST_DEVICE
void
print_latex_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and tid -> thr_idx
LayoutA const& A, ThrIDA const& TA, // (m,k) -> (tid,vid) and tid -> thr_idx
LayoutB const& B, ThrIDB const& TB, // (n,k) -> (tid,vid) and tid -> thr_idx
TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string
{
CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{});
CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{});
CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{});
assert(size<0>(A) == size<0>(C));
assert(size<0>(B) == size<1>(C));
assert(size<1>(A) == size<1>(B));
// Commented prints
printf("%% LayoutC: "); print(C); printf("\n");
printf("%% ThrIDC : "); print(TC); printf("\n");
printf("%% LayoutA: "); print(A); printf("\n");
printf("%% ThrIDA : "); print(TA); printf("\n");
printf("%% LayoutB: "); print(B); printf("\n");
printf("%% ThrIDB : "); print(TB); printf("\n\n");
// Header
printf("\\documentclass[convert]{standalone}\n"
"\\usepackage{tikz}\n\n"
"\\begin{document}\n"
"\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n");
// C starting at 0,0
for (int m = 0; m < size<0>(C); ++m) {
for (int n = 0; n < size<1>(C); ++n) {
int thrid = C(m,n) % size(TC);
int val_idx = C(m,n) / size(TC);
int thr_idx = TC(thrid);
printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
color(thr_idx, val_idx),
m, n,
thr_idx, val_idx);
}
}
// Grid
printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n",
0, 0, int(size<0>(C)), int(size<1>(C)));
// A starting at 0,-size<1>(A)-1
for (int m = 0; m < size<0>(A); ++m) {
for (int k = 0; k < size<1>(A); ++k) {
int thrid = A(m,k) % size(TA);
int val_idx = A(m,k) / size(TA);
int thr_idx = TA(thrid);
printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
color(thr_idx, val_idx),
m, k-1-size<1>(A),
thr_idx, val_idx);
}
}
// Grid
printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n",
0, int(-size<1>(A)-1), int(size<0>(A)), -1);
// A labels
for (int m = 0, k = -1; m < size<0>(A); ++m) {
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k-1-size<1>(A), m);
}
for (int m = -1, k = 0; k < size<1>(A); ++k) {
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k-1-size<1>(A), k);
}
// B starting at -size<1>(B)-1,0
for (int n = 0; n < size<0>(B); ++n) {
for (int k = 0; k < size<1>(B); ++k) {
int thrid = B(n,k) % size(TB);
int val_idx = B(n,k) / size(TB);
int thr_idx = TB(thrid);
printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
color(thr_idx, val_idx),
k-1-size<1>(B), n,
thr_idx, val_idx);
}
}
// Grid
printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n",
int(-size<1>(B)-1), 0, -1, int(size<0>(B)));
// B labels
for (int n = 0, k = -1; n < size<0>(B); ++n) {
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", k-1-size<1>(B), n, n);
}
for (int n = -1, k = 0; k < size<1>(B); ++k) {
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", k-1-size<1>(B), n, k);
}
// Footer
printf("\\end{tikzpicture}\n"
"\\end{document}\n");
}
// MNK MMA Layout to console printer
template <class LayoutC, class ThrIDC,
class LayoutA, class ThrIDA,
class LayoutB, class ThrIDB>
CUTE_HOST_DEVICE
void
print_layout_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and tid -> thr_idx
LayoutA const& A, ThrIDA const& TA, // (m,k) -> (tid,vid) and tid -> thr_idx
LayoutB const& B, ThrIDB const& TB) // (n,k) -> (tid,vid) and tid -> thr_idx
{
CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{});
CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{});
CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{});
assert(size<0>(A) == size<0>(C));
assert(size<0>(B) == size<1>(C));
assert(size<1>(A) == size<1>(B));
int a_width = size<1>(A) * 6 + 4;
// Print out B (white-shifted) k-by-n
for (int k = 0; k < size<1>(B); ++k) {
// Header
printf("%*s", a_width, "");
for (int n = 0; n < size<0>(B); ++n) printf("+-----");
printf("+\n");
// Values
printf("%*s", a_width, "");
for (int n = 0; n < size<0>(B); ++n) printf("|T%02dV%1d", int(TB(B(n,k) % size(TB))), int(B(n,k) / size(TB)));
printf("|\n");
}
// Footer
printf("%*s", a_width, "");
for (int n = 0; n < size<0>(B); ++n) printf("+-----");
printf("+\n\n");
// Print out A m-by-k and C m-by-n
for (int m = 0; m < size<0>(A); ++m) {
// Header
for (int k = 0; k < size<1>(A); ++k) printf("+-----");
printf("+ ");
for (int n = 0; n < size<1>(C); ++n) printf("+-----");
printf("+\n");
// Values
for (int k = 0; k < size<1>(A); ++k) printf("|T%02dV%1d", int(TA(A(m,k) % size(TA))), int(A(m,k) / size(TA)));
printf("| ");
for (int n = 0; n < size<1>(C); ++n) printf("|T%02dV%1d", int(TC(C(m,n) % size(TC))), int(C(m,n) / size(TC)));
printf("|\n");
}
// Footer
for (int k = 0; k < size<1>(A); ++k) printf("+-----");
printf("+ ");
for (int n = 0; n < size<1>(C); ++n) printf("+-----");
printf("+\n");
}
// MNK MMA Layout to SVG -- 8-value color coded by thread
template <class LayoutC, class ThrIDC,
class LayoutA, class ThrIDA,
class LayoutB, class ThrIDB>
CUTE_HOST_DEVICE
void
print_svg_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and tid -> thr_idx
LayoutA const& A, ThrIDA const& TA, // (m,k) -> (tid,vid) and tid -> thr_idx
LayoutB const& B, ThrIDB const& TB) // (n,k) -> (tid,vid) and tid -> thr_idx
{
char const *color_map[8] = {"175,175,255", "175,255,175", "255,255,175",
"255,175,175", "210,210,255", "210,255,210",
"255,255,210", "255,210,210"};
const int cell_width = 20;
const int cell_height = 20;
const int page_width = (size<1>(A) + size<0>(B) + 2) * cell_width;
const int page_height = (size<1>(B) + size<0>(A) + 2) * cell_height;
// header
printf("<svg width=\"100%%\" height=\"100%%\" viewBox=\"0 0 %d %d\" "
"preserveAspectRatio=\"xMidYMid meet\" "
"xmlns=\"http://www.w3.org/2000/svg\">\n",
page_width, page_height);
// C
int c_base_x = (size<1>(A) + 2) * cell_width;
int c_base_y = (size<1>(B) + 2) * cell_height;
for (int m = 0; m < cute::size<0>(C); ++m) {
for (int n = 0; n < cute::size<1>(C); ++n) {
int thrid = C(m, n) % size(TC);
int val_idx = C(m, n) / size(TC);
int thr_idx = TC(thrid);
int x = n * cell_width + c_base_x;
int y = m * cell_height + c_base_y;
int thr_x = x + cell_width / 2;
int thr_y = y + cell_height / 4;
int val_x = x + cell_width / 2;
int val_y = y + cell_height * 3 / 4;
printf("<rect x=\"%d\" y=\"%d\" width=\"%d\" height=\"%d\" "
"fill=\"rgb(%s)\" stroke=\"black\"/>\n",
x, y, cell_width, cell_height, color_map[thr_idx % 8]);
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
"alignment-baseline=\"central\" font-size=\"8\">T%d</text>\n",
thr_x, thr_y, thr_idx);
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
"alignment-baseline=\"central\" font-size=\"8\">V%d</text>\n",
val_x, val_y, val_idx);
}
}
// A
int a_base_x = cell_width;
int a_base_y = (size<1>(B) + 2) * cell_height;
for (int m = 0; m < size<0>(A); ++m) {
for (int k = 0; k < size<1>(A); ++k) {
int thrid = A(m, k) % size(TA);
int val_idx = A(m, k) / size(TA);
int thr_idx = TA(thrid);
int x = k * cell_width + a_base_x;
int y = m * cell_height + a_base_y;
int thr_x = x + cell_width / 2;
int thr_y = y + cell_height / 4;
int val_x = x + cell_width / 2;
int val_y = y + cell_height * 3 / 4;
printf("<rect x=\"%d\" y=\"%d\" width=\"%d\" height=\"%d\" "
"fill=\"rgb(%s)\" stroke=\"black\" />\n",
x, y, cell_width, cell_height, color_map[thr_idx % 8]);
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
"alignment-baseline=\"central\" font-size=\"8\">T%d</text>\n",
thr_x, thr_y, thr_idx);
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
"alignment-baseline=\"central\" font-size=\"8\">V%d</text>\n",
val_x, val_y, val_idx);
}
}
// B
int b_base_x = (size<1>(A) + 2) * cell_width;
int b_base_y = cell_height;
for (int n = 0; n < size<0>(B); ++n) {
for (int k = 0; k < size<1>(B); ++k) {
int thrid = B(n, k) % size(TB);
int val_idx = B(n, k) / size(TB);
int thr_idx = TB(thrid);
int x = n * cell_width + b_base_x;
int y = k * cell_height + b_base_y;
int thr_x = x + cell_width / 2;
int thr_y = y + cell_height / 4;
int val_x = x + cell_width / 2;
int val_y = y + cell_height * 3 / 4;
printf("<rect x=\"%d\" y=\"%d\" width=\"%d\" height=\"%d\" "
"fill=\"rgb(%s)\" stroke=\"black\" />\n",
x, y, cell_width, cell_height, color_map[thr_idx % 8]);
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
"alignment-baseline=\"central\" font-size=\"8\">T%d</text>\n",
thr_x, thr_y, thr_idx);
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
"alignment-baseline=\"central\" font-size=\"8\">V%d</text>\n",
val_x, val_y, val_idx);
}
}
// A labels
for (int m = 0; m < size<0>(A); ++m) {
int x = cell_width / 2;
int y = m * cell_height + cell_height / 2 + a_base_y;
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
"alignment-baseline=\"central\" font-size=\"12\">%d</text>\n",
x, y, m);
}
for (int k = 0; k < size<1>(A); ++k) {
int x = cell_width + k * cell_width + cell_width / 2;
int y = -cell_height / 2 + a_base_y;
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
"alignment-baseline=\"central\" font-size=\"12\">%d</text>\n",
x, y, k);
}
// B labels
for (int n = 0; n < size<0>(B); ++n) {
int x = b_base_x + cell_width * n + cell_width / 2;
int y = cell_height / 2;
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
"alignment-baseline=\"central\" font-size=\"12\">%d</text>\n",
x, y, n);
}
for (int k = 0; k < size<1>(B); ++k) {
int x = b_base_x - cell_width / 2;
int y = cell_height * (k + 1) + cell_height / 2;
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
"alignment-baseline=\"central\" font-size=\"12\">%d</text>\n",
x, y, k);
}
// footer
printf("</svg>\n");
}
template <class... Args>
CUTE_HOST_DEVICE
void
print_svg(MMA_Atom<Args...> const &mma_atom) {
print_svg(make_tiled_mma(mma_atom));
}
template <class... Args>
CUTE_HOST_DEVICE
void
print_svg(TiledMMA<Args...> const &mma) {
auto layout_and_thrid_C = mma.get_layoutC_MN();
auto layoutC_MN = get<0>(layout_and_thrid_C);
auto thrID_C = get<1>(layout_and_thrid_C);
auto layout_and_thrid_A = mma.get_layoutA_MK();
auto layoutA_MK = get<0>(layout_and_thrid_A);
auto thrID_A = get<1>(layout_and_thrid_A);
auto layout_and_thrid_B = mma.get_layoutB_NK();
auto layoutB_NK = get<0>(layout_and_thrid_B);
auto thrID_B = get<1>(layout_and_thrid_B);
print_svg_mma(layoutC_MN, thrID_C, layoutA_MK, thrID_A, layoutB_NK, thrID_B);
}
} // namespace cute
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -1114,7 +698,7 @@ print_svg(TiledMMA<Args...> const &mma) {
#include <cute/atom/mma_traits_sm89.hpp>
#include <cute/atom/mma_traits_sm90.hpp>
#include <cute/atom/mma_traits_sm90_gmma.hpp>
#include <cute/atom/mma_traits_sm100.hpp>
#include <cute/atom/mma_traits_sm100.hpp>
#include <cute/atom/mma_traits_sm120.hpp>
#include <cute/atom/mma_traits_sm120_sparse.hpp>

View File

@ -3844,4 +3844,39 @@ struct MMA_Traits<SM100_MMA_MXF4NVF4_2x1SM_SS_SPARSE<a_type, b_type, c_type, sf_
}
};
/**
* Specialization for a vectorized FMA per thread.
*/
template <>
struct MMA_Traits<SM100_2x1x1_F32F32F32F32>
{
using ValTypeD = float;
using ValTypeA = float;
using ValTypeB = float;
using ValTypeC = float;
using Shape_MNK = Shape<_2,_1,_1>;
using ThrID = Layout<_1>;
using ALayout = Layout<Shape<_1,_2>>;
using BLayout = Layout<Shape<_1,_1>>;
using CLayout = Layout<Shape<_1,_2>>;
};
template <>
struct MMA_Traits<SM100_1x2x1_F32F32F32F32>
{
using ValTypeD = float;
using ValTypeA = float;
using ValTypeB = float;
using ValTypeC = float;
using Shape_MNK = Shape<_1,_2,_1>;
using ThrID = Layout<_1>;
using ALayout = Layout<Shape<_1,_1>>;
using BLayout = Layout<Shape<_1,_2>>;
using CLayout = Layout<Shape<_1,_2>>;
};
} // end namespace cute

View File

@ -834,7 +834,7 @@ coalesce_x(Layout<Shape,Stride> const& layout)
} else {
return detail::bw_coalesce<R-2>(flat_shape, flat_stride, get<R-1>(flat_shape), get<R-1>(flat_stride));
}
CUTE_GCC_UNREACHABLE;
}
@ -1944,185 +1944,4 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, Layout<Shape,Stride> const&
}
#endif
// Generic 2D Layout to console table
template <class Layout>
CUTE_HOST_DEVICE
void
print_layout(Layout const& layout) // (m,n) -> idx
{
CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{});
int idx_width = num_digits(cosize(layout)) + 2;
const char* delim = "+-----------------------";
print(layout); print("\n");
// Column indices
print(" ");
for (int n = 0; n < size<1>(layout); ++n) { printf(" %*d ", idx_width-2, n); }
printf("\n");
// Print out A m-by-n
for (int m = 0; m < size<0>(layout); ++m) {
// Header
print(" ");
for (int n = 0; n < size<1>(layout); ++n) { printf("%.*s", idx_width+1, delim); }
printf("+\n");
// Values
printf("%2d ", m); // Row indices
for (int n = 0; n < size<1>(layout); ++n) { printf("| %*d ", idx_width-2, int(layout(m,n))); }
printf("|\n");
}
// Footer
print(" ");
for (int n = 0; n < size<1>(layout); ++n) { printf("%.*s", idx_width+1, delim); }
printf("+\n");
}
// Generic ThrVal 2D Layout to console table
template <class Layout, class ThrID>
CUTE_HOST_DEVICE
void
print_layout(Layout const& layout, ThrID const& thrid) // (m,n) -> (tid,vid) and tid -> thr_idx
{
CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{});
print(layout); print("\n");
print(thrid); print("\n");
// Print out m-by-n
for (int m = 0; m < size<0>(layout); ++m) {
// Header
for (int n = 0; n < size<1>(layout); ++n) printf("+------");
printf("+\n");
// Values
for (int n = 0; n < size<1>(layout); ++n) printf("|%03d-%02d", int(thrid(layout(m,n) % size(thrid))), int(layout(m,n) / size(thrid)));
printf("|\n");
}
// Footer
for (int n = 0; n < size<1>(layout); ++n) printf("+------");
printf("+\n");
}
struct TikzColor_White {
CUTE_HOST_DEVICE char const*
operator()(int idx) const {
return "white";
}
};
struct TikzColor_BWx8 {
CUTE_HOST_DEVICE char const*
operator()(int idx) const {
static char const* color_map[8] = {"black!00", "black!40", "black!20", "black!60",
"black!10", "black!50", "black!30", "black!70"};
return color_map[idx % 8];
}
};
struct TikzColor_TV {
CUTE_HOST_DEVICE char const*
operator()(int tid, int vid) const {
static char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}",
"{rgb,255:red,175;green,255;blue,175}",
"{rgb,255:red,255;green,255;blue,175}",
"{rgb,255:red,255;green,175;blue,175}",
"{rgb,255:red,210;green,210;blue,255}",
"{rgb,255:red,210;green,255;blue,210}",
"{rgb,255:red,255;green,255;blue,210}",
"{rgb,255:red,255;green,210;blue,210}"};
return color_map[tid % 8];
}
};
// Generic 2D Layout to LaTeX printer
template <class LayoutA, class TikzColorFn = TikzColor_BWx8>
CUTE_HOST_DEVICE
void
print_latex(LayoutA const& layout_a, // (m,n) -> idx
TikzColorFn color = {}) // lambda(idx) -> tikz color string
{
CUTE_STATIC_ASSERT_V(rank(layout_a) <= Int<2>{});
auto layout = append<2>(layout_a, Layout<_1,_0>{});
// Commented print(layout)
printf("%% Layout: "); print(layout); printf("\n");
// Header
printf("\\documentclass[convert]{standalone}\n"
"\\usepackage{tikz}\n\n"
"\\begin{document}\n"
"\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n");
// Layout
for (int i = 0; i < size<0>(layout); ++i) {
for (int j = 0; j < size<1>(layout); ++j) {
int idx = layout(i,j);
printf("\\node[fill=%s] at (%d,%d) {%d};\n",
color(idx), i, j, idx);
}
}
// Grid
printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (0,0) grid (%d,%d);\n\n",
int(size<0>(layout)), int(size<1>(layout)));
// Labels
for (int i = 0, j = -1; i < size<0>(layout); ++i) {
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i);
}
for (int i = -1, j = 0; j < size<1>(layout); ++j) {
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j);
}
// Footer
printf("\\end{tikzpicture}\n"
"\\end{document}\n");
}
// Generic ThrVal 2D Layout to LaTeX TikZ
template <class Layout, class ThrID, class TikzColorFn = TikzColor_TV>
CUTE_HOST_DEVICE
void
print_latex(Layout const& layout, // (m,n) -> (tid,vid)
ThrID const& thr, // tid -> thr_idx
TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string
{
CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{});
// Commented prints
printf("%% Layout: "); print(layout); printf("\n");
printf("%% ThrID : "); print(thr); printf("\n");
// Header
printf("\\documentclass[convert]{standalone}\n"
"\\usepackage{tikz}\n\n"
"\\begin{document}\n"
"\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n");
// Layout
for (int i = 0; i < size<0>(layout); ++i) {
for (int j = 0; j < size<1>(layout); ++j) {
int thrid = layout(i,j) % size(thr);
int val_idx = layout(i,j) / size(thr);
int thr_idx = thr(thrid);
printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
color(thr_idx, val_idx),
i, j,
thr_idx, val_idx);
}
}
// Grid
printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (0,0) grid (%d,%d);\n\n",
int(size<0>(layout)), int(size<1>(layout)));
// Labels
for (int i = 0, j = -1; i < size<0>(layout); ++i) {
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i);
}
for (int j = 0, i = -1; j < size<1>(layout); ++j) {
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j);
}
// Footer
printf("\\end{tikzpicture}\n"
"\\end{document}\n");
}
} // end namespace cute

View File

@ -42,22 +42,15 @@ namespace cute
{
template <class... T>
struct ArithmeticTuple : tuple<T...>
{
template <class... U>
struct ArithmeticTuple : public tuple<T...> {
CUTE_HOST_DEVICE constexpr
ArithmeticTuple(ArithmeticTuple<U...> const& u)
: tuple<T...>(static_cast<tuple<U...> const&>(u)) {}
ArithmeticTuple() : tuple<T...>() {}
template <class... U>
CUTE_HOST_DEVICE constexpr
ArithmeticTuple(tuple<U...> const& u)
: tuple<T...>(u) {}
ArithmeticTuple(tuple<T...> const& t) : tuple<T...>(t) {}
template <class... U>
CUTE_HOST_DEVICE constexpr
ArithmeticTuple(U const&... u)
: tuple<T...>(u...) {}
ArithmeticTuple(T const&... t) : tuple<T...>(t...) {}
};
template <class... T>
@ -147,12 +140,12 @@ operator-(ArithmeticTuple<T...> const& t) {
}
//
// Special cases
// Special cases for C<0>
//
template <auto t, class... U>
CUTE_HOST_DEVICE constexpr
ArithmeticTuple<U...> const&
ArithmeticTuple<U...>
operator+(C<t>, ArithmeticTuple<U...> const& u) {
static_assert(t == 0, "Arithmetic tuple op+ error!");
return u;
@ -160,7 +153,7 @@ operator+(C<t>, ArithmeticTuple<U...> const& u) {
template <class... T, auto u>
CUTE_HOST_DEVICE constexpr
ArithmeticTuple<T...> const&
ArithmeticTuple<T...>
operator+(ArithmeticTuple<T...> const& t, C<u>) {
static_assert(u == 0, "Arithmetic tuple op+ error!");
return t;
@ -168,7 +161,7 @@ operator+(ArithmeticTuple<T...> const& t, C<u>) {
template <auto t, class... U>
CUTE_HOST_DEVICE constexpr
ArithmeticTuple<U...> const&
ArithmeticTuple<U...>
operator-(C<t>, ArithmeticTuple<U...> const& u) {
static_assert(t == 0, "Arithmetic tuple op- error!");
return -u;
@ -176,7 +169,7 @@ operator-(C<t>, ArithmeticTuple<U...> const& u) {
template <class... T, auto u>
CUTE_HOST_DEVICE constexpr
ArithmeticTuple<T...> const&
ArithmeticTuple<T...>
operator-(ArithmeticTuple<T...> const& t, C<u>) {
static_assert(u == 0, "Arithmetic tuple op- error!");
return t;
@ -212,27 +205,20 @@ struct ArithmeticTupleIterator
}
};
template <class Tuple>
template <class... Ts>
CUTE_HOST_DEVICE constexpr
auto
make_inttuple_iter(Tuple const& t) {
return ArithmeticTupleIterator(as_arithmetic_tuple(t));
}
template <class T0, class T1, class... Ts>
CUTE_HOST_DEVICE constexpr
auto
make_inttuple_iter(T0 const& t0, T1 const& t1, Ts const&... ts) {
return make_inttuple_iter(cute::make_tuple(t0, t1, ts...));
make_inttuple_iter(Ts const&... ts) {
return ArithmeticTupleIterator(as_arithmetic_tuple(ts...));
}
//
// ArithmeticTuple "basis" elements
// A ScaledBasis<T,N> is a (at least) rank-N+1 ArithmeticTuple:
// A ScaledBasis<T,Ns...> is a (at least) rank-N+1 ArithmeticTuple:
// (_0,_0,...,T,_0,...)
// with value T in the Nth mode
template <class T, int N>
template <class T, int... Ns>
struct ScaledBasis : private tuple<T>
{
CUTE_HOST_DEVICE constexpr
@ -243,40 +229,61 @@ struct ScaledBasis : private tuple<T>
CUTE_HOST_DEVICE constexpr
decltype(auto) value() const { return get<0>(static_cast<tuple<T> const&>(*this)); }
// Deprecated: Get the first hierarchical mode in this basis.
CUTE_HOST_DEVICE static constexpr
auto mode() { return Int<N>{}; }
auto mode() { return get<0>(int_sequence<Ns...>{}); }
};
// Ensure flat representation
template <class T, int... Ms, int... Ns>
struct ScaledBasis<ScaledBasis<T, Ms...>, Ns...> : ScaledBasis<T, Ns..., Ms...> {};
template <class T>
struct is_scaled_basis : false_type {};
template <class T, int N>
struct is_scaled_basis<ScaledBasis<T,N>> : true_type {};
template <class T, int... Ns>
struct is_scaled_basis<ScaledBasis<T,Ns...>> : true_type {};
template <class T, int N>
struct is_integral<ScaledBasis<T,N>> : true_type {};
template <class T, int... Ns>
struct is_integral<ScaledBasis<T,Ns...>> : true_type {};
// Get the scalar T out of a ScaledBasis
template <class SB>
CUTE_HOST_DEVICE constexpr auto
basis_value(SB const& e)
// Shortcuts
// E<> := _1
// E<0> := (_1,_0,_0,...)
// E<1> := (_0,_1,_0,...)
// E<0,0> := ((_1,_0,_0,...),_0,_0,...)
// E<0,1> := ((_0,_1,_0,...),_0,_0,...)
// E<1,0> := (_0,(_1,_0,_0,...),_0,...)
// E<1,1> := (_0,(_0,_1,_0,...),_0,...)
template <int... Ns>
using E = ScaledBasis<Int<1>,Ns...>;
// Apply the Ns... pack to another Tuple
template <class T, class Tuple>
CUTE_HOST_DEVICE decltype(auto)
basis_get(T const&, Tuple&& t)
{
if constexpr (is_scaled_basis<SB>::value) {
return basis_value(e.value());
return static_cast<Tuple&&>(t);
}
template <class T, int... Ns, class Tuple>
CUTE_HOST_DEVICE decltype(auto)
basis_get(ScaledBasis<T,Ns...> const&, Tuple&& t)
{
if constexpr (sizeof...(Ns) == 0) {
return static_cast<Tuple&&>(t);
} else {
return e;
return get<Ns...>(static_cast<Tuple&&>(t));
}
CUTE_GCC_UNREACHABLE;
}
// Apply the N... pack to another Tuple
template <class SB, class Tuple>
template <class T>
CUTE_HOST_DEVICE decltype(auto)
basis_get(SB const& e, Tuple&& t)
{
if constexpr (is_scaled_basis<SB>::value) {
return basis_get(e.value(), get<SB::mode()>(static_cast<Tuple&&>(t)));
basis_value(T const& e) {
if constexpr (is_scaled_basis<T>::value) {
return e.value();
} else {
return static_cast<Tuple&&>(t);
return e;
}
CUTE_GCC_UNREACHABLE;
}
@ -294,65 +301,34 @@ to_atuple_i(T const& t, seq<I...>) {
// Turn a ScaledBases<T,N> into a rank-N+1 ArithmeticTuple
// with N prefix 0s: (_0,_0,...N...,_0,T)
template <class T, int N>
template <class T>
CUTE_HOST_DEVICE constexpr
auto
as_arithmetic_tuple(ScaledBasis<T,N> const& t) {
return detail::to_atuple_i(as_arithmetic_tuple(t.value()), make_seq<N>{});
as_arithmetic_tuple(ScaledBasis<T> const& t) {
return t.value();
}
namespace detail {
template <class T, int N, int... Ns>
CUTE_HOST_DEVICE constexpr
auto
as_arithmetic_tuple(ScaledBasis<T,N,Ns...> const& t) {
return detail::to_atuple_i(as_arithmetic_tuple(ScaledBasis<T,Ns...>{t.value()}), make_seq<N>{});
}
template <int... Ns>
struct Basis;
template <>
struct Basis<> {
using type = Int<1>;
};
template <int N, int... Ns>
struct Basis<N,Ns...> {
using type = ScaledBasis<typename Basis<Ns...>::type, N>;
};
} // end namespace detail
// Shortcut for writing ScaledBasis<ScaledBasis<ScaledBasis<Int<1>, N0>, N1>, ...>
// E<> := _1
// E<0> := (_1,_0,_0,...)
// E<1> := (_0,_1,_0,...)
// E<0,0> := ((_1,_0,_0,...),_0,_0,...)
// E<0,1> := ((_0,_1,_0,...),_0,_0,...)
// E<1,0> := (_0,(_1,_0,_0,...),_0,...)
// E<1,1> := (_0,(_0,_1,_0,...),_0,...)
template <int... N>
using E = typename detail::Basis<N...>::type;
template <class Shape>
template <int... Ns, class Shape>
CUTE_HOST_DEVICE constexpr
auto
make_basis_like(Shape const& shape)
{
if constexpr (is_integral<Shape>::value) {
return Int<1>{};
} else {
// Generate bases for each rank of shape
if constexpr (is_tuple<Shape>::value) {
// Generate bases for each mode of shape
return transform(tuple_seq<Shape>{}, shape, [](auto I, auto si) {
// Generate bases for each rank of si and add an i on front
using I_type = decltype(I);
return transform_leaf(make_basis_like(si), [](auto e) {
// MSVC has trouble capturing variables as constexpr,
// so that they can be used as template arguments.
// This is exactly what the code needs to do with i, unfortunately.
// The work-around is to define i inside the inner lambda,
// by using just the type from the enclosing scope.
constexpr int i = I_type::value;
return ScaledBasis<decltype(e), i>{};
});
// Generate bases for each si and add an i on end
return make_basis_like<Ns...,decltype(I)::value>(si);
});
} else {
return E<Ns...>{};
}
CUTE_GCC_UNREACHABLE;
}
@ -360,109 +336,124 @@ make_basis_like(Shape const& shape)
// Arithmetic
//
template <class T, int M, class U>
template <class T, int... Ns, class U>
CUTE_HOST_DEVICE constexpr
auto
safe_div(ScaledBasis<T,M> const& b, U const& u)
safe_div(ScaledBasis<T,Ns...> const& b, U const& u)
{
auto t = safe_div(b.value(), u);
return ScaledBasis<decltype(t),M>{t};
return ScaledBasis<decltype(t),Ns...>{t};
}
template <class T, int M, class U>
template <class T, int... Ns, class U>
CUTE_HOST_DEVICE constexpr
auto
ceil_div(ScaledBasis<T,M> const& b, U const& u)
ceil_div(ScaledBasis<T,Ns...> const& b, U const& u)
{
auto t = ceil_div(b.value(), u);
return ScaledBasis<decltype(t),M>{t};
return ScaledBasis<decltype(t),Ns...>{t};
}
template <class T, int N>
template <class T, int... Ns>
CUTE_HOST_DEVICE constexpr
auto
abs(ScaledBasis<T,N> const& e)
abs(ScaledBasis<T,Ns...> const& e)
{
auto t = abs(e.value());
return ScaledBasis<decltype(t),N>{t};
return ScaledBasis<decltype(t),Ns...>{t};
}
// Equality
template <class T, int N, class U, int M>
template <class T, int... Ns, class U, int... Ms>
CUTE_HOST_DEVICE constexpr
auto
operator==(ScaledBasis<T,N> const& t, ScaledBasis<U,M> const& u) {
return bool_constant<M == N>{} && t.value() == u.value();
operator==(ScaledBasis<T,Ns...> const& t, ScaledBasis<U,Ms...> const& u) {
if constexpr (sizeof...(Ns) == sizeof...(Ms)) {
return bool_constant<((Ns == Ms) && ...)>{} && t.value() == u.value();
} else {
return false_type{};
}
CUTE_GCC_UNREACHABLE;
}
// Not equal to anything else
template <class T, int N, class U>
template <class T, int... Ns, class U>
CUTE_HOST_DEVICE constexpr
false_type
operator==(ScaledBasis<T,N> const&, U const&) {
operator==(ScaledBasis<T,Ns...> const&, U const&) {
return {};
}
template <class T, class U, int M>
template <class T, class U, int... Ms>
CUTE_HOST_DEVICE constexpr
false_type
operator==(T const&, ScaledBasis<U,M> const&) {
operator==(T const&, ScaledBasis<U,Ms...> const&) {
return {};
}
// Multiplication
template <class A, class T, int N>
template <class A, class T, int... Ns>
CUTE_HOST_DEVICE constexpr
auto
operator*(A const& a, ScaledBasis<T,N> const& e) {
operator*(A const& a, ScaledBasis<T,Ns...> const& e) {
auto r = a * e.value();
return ScaledBasis<decltype(r),N>{r};
return ScaledBasis<decltype(r),Ns...>{r};
}
template <class T, int N, class B>
template <class T, int... Ns, class B>
CUTE_HOST_DEVICE constexpr
auto
operator*(ScaledBasis<T,N> const& e, B const& b) {
operator*(ScaledBasis<T,Ns...> const& e, B const& b) {
auto r = e.value() * b;
return ScaledBasis<decltype(r),N>{r};
return ScaledBasis<decltype(r),Ns...>{r};
}
// Addition
template <class T, int N, class U, int M>
template <class T, int... Ns, class U, int... Ms>
CUTE_HOST_DEVICE constexpr
auto
operator+(ScaledBasis<T,N> const& t, ScaledBasis<U,M> const& u) {
operator+(ScaledBasis<T,Ns...> const& t, ScaledBasis<U,Ms...> const& u) {
return as_arithmetic_tuple(t) + as_arithmetic_tuple(u);
}
template <class T, int N, class... U>
template <class T, int... Ns, class... U>
CUTE_HOST_DEVICE constexpr
auto
operator+(ScaledBasis<T,N> const& t, ArithmeticTuple<U...> const& u) {
operator+(ScaledBasis<T,Ns...> const& t, ArithmeticTuple<U...> const& u) {
return as_arithmetic_tuple(t) + u;
}
template <class... T, class U, int M>
template <class... T, class U, int... Ms>
CUTE_HOST_DEVICE constexpr
auto
operator+(ArithmeticTuple<T...> const& t, ScaledBasis<U,M> const& u) {
operator+(ArithmeticTuple<T...> const& t, ScaledBasis<U,Ms...> const& u) {
return t + as_arithmetic_tuple(u);
}
template <auto t, class U, int M>
template <auto t, class U, int... Ms>
CUTE_HOST_DEVICE constexpr
auto
operator+(C<t>, ScaledBasis<U,M> const& u) {
static_assert(t == 0, "ScaledBasis op+ error!");
return u;
operator+(C<t>, ScaledBasis<U,Ms...> const& u) {
if constexpr (sizeof...(Ms) == 0) {
return C<t>{} + u.value();
} else {
static_assert(t == 0, "ScaledBasis op+ error!");
return u;
}
CUTE_GCC_UNREACHABLE;
}
template <class T, int N, auto u>
template <class T, int... Ns, auto u>
CUTE_HOST_DEVICE constexpr
auto
operator+(ScaledBasis<T,N> const& t, C<u>) {
static_assert(u == 0, "ScaledBasis op+ error!");
return t;
operator+(ScaledBasis<T,Ns...> const& t, C<u>) {
if constexpr (sizeof...(Ns) == 0) {
return t.value() + C<u>{};
} else {
static_assert(u == 0, "ScaledBasis op+ error!");
return t;
}
CUTE_GCC_UNREACHABLE;
}
//
@ -475,10 +466,10 @@ CUTE_HOST_DEVICE void print(ArithmeticTupleIterator<ArithTuple> const& iter)
printf("ArithTuple"); print(iter.coord_);
}
template <class T, int N>
CUTE_HOST_DEVICE void print(ScaledBasis<T,N> const& e)
template <class T, int... Ns>
CUTE_HOST_DEVICE void print(ScaledBasis<T,Ns...> const& e)
{
print(e.value()); printf("@%d", N);
print(e.value()); (void(printf("@%d", Ns)), ...);
}
#if !defined(__CUDACC_RTC__)
@ -488,10 +479,11 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, ArithmeticTupleIterator<Ari
return os << "ArithTuple" << iter.coord_;
}
template <class T, int N>
CUTE_HOST std::ostream& operator<<(std::ostream& os, ScaledBasis<T,N> const& e)
template <class T, int... Ns>
CUTE_HOST std::ostream& operator<<(std::ostream& os, ScaledBasis<T,Ns...> const& e)
{
return os << e.value() << "@" << N;
os << e.value(); (void(os << "@" << Ns), ...);
return os;
}
#endif

View File

@ -47,8 +47,9 @@ namespace cute
// Signed integers
//
using int2_t = cutlass::int2b_t;
using int4_t = cutlass::int4b_t;
using int2_t = cutlass::int2b_t;
using int4_t = cutlass::int4b_t;
using int6_t = cutlass::int6b_t;
using CUTE_STL_NAMESPACE::int8_t;
using CUTE_STL_NAMESPACE::int16_t;
using CUTE_STL_NAMESPACE::int32_t;
@ -75,10 +76,10 @@ using int_byte_t = typename int_byte<N>::type;
// Unsigned integers
//
using uint1_t = cutlass::uint1b_t;
using uint2_t = cutlass::uint2b_t;
using uint4_t = cutlass::uint4b_t;
using uint6_t = cutlass::uint6b_t;
using uint1_t = cutlass::uint1b_t;
using uint2_t = cutlass::uint2b_t;
using uint4_t = cutlass::uint4b_t;
using uint6_t = cutlass::uint6b_t;
using CUTE_STL_NAMESPACE::uint8_t;
using CUTE_STL_NAMESPACE::uint16_t;
using CUTE_STL_NAMESPACE::uint32_t;
@ -88,7 +89,7 @@ template <int N> struct uint_bit;
template <> struct uint_bit< 1> { using type = uint1_t; };
template <> struct uint_bit< 2> { using type = uint2_t; };
template <> struct uint_bit< 4> { using type = uint4_t; };
template <> struct uint_bit< 6> { using type = uint6_t; };
template <> struct uint_bit< 6> { using type = uint6_t; };
template <> struct uint_bit< 8> { using type = uint8_t; };
template <> struct uint_bit< 16> { using type = uint16_t; };
template <> struct uint_bit< 32> { using type = uint32_t; };

View File

@ -38,10 +38,19 @@
namespace cute {
template <typename T>
struct sizeof_bits : public cutlass::sizeof_bits<T> {};
template <class T>
struct sizeof_bits : cutlass::sizeof_bits<T> {};
// DO NOT change auto to int, sizeof_bits<sparse_elem> use integral_ratio instead of int
template <class T>
struct sizeof_bits<T const> : sizeof_bits<T> {};
template <class T>
struct sizeof_bits<T volatile> : sizeof_bits<T> {};
template <class T>
struct sizeof_bits<T const volatile> : sizeof_bits<T> {};
// DO NOT change auto to int, sizeof_bits<sparse_elem> use integral_ratio instead of int
template <class T>
static constexpr auto sizeof_bits_v = sizeof_bits<T>::value;
@ -53,6 +62,23 @@ using cutlass::is_subbyte;
template <class T>
static constexpr auto is_subbyte_v = is_subbyte<T>::value;
//
// Integral
//
using cutlass::bin1_t;
using cutlass::uint1b_t;
using cutlass::int2b_t;
using cutlass::uint2b_t;
using cutlass::int4b_t;
using cutlass::uint4b_t;
using cutlass::int6b_t;
using cutlass::uint6b_t;
//
// Floating Point
//
using cutlass::half_t;
using cutlass::bfloat16_t;
@ -65,18 +91,12 @@ using cutlass::type_erased_dynamic_float8_t;
using cutlass::float_e4m3_t;
using cutlass::float_e5m2_t;
using cutlass::uint1b_t;
using cutlass::int2b_t;
using cutlass::uint2b_t;
using cutlass::int4b_t;
using cutlass::uint4b_t;
using cutlass::bin1_t;
using cutlass::float_ue4m3_t;
using cutlass::float_ue8m0_t;
using cutlass::uint6b_t;
using cutlass::float_e2m1_t;
using cutlass::float_e2m3_t;
using cutlass::float_e3m2_t;
@ -94,8 +114,6 @@ using cutlass::detail::type_erased_dynamic_float4_unpacksmem_t;
using cutlass::detail::type_erased_dynamic_float6_unpacksmem_t;
};
//
// Print utility
//
@ -112,7 +130,6 @@ print(bfloat16_t a) {
printf("%f", static_cast<float>(a));
}
CUTE_HOST_DEVICE
void
print(tfloat32_t a) {
@ -131,6 +148,15 @@ print(float_e5m2_t a) {
printf("%f", static_cast<float>(a));
}
template <cutlass::detail::FpEncoding Encoding, class Derived>
CUTE_HOST_DEVICE
void
print(cutlass::float_exmy_base<Encoding, Derived> a) {
printf("%f", static_cast<float>(a));
}
// Pretty Print utility
CUTE_HOST_DEVICE void
pretty_print(bfloat16_t v) {
printf("%*.2f", 8, float(v));
@ -156,26 +182,11 @@ pretty_print(float_e5m2_t t) {
printf("%*.2f", 8, static_cast<float>(t));
}
template <
cutlass::detail::FpEncoding Encoding,
class Derived
>
CUTE_HOST_DEVICE
void
print(cutlass::float_exmy_base<Encoding, Derived> a) {
printf("%f", static_cast<float>(a));
}
template <
cutlass::detail::FpEncoding Encoding,
class Derived
>
template <cutlass::detail::FpEncoding Encoding, class Derived>
CUTE_HOST_DEVICE
void
pretty_print_float_exmy_base(cutlass::float_exmy_base<Encoding, Derived> t) {
printf("%*.2f", 8, static_cast<float>(t));
}
} // namespace cute

View File

@ -33,9 +33,9 @@
#include <cute/config.hpp> // CUTE_HOST_DEVICE
#include <cute/pointer_base.hpp> // cute::iter_adaptor
#include <cute/pointer_sparse.hpp>
#include <cute/container/array_subbyte.hpp> // cute::subbyte_iterator
#include <cute/numeric/integral_constant.hpp> // cute::true_type, cute::false_type
#include <cute/numeric/numeric_types.hpp> // sizeof_bits
#include <cute/container/array_subbyte.hpp> // cute::subbyte_iterator
namespace cute
{
@ -51,11 +51,13 @@ namespace cute
// Requires construction of a sparse_ptr that emulates access to the S logical elements.
//
template <class NewT>
template <class NewT_, class T>
CUTE_HOST_DEVICE constexpr
auto
recast_ptr(void* ptr)
recast_ptr(T* ptr)
{
using NewT = copy_cv_t<T, NewT_>;
if constexpr (is_sparse<NewT>::value) {
constexpr int sparsity = NewT::sparsity;
NewT* p = reinterpret_cast<NewT*>(ptr);
@ -69,24 +71,6 @@ recast_ptr(void* ptr)
CUTE_GCC_UNREACHABLE;
}
template <class NewT>
CUTE_HOST_DEVICE constexpr
auto
recast_ptr(void const* ptr)
{
if constexpr (is_sparse<NewT>::value) {
constexpr int sparsity = NewT::sparsity;
NewT const* p = reinterpret_cast<NewT const*>(ptr);
return make_sparse_ptr<sparsity>(p);
} else
if constexpr (cute::is_subbyte_v<NewT>) {
return subbyte_iterator<NewT const>(ptr);
} else {
return reinterpret_cast<NewT const*>(ptr);
}
CUTE_GCC_UNREACHABLE;
}
// Disambiguate nullptr
template <class NewT>
CUTE_HOST_DEVICE constexpr

View File

@ -167,23 +167,6 @@ downcast(ComposedLayout<SwizzleFn,smem_sparse_ptr_flag_bits<S,B>,Layout> const&
// Display utilities
//
// Capture and cast smem_ptr_flag Layouts to offset-0 layouts
template <class SwizzleFn, int B, class Layout>
CUTE_HOST_DEVICE
void
print_layout(ComposedLayout<SwizzleFn,smem_ptr_flag_bits<B>,Layout> const& layout)
{
print_layout(as_position_independent_swizzle_layout(layout));
}
template <class SwizzleFn, int B, class Layout>
CUTE_HOST_DEVICE
void
print_latex(ComposedLayout<SwizzleFn,smem_ptr_flag_bits<B>,Layout> const& layout)
{
print_latex(as_position_independent_swizzle_layout(layout));
}
template <int B>
CUTE_HOST_DEVICE void print(smem_ptr_flag_bits<B> ptr)
{

View File

@ -56,3 +56,9 @@
#include <cute/algorithm/cooperative_copy.hpp>
#include <cute/algorithm/cooperative_gemm.hpp>
//
// Utilities
//
#include <cute/util/print_tensor.hpp>
#include <cute/util/print_latex.hpp>

View File

@ -753,24 +753,30 @@ domain_offset(Coord const& coord, Tensor&& tensor)
// -- doesn't check dynamic integer divisibility
// -- doesn't check alignment
template <class NewType, class Tensor>
template <class NewType_, class Tensor>
CUTE_HOST_DEVICE constexpr
auto
recast(Tensor&& tensor)
{
using OldType = typename remove_cvref_t<Tensor>::value_type;
using OldType = typename remove_cvref_t<Tensor>::element_type;
using NewType = copy_cv_t<OldType, NewType_>;
auto old_layout = tensor.layout();
auto new_layout = recast_layout<OldType,NewType>(old_layout);
// If this is an upcast of a normal Layout with static negative strides, then offset as well
if constexpr (sizeof(OldType) < sizeof(NewType) && not is_composed_layout<decltype(old_layout)>::value) {
auto shape_diff = transform(flatten(old_layout.shape()), flatten(new_layout.shape()), minus{});
auto extent_diff = transform(shape_diff, flatten(old_layout.stride()), multiplies{});
auto offset = fold(extent_diff, Int<0>{}, [](auto const& i, auto const& a) { return i + cute::min(a,Int<0>{}); });
return make_tensor(recast_ptr<NewType>(static_cast<Tensor&&>(tensor).data() + offset), new_layout);
if constexpr (is_same<NewType, OldType>::value) {
return tensor;
} else {
return make_tensor(recast_ptr<NewType>(static_cast<Tensor&&>(tensor).data() ), new_layout);
// If this is an upcast of a normal Layout with static negative strides, then offset as well
if constexpr (sizeof(OldType) < sizeof(NewType) && not is_composed_layout<decltype(old_layout)>::value) {
auto shape_diff = transform(flatten(old_layout.shape()), flatten(new_layout.shape()), minus{});
auto extent_diff = transform(shape_diff, flatten(old_layout.stride()), multiplies{});
auto offset = fold(extent_diff, Int<0>{}, [](auto const& i, auto const& a) { return i + cute::min(a,Int<0>{}); });
return make_tensor(recast_ptr<NewType>(static_cast<Tensor&&>(tensor).data() + offset), new_layout);
} else {
return make_tensor(recast_ptr<NewType>(static_cast<Tensor&&>(tensor).data() ), new_layout);
}
}
CUTE_GCC_UNREACHABLE;
@ -1114,95 +1120,5 @@ CUTE_HOST_DEVICE void print(Tensor<Engine,Layout> const& tensor)
print(tensor.data()); print(" o "); print(tensor.layout());
}
template <class Engine, class Layout>
CUTE_HOST_DEVICE void print_tensor(Tensor<Engine,Layout> const& tensor, bool print_type = true)
{
if (print_type) {
print(tensor); print(":\n");
}
if constexpr (Layout::rank == 1)
{
for (int m = 0; m < size(tensor); ++m) {
pretty_print(tensor(m));
printf("\n");
}
} else
if constexpr (Layout::rank == 2)
{
for (int m = 0; m < size<0>(tensor); ++m) {
for (int n = 0; n < size<1>(tensor); ++n) {
pretty_print(tensor(m,n));
}
printf("\n");
}
} else
if constexpr (Layout::rank == 3)
{
print_tensor(tensor(_,_,0), false);
for (int k = 1; k < size<2>(tensor); ++k) {
for (int i = 0; i < 5*size<1>(tensor); ++i) { print("-"); } print("\n");
print_tensor(tensor(_,_,k), false);
}
} else
if constexpr (Layout::rank == 4)
{
print_tensor(tensor(_,_,_,0), false);
for (int p = 1; p < size<3>(tensor); ++p) {
for (int i = 0; i < 5*size<1>(tensor); ++i) { print("="); } print("\n");
print_tensor(tensor(_,_,_,p), false);
}
}
}
#if !defined(__CUDACC_RTC__)
template <class Engine, class Layout>
CUTE_HOST std::ostream& print_tensor_os(std::ostream& os, Tensor<Engine,Layout> const& tensor)
{
int digits = 9;
if constexpr (Layout::rank == 1)
{
for (int m = 0; m < size(tensor); ++m) {
os << std::setw(digits) << tensor(m) << std::endl;
}
} else
if constexpr (Layout::rank == 2)
{
for (int m = 0; m < size<0>(tensor); ++m) {
for (int n = 0; n < size<1>(tensor); ++n) {
os << std::setw(digits) << tensor(m,n);
}
os << std::endl;
}
} else
if constexpr (Layout::rank == 3)
{
print_tensor_os(os, tensor(_,_,0));
for (int k = 1; k < size<2>(tensor); ++k) {
for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "-"; } os << std::endl;
print_tensor_os(os, tensor(_,_,k));
}
} else
if constexpr (Layout::rank == 4)
{
print_tensor_os(os, tensor(_,_,_,0));
for (int p = 1; p < size<3>(tensor); ++p) {
for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "="; } os << std::endl;
print_tensor_os(os, tensor(_,_,_,p));
}
}
return os;
}
template <class Engine, class Layout>
CUTE_HOST std::ostream& operator<<(std::ostream& os, Tensor<Engine,Layout> const& tensor)
{
os << tensor.layout() << std::endl;
return print_tensor_os(os, tensor);
}
#endif // !defined(__CUDACC_RTC__)
} // end namespace cute

View File

@ -0,0 +1,438 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp> // CUTE_HOST_DEVICE
#include <cute/atom/mma_atom.hpp>
#include <cute/atom/copy_atom.hpp>
#include <cute/layout.hpp>
#include <cute/tensor_impl.hpp>
namespace cute
{
///////////////////////////////////////
// Common LaTeX TikZ Color utilities //
///////////////////////////////////////
struct TikzColor_White {
CUTE_HOST_DEVICE char const*
operator()(int idx) const {
return "white";
}
};
struct TikzColor_BWx8 {
CUTE_HOST_DEVICE char const*
operator()(int idx) const {
static char const* color_map[8] = {"black!00", "black!40", "black!20", "black!60",
"black!10", "black!50", "black!30", "black!70"};
return color_map[idx % 8];
}
};
struct TikzColor_TV {
CUTE_HOST_DEVICE char const*
operator()(int tid, int vid) const {
static char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}",
"{rgb,255:red,175;green,255;blue,175}",
"{rgb,255:red,255;green,255;blue,175}",
"{rgb,255:red,255;green,175;blue,175}",
"{rgb,255:red,210;green,210;blue,255}",
"{rgb,255:red,210;green,255;blue,210}",
"{rgb,255:red,255;green,255;blue,210}",
"{rgb,255:red,255;green,210;blue,210}"};
return color_map[tid % 8];
}
};
/////////////////////////////
// Layout 2D to LaTeX TikZ //
/////////////////////////////
template <class LayoutA, class TikzColorFn = TikzColor_BWx8>
CUTE_HOST_DEVICE
void
print_latex(LayoutA const& layout_a, // (m,n) -> idx
TikzColorFn color = {}) // lambda(idx) -> tikz color string
{
CUTE_STATIC_ASSERT_V(rank(layout_a) <= Int<2>{});
auto layout = append<2>(layout_a, Layout<_1,_0>{});
// Commented print(layout)
printf("%% Layout: "); print(layout); printf("\n");
// Header
printf("\\documentclass[convert]{standalone}\n"
"\\usepackage{tikz}\n\n"
"\\begin{document}\n"
"\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n");
auto [M, N] = product_each(shape(layout));
// Layout
for (int m = 0; m < M; ++m) {
for (int n = 0; n < N; ++n) {
int idx = layout(m,n);
printf("\\node[fill=%s] at (%d,%d) {%d};\n",
color(idx), m, n, idx);
}
}
// Grid
printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (0,0) grid (%d,%d);\n\n",
int(M), int(N));
// Labels
for (int m = 0, n = -1; m < M; ++m) {
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, n, m);
}
for (int m = -1, n = 0; n < N; ++n) {
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, n, n);
}
// Footer
printf("\\end{tikzpicture}\n"
"\\end{document}\n");
}
template <class SwizzleFn, int B, class Layout, class TikzColorFn = TikzColor_BWx8>
CUTE_HOST_DEVICE
void
print_latex(ComposedLayout<SwizzleFn,smem_ptr_flag_bits<B>,Layout> const& layout,
TikzColorFn color = {}) // lambda(idx) -> tikz color string)
{
print_latex(as_position_independent_swizzle_layout(layout), color);
}
///////////////////////////////
// LayoutTV 2D to LaTeX TikZ //
///////////////////////////////
template <class LayoutTV, class Tile_MN,
class TikzColorFn = TikzColor_TV>
CUTE_HOST_DEVICE
void
print_latex_tv(LayoutTV const& layout_tv, // (t,v) -> m,n coord
Tile_MN const& tile_mn, // (M,N)
TikzColorFn color = {}) // (t,v) -> color
{
CUTE_STATIC_ASSERT_V(rank(layout_tv) == Int<2>{});
// Commented prints
printf("%% Layout TV: "); print(layout_tv); printf("\n");
// Header
printf("\\documentclass[convert]{standalone}\n"
"\\usepackage{tikz}\n\n"
"\\begin{document}\n"
"\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n");
auto [M, N] = product_each(shape(tile_mn));
Tensor filled = make_tensor<bool>(make_shape(M, N));
clear(filled);
// Layout
for (int tid = 0; tid < size<0>(layout_tv); ++tid) {
for (int vid = 0; vid < size<1>(layout_tv); ++vid) {
auto [m, n] = layout_tv(tid, vid);
if (not filled(m, n)) {
filled(m, n) = true;
printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
color(tid, vid),
int(m), int(n),
tid, vid);
}
}
}
// Grid
printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (0,0) grid (%d,%d);\n\n", int(M), int(N));
// Labels
for (int m = 0, n = -1; m < M; ++m) {
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, n, m);
}
for (int n = 0, m = -1; n < N; ++n) {
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, n, n);
}
// Footer
printf("\\end{tikzpicture}\n"
"\\end{document}\n");
}
////////////////////////////
// MMA Atom to LaTeX TikZ //
////////////////////////////
namespace detail {
template <class LayoutC, class LayoutA, class LayoutB, class Tile_MNK,
class TikzColorFn = TikzColor_TV>
CUTE_HOST_DEVICE
void
print_latex_mma(LayoutC const& C, // (tid,vid) -> (m,n) coord
LayoutA const& A, // (tid,vid) -> (m,k) coord
LayoutB const& B, // (tid,vid) -> (n,k) coord
Tile_MNK const& tile_mnk, // (M,N,K)
TikzColorFn color = {}) // lambda(tid,vid) -> tikz color string
{
CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{});
CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{});
CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{});
// Commented prints
printf("%% LayoutC: "); print(C); printf("\n");
printf("%% LayoutA: "); print(A); printf("\n");
printf("%% LayoutB: "); print(B); printf("\n");
// Header
printf("\\documentclass[convert]{standalone}\n"
"\\usepackage{tikz}\n\n"
"\\begin{document}\n"
"\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n");
auto [M, N, K] = product_each(shape(tile_mnk));
Tensor filled = make_tensor<bool>(make_shape(M, N, K));
clear(filled);
// C starting at 0,0
for (int tid = 0; tid < size<0>(C); ++tid) {
for (int vid = 0; vid < size<1>(C); ++vid) {
auto [m, n] = C(tid, vid);
if (not filled(m, n, 0)) {
filled(m, n, 0) = true;
printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
color(tid, vid),
int(m), int(n),
tid, vid);
}
}
}
// Grid
printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n",
0, 0, int(M), int(N));
clear(filled);
// A starting at 0,-K-1
for (int tid = 0; tid < size<0>(A); ++tid) {
for (int vid = 0; vid < size<1>(A); ++vid) {
auto [m, k] = A(tid, vid);
if (not filled(m, 0, k)) {
filled(m, 0, k) = true;
printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
color(tid, vid),
int(m), int(k-K-1),
tid, vid);
}
}
}
// Grid
printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n",
0, -int(K)-1, int(M), -1);
// A labels
for (int m = 0, k = -1; m < M; ++m) {
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, int(k-K-1), m);
}
for (int m = -1, k = 0; k < K; ++k) {
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, int(k-K-1), k);
}
clear(filled);
// B starting at -K-1,0
for (int tid = 0; tid < size<0>(B); ++tid) {
for (int vid = 0; vid < size<1>(B); ++vid) {
auto [n, k] = B(tid, vid);
if (not filled(0, n, k)) {
filled(0, n, k) = true;
printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
color(tid, vid),
int(k)-int(K)-1, int(n),
tid, vid);
}
}
}
// Grid
printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n",
-int(K)-1, 0, -1, int(N));
// B labels
for (int n = 0, k = -1; n < N; ++n) {
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", int(k-K-1), n, n);
}
for (int n = -1, k = 0; k < K; ++k) {
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", int(k-K-1), n, k);
}
// Footer
printf("\\end{tikzpicture}\n"
"\\end{document}\n");
}
} // end namespace detail
// MMA Atom to LaTeX TikZ
template <class... Args, class TikzColorFn = TikzColor_TV>
CUTE_HOST_DEVICE
void
print_latex(MMA_Atom<Args...> const& mma_atom,
TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string
{
print_latex(make_tiled_mma(mma_atom));
}
// TiledMMA to LaTeX TikZ
template <class... Args, class TikzColorFn = TikzColor_TV>
CUTE_HOST_DEVICE
void
print_latex(TiledMMA<Args...> const& mma,
TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string
{
auto tile_mnk = tile_shape(mma);
Tensor refC = make_identity_tensor(select<0,1>(tile_mnk));
Tensor tensorC_TV = composition(refC, mma.get_layoutC_TV());
Tensor refA = make_identity_tensor(select<0,2>(tile_mnk));
Tensor tensorA_TV = composition(refA, mma.get_layoutA_TV());
Tensor refB = make_identity_tensor(select<1,2>(tile_mnk));
Tensor tensorB_TV = composition(refB, mma.get_layoutB_TV());
detail::print_latex_mma(tensorC_TV, tensorA_TV, tensorB_TV, tile_mnk, color);
}
////////////////////////////
// CopyAtom to LaTeX TikZ //
////////////////////////////
namespace detail {
// Generic TV Layout to LaTeX TikZ
template <class LayoutS_TV, class LayoutD_TV, class Tile_MN,
class TikzColorFn = TikzColor_TV>
CUTE_HOST_DEVICE
void
print_latex_copy(LayoutS_TV const& S, // (t,v) -> m,n coord
LayoutD_TV const& D, // (t,v) -> m,n coord
Tile_MN const& tile_mn, // (M,N)
TikzColorFn color = {}) // (t,v) -> color
{
CUTE_STATIC_ASSERT_V(rank(S) == Int<2>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<2>{});
// Commented prints
printf("%% Layout S TV: "); print(S); printf("\n");
printf("%% Layout D TV: "); print(D); printf("\n");
// Header
printf("\\documentclass[convert]{standalone}\n"
"\\usepackage{tikz}\n\n"
"\\begin{document}\n"
"\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n");
auto [M, N] = product_each(shape(tile_mn));
Tensor filled = make_tensor<bool>(make_shape(M, N));
clear(filled);
// S starting at 0,0
for (int tid = 0; tid < size<0>(S); ++tid) {
for (int vid = 0; vid < size<1>(S); ++vid) {
auto [m, n] = S(tid, vid);
if (not filled(m, n)) {
filled(m, n) = true;
printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
color(tid, vid),
int(m), int(n),
tid, vid);
}
}
}
// Grid
printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n",
0, 0, int(M), int(N));
// S Labels
for (int m = 0, n = -1; m < M; ++m) {
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, n, m);
}
for (int m = -1, n = 0; n < N; ++n) {
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, n, n);
}
clear(filled);
// D starting at 0,N+3
for (int tid = 0; tid < size<0>(D); ++tid) {
for (int vid = 0; vid < size<1>(D); ++vid) {
auto [m, n] = D(tid, vid);
if (not filled(m, n)) {
filled(m, n) = true;
printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
color(tid, vid),
int(m), int(n) + int(N) + 3,
tid, vid);
}
}
}
// Grid
printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n",
0, int(N) + 3, int(M), int(N) + int(N) + 3);
// D Labels
for (int m = 0, n = N; m < M; ++m) {
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, int(n+N+3), m);
}
for (int m = -1, n = 0; n < N; ++n) {
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, int(n+N+3), n);
}
// Footer
printf("\\end{tikzpicture}\n"
"\\end{document}\n");
}
} // end namespace detail
// TiledCopy to LaTeX TikZ
template <class... Args, class TikzColorFn = TikzColor_TV>
CUTE_HOST_DEVICE
void
print_latex(TiledCopy<Args...> const& copy,
TikzColorFn color = {}) // lambda(tid,vid) -> tikz color string
{
auto tiler_mn = typename TiledCopy<Args...>::Tiler_MN{};
auto tile_mn = product_each(shape(logical_divide(make_layout(Shape<_1,_1>{}), tiler_mn))); // tile_shape
Tensor refS = make_identity_tensor(tile_mn);
Tensor layoutS_TV = copy.tidfrg_S(refS)(_,_,Int<0>{});
Tensor refD = make_identity_tensor(tile_mn);
Tensor layoutD_TV = copy.tidfrg_D(refD)(_,_,Int<0>{});
detail::print_latex_copy(layoutS_TV, layoutD_TV, tile_mn, color);
}
} // end namespace cute

View File

@ -0,0 +1,257 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp> // CUTE_HOST_DEVICE
#include <cute/atom/mma_atom.hpp>
#include <cute/atom/copy_atom.hpp>
#include <cute/layout.hpp>
#include <cute/tensor_impl.hpp>
namespace cute
{
////////////////////////////////
// Common SVG Color utilities //
////////////////////////////////
struct TSVGColor_White {
CUTE_HOST_DEVICE char const*
operator()(int idx) const {
return "255,255,255";
}
};
struct TSVGColor_BWx8 {
CUTE_HOST_DEVICE char const*
operator()(int idx) const {
static char const* color_map[8] = {"255,255,255", "230,230,230", "205,205,205", "180,180,180",
"155,155,155", "130,130,130", "105,105,105", "080,080,080"};
return color_map[idx % 8];
}
};
struct SVGColor_TV {
CUTE_HOST_DEVICE char const*
operator()(int tid, int vid) const {
static char const* color_map[8] = {"175,175,255", "175,255,175", "255,255,175", "255,175,175",
"210,210,255", "210,255,210", "255,255,210", "255,210,210"};
return color_map[tid % 8];
}
};
/////////////////////
// MMA Atom to SVG //
/////////////////////
namespace detail {
template <class LayoutC, class LayoutA, class LayoutB, class Tile_MNK,
class SVGColorFn = SVGColor_TV>
CUTE_HOST_DEVICE
void
print_svg_mma(LayoutC const& C,
LayoutA const& A,
LayoutB const& B,
Tile_MNK const& tile_mnk,
SVGColorFn color = {}) // lambda(tid,vid) -> SVG color string
{
CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{});
CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{});
CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{});
auto [M, N, K] = product_each(shape(tile_mnk));
int cell_size = 20;
int page_width = (K + N + 2) * cell_size;
int page_height = (K + M + 2) * cell_size;
// Commented print
printf("<!-- Tile: "); print(tile_mnk); printf(" -->\n");
printf("<!-- A: "); print(A); printf(" -->\n");
printf("<!-- B: "); print(B); printf(" -->\n");
printf("<!-- C: "); print(C); printf(" -->\n");
// SVG Header
printf("<svg width=\"100%%\" height=\"100%%\" viewBox=\"0 0 %d %d\" "
"preserveAspectRatio=\"xMidYMid meet\" "
"xmlns=\"http://www.w3.org/2000/svg\">\n",
page_width, page_height);
Tensor filled = make_tensor<bool>(make_shape(M, N, K));
clear(filled);
// --- Draw C ---
for (int tid = 0; tid < size<0>(C); ++tid) {
for (int vid = 0; vid < size<1>(C); ++vid) {
auto [m, n] = C(tid, vid);
if (!filled(m, n, 0)) {
filled(m, n, 0) = true;
int x = (n + K + 2) * cell_size;
int y = (m + K + 2) * cell_size;
printf("<rect x=\"%d\" y=\"%d\" width=\"%d\" height=\"%d\" "
"fill=\"rgb(%s)\" stroke=\"black\"/>\n",
x, y, cell_size, cell_size, color(tid,vid));
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
"alignment-baseline=\"central\" font-size=\"8\">T%d</text>\n",
x + cell_size/2, y + 1*cell_size/4, tid);
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
"alignment-baseline=\"central\" font-size=\"8\">V%d</text>\n",
x + cell_size/2, y + 3*cell_size/4, vid);
}
}
}
clear(filled);
// --- Draw A ---
for (int tid = 0; tid < size<0>(A); ++tid) {
for (int vid = 0; vid < size<1>(A); ++vid) {
auto [m, k] = A(tid, vid);
if (!filled(m, 0, k)) {
filled(m, 0, k) = true;
int x = (k + 1) * cell_size;
int y = (m + K + 2) * cell_size;
printf("<rect x=\"%d\" y=\"%d\" width=\"%d\" height=\"%d\" "
"fill=\"rgb(%s)\" stroke=\"black\" />\n",
x, y, cell_size, cell_size, color(tid,vid));
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
"alignment-baseline=\"central\" font-size=\"8\">T%d</text>\n",
x + cell_size/2, y + 1*cell_size/4, tid);
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
"alignment-baseline=\"central\" font-size=\"8\">V%d</text>\n",
x + cell_size/2, y + 3*cell_size/4, vid);
}
}
}
// A labels
for (int m = 0, k = -1; m < M; ++m) {
int x = (k + 1) * cell_size;
int y = (m + K + 2) * cell_size;
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
"alignment-baseline=\"central\" font-size=\"12\">%d</text>\n",
x + cell_size/2, y + cell_size/2, m);
}
for (int m = -1, k = 0; k < K; ++k) {
int x = (k + 1) * cell_size;
int y = (m + K + 2) * cell_size;
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
"alignment-baseline=\"central\" font-size=\"12\">%d</text>\n",
x + cell_size/2, y + cell_size/2, k);
}
clear(filled);
// --- Draw B ---
for (int tid = 0; tid < size<0>(B); ++tid) {
for (int vid = 0; vid < size<1>(B); ++vid) {
auto [n, k] = B(tid, vid);
if (!filled(0, n, k)) {
filled(0, n, k) = true;
int x = (n + K + 2) * cell_size;
int y = (k + 1) * cell_size;
printf("<rect x=\"%d\" y=\"%d\" width=\"%d\" height=\"%d\" "
"fill=\"rgb(%s)\" stroke=\"black\" />\n",
x, y, cell_size, cell_size, color(tid,vid));
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
"alignment-baseline=\"central\" font-size=\"8\">T%d</text>\n",
x + cell_size/2, y + 1*cell_size/4, tid);
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
"alignment-baseline=\"central\" font-size=\"8\">V%d</text>\n",
x + cell_size/2, y + 3*cell_size/4, vid);
}
}
}
// B labels
for (int n = 0, k = -1; n < N; ++n) {
int x = (n + K + 2) * cell_size;
int y = (k + 1) * cell_size;
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
"alignment-baseline=\"central\" font-size=\"12\">%d</text>\n",
x + cell_size/2, y + cell_size/2, n);
}
for (int n = -1, k = 0; k < K; ++k) {
int x = (n + K + 2) * cell_size;
int y = (k + 1) * cell_size;
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
"alignment-baseline=\"central\" font-size=\"12\">%d</text>\n",
x + cell_size/2, y + cell_size/2, k);
}
// SVG footer
printf("</svg>\n");
}
} // end namespace detail
// MMA Atom to SVG
template <class... Args, class SVGColorFn = SVGColor_TV>
CUTE_HOST_DEVICE
void
print_svg(MMA_Atom<Args...> const& mma_atom,
SVGColorFn color = {}) // lambda(thr_idx,val_idx) -> svg color string
{
print_svg(make_tiled_mma(mma_atom));
}
// TiledMMA to SVG
template <class... Args, class SVGColorFn = SVGColor_TV>
CUTE_HOST_DEVICE
void
print_svg(TiledMMA<Args...> const& mma,
SVGColorFn color = {}) // lambda(thr_idx,val_idx) -> svg color string
{
auto tile_mnk = tile_shape(mma);
Tensor refC = make_identity_tensor(select<0,1>(tile_mnk));
Tensor tensorC_TV = composition(refC, mma.get_layoutC_TV());
Tensor refA = make_identity_tensor(select<0,2>(tile_mnk));
Tensor tensorA_TV = composition(refA, mma.get_layoutA_TV());
Tensor refB = make_identity_tensor(select<1,2>(tile_mnk));
Tensor tensorB_TV = composition(refB, mma.get_layoutB_TV());
detail::print_svg_mma(tensorC_TV, tensorA_TV, tensorB_TV, tile_mnk, color);
}
} // end namespace cute

View File

@ -0,0 +1,188 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp> // CUTE_HOST_DEVICE
#include <cute/layout.hpp>
#include <cute/tensor_impl.hpp>
namespace cute
{
////////////////////////////////
// Layout 2D to Console table //
////////////////////////////////
template <class Layout>
CUTE_HOST_DEVICE
void
print_layout(Layout const& layout) // (m,n) -> idx
{
CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{});
int idx_width = num_digits(cosize(layout)) + 2;
const char* delim = "+-----------------------";
print(layout); print("\n");
// Column indices
print(" ");
for (int n = 0; n < size<1>(layout); ++n) { printf(" %*d ", idx_width-2, n); }
printf("\n");
// Print out A m-by-n
for (int m = 0; m < size<0>(layout); ++m) {
// Header
print(" ");
for (int n = 0; n < size<1>(layout); ++n) { printf("%.*s", idx_width+1, delim); }
printf("+\n");
// Values
printf("%2d ", m); // Row indices
for (int n = 0; n < size<1>(layout); ++n) { printf("| %*d ", idx_width-2, int(layout(m,n))); }
printf("|\n");
}
// Footer
print(" ");
for (int n = 0; n < size<1>(layout); ++n) { printf("%.*s", idx_width+1, delim); }
printf("+\n");
}
// Capture and cast smem_ptr_flag Layouts to offset-0 layouts
template <class SwizzleFn, int B, class Layout>
CUTE_HOST_DEVICE
void
print_layout(ComposedLayout<SwizzleFn,smem_ptr_flag_bits<B>,Layout> const& layout)
{
print_layout(as_position_independent_swizzle_layout(layout));
}
////////////////////////////////
// Tensor 1D,2D,3D,4D Console //
////////////////////////////////
template <class Engine, class Layout>
CUTE_HOST_DEVICE
void
print_tensor(Tensor<Engine,Layout> const& tensor, bool print_type = true)
{
if (print_type) {
print(tensor); print(":\n");
}
if constexpr (Layout::rank == 1)
{
for (int m = 0; m < size(tensor); ++m) {
pretty_print(tensor(m));
printf("\n");
}
} else
if constexpr (Layout::rank == 2)
{
for (int m = 0; m < size<0>(tensor); ++m) {
for (int n = 0; n < size<1>(tensor); ++n) {
pretty_print(tensor(m,n));
}
printf("\n");
}
} else
if constexpr (Layout::rank == 3)
{
print_tensor(tensor(_,_,0), false);
for (int k = 1; k < size<2>(tensor); ++k) {
for (int i = 0; i < 5*size<1>(tensor); ++i) { print("-"); } print("\n");
print_tensor(tensor(_,_,k), false);
}
} else
if constexpr (Layout::rank == 4)
{
print_tensor(tensor(_,_,_,0), false);
for (int p = 1; p < size<3>(tensor); ++p) {
for (int i = 0; i < 5*size<1>(tensor); ++i) { print("="); } print("\n");
print_tensor(tensor(_,_,_,p), false);
}
}
}
#if !defined(__CUDACC_RTC__)
template <class Engine, class Layout>
CUTE_HOST
std::ostream&
print_tensor_os(std::ostream& os, Tensor<Engine,Layout> const& tensor)
{
int digits = 9;
if constexpr (Layout::rank == 1)
{
for (int m = 0; m < size(tensor); ++m) {
os << std::setw(digits) << tensor(m) << std::endl;
}
} else
if constexpr (Layout::rank == 2)
{
for (int m = 0; m < size<0>(tensor); ++m) {
for (int n = 0; n < size<1>(tensor); ++n) {
os << std::setw(digits) << tensor(m,n);
}
os << std::endl;
}
} else
if constexpr (Layout::rank == 3)
{
print_tensor_os(os, tensor(_,_,0));
for (int k = 1; k < size<2>(tensor); ++k) {
for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "-"; } os << std::endl;
print_tensor_os(os, tensor(_,_,k));
}
} else
if constexpr (Layout::rank == 4)
{
print_tensor_os(os, tensor(_,_,_,0));
for (int p = 1; p < size<3>(tensor); ++p) {
for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "="; } os << std::endl;
print_tensor_os(os, tensor(_,_,_,p));
}
}
return os;
}
template <class Engine, class Layout>
CUTE_HOST
std::ostream&
operator<<(std::ostream& os, Tensor<Engine,Layout> const& tensor)
{
os << tensor.layout() << std::endl;
return print_tensor_os(os, tensor);
}
#endif // !defined(__CUDACC_RTC__)
} // end namespace cute

View File

@ -92,6 +92,29 @@ using CUTE_STL_NAMESPACE::remove_const_t;
using CUTE_STL_NAMESPACE::remove_cv_t;
using CUTE_STL_NAMESPACE::remove_reference_t;
template <class Src, class Dst>
struct copy_cv {
using type = Dst;
};
template <class Src, class Dst>
struct copy_cv<Src const, Dst> {
using type = Dst const;
};
template <class Src, class Dst>
struct copy_cv<Src volatile, Dst> {
using type = Dst volatile;
};
template <class Src, class Dst>
struct copy_cv<Src const volatile, Dst> {
using type = Dst const volatile;
};
template <class Src, class Dst>
using copy_cv_t = typename copy_cv<Src,Dst>::type;
using CUTE_STL_NAMESPACE::extent;
using CUTE_STL_NAMESPACE::remove_extent;