v4.1 release
This commit is contained in:
@ -260,13 +260,10 @@ copy(AutoVectorizingCopyWithAssumedAlignment<MaxVecBits> const&,
|
||||
{
|
||||
// If more than one element vectorizes to 8bits or more, then recast and copy
|
||||
using VecType = uint_bit_t<vec_bits>;
|
||||
// Preserve volatility
|
||||
using SrcVecType = conditional_t<is_volatile_v<typename SrcEngine::element_type>, VecType const volatile, VecType const>;
|
||||
using DstVecType = conditional_t<is_volatile_v<typename DstEngine::element_type>, VecType volatile, VecType >;
|
||||
|
||||
// Recast
|
||||
Tensor src_v = recast<SrcVecType>(src);
|
||||
Tensor dst_v = recast<DstVecType>(dst);
|
||||
Tensor src_v = recast<VecType>(src);
|
||||
Tensor dst_v = recast<VecType>(dst);
|
||||
return copy_if(constant_fn<true_type>{}, src_v, dst_v);
|
||||
} else {
|
||||
return copy_if(constant_fn<true_type>{}, src, dst);
|
||||
|
||||
@ -325,21 +325,6 @@ struct TiledCopy : Copy_Atom
|
||||
return tile2thrfrg(ref_S, right_inverse(AtomLayoutRef{}).compose(AtomLayoutSrc{}))(_,_,Int<0>{});
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr static
|
||||
auto
|
||||
get_layoutS_MN()
|
||||
{
|
||||
// (thr_idx,val_idx) -> (M,N)
|
||||
auto layoutS_TV = get_layoutS_TV();
|
||||
// (M,K) -> (thr_idx,val_idx)
|
||||
auto layoutS_MK = right_inverse(layoutS_TV).with_shape(shape(Tiler_MN{}));
|
||||
|
||||
// athrid = (v,m,k) -> thr_idx
|
||||
auto thrID_S = make_layout(size<0>(TiledLayout_TV{}));
|
||||
|
||||
return cute::make_tuple(layoutS_MK, thrID_S);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr static
|
||||
auto
|
||||
get_layoutD_TV()
|
||||
@ -350,21 +335,6 @@ struct TiledCopy : Copy_Atom
|
||||
return tile2thrfrg(ref_D, right_inverse(AtomLayoutRef{}).compose(AtomLayoutDst{}))(_,_,Int<0>{});
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr static
|
||||
auto
|
||||
get_layoutD_MN()
|
||||
{
|
||||
// (thr_idx,val_idx) -> (M,N)
|
||||
auto layoutD_TV = get_layoutD_TV();
|
||||
// (M,K) -> (thr_idx,val_idx)
|
||||
auto layoutD_MK = right_inverse(layoutD_TV).with_shape(shape(Tiler_MN{}));
|
||||
|
||||
// athrid = (v,m,k) -> thr_idx
|
||||
auto thrID_D = make_layout(size<0>(TiledLayout_TV{}));
|
||||
|
||||
return cute::make_tuple(layoutD_MK, thrID_D);
|
||||
}
|
||||
|
||||
template <class ThrIdx,
|
||||
__CUTE_REQUIRES(is_integral<ThrIdx>::value)>
|
||||
CUTE_HOST_DEVICE static
|
||||
@ -680,101 +650,6 @@ print(ThrCopy<TiledCopy, ThrIdx> const& thr_copy)
|
||||
print(TiledCopy{});
|
||||
}
|
||||
|
||||
// TiledCopy to LaTeX TikZ
|
||||
template <class... Args, class TikzColorFn = TikzColor_TV>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
print_latex(TiledCopy<Args...> const& copy,
|
||||
TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string
|
||||
{
|
||||
auto [layoutS_MN, thrID_S] = copy.get_layoutS_MN();
|
||||
auto [layoutD_MN, thrID_D] = copy.get_layoutD_MN();
|
||||
|
||||
print_latex_copy(layoutS_MN, thrID_S,
|
||||
layoutD_MN, thrID_D);
|
||||
}
|
||||
|
||||
// MNK Copy Layout to LaTeX TikZ
|
||||
template <class LayoutS, class ThrIDS,
|
||||
class LayoutD, class ThrIDD,
|
||||
class TikzColorFn = TikzColor_TV>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and tid -> thr_idx
|
||||
LayoutD const& D, ThrIDD const& TD, // (m,n) -> (tid,vid) and tid -> thr_idx
|
||||
TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(rank(S) == Int<2>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(D) == Int<2>{});
|
||||
|
||||
assert(size<0>(S) == size<0>(D));
|
||||
assert(size<1>(S) == size<1>(D));
|
||||
|
||||
// Commented prints
|
||||
printf("%% LayoutS: "); print(S); printf("\n");
|
||||
printf("%% ThrIDS : "); print(TS); printf("\n");
|
||||
printf("%% LayoutD: "); print(D); printf("\n");
|
||||
printf("%% ThrIDD : "); print(TD); printf("\n\n");
|
||||
|
||||
// Header
|
||||
printf("\\documentclass[convert]{standalone}\n"
|
||||
"\\usepackage{tikz}\n\n"
|
||||
"\\begin{document}\n"
|
||||
"\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n");
|
||||
|
||||
// S starting at 0,0
|
||||
for (int i = 0; i < size<0>(S); ++i) {
|
||||
for (int j = 0; j < size<1>(S); ++j) {
|
||||
int thrid = S(i,j) % size(TS);
|
||||
int val_idx = S(i,j) / size(TS);
|
||||
int thr_idx = TS(thrid);
|
||||
|
||||
printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
|
||||
color(thr_idx, val_idx),
|
||||
i, j,
|
||||
thr_idx, val_idx);
|
||||
}
|
||||
}
|
||||
// Grid
|
||||
printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n",
|
||||
0, 0, int(size<0>(S)), int(size<1>(S)));
|
||||
// S Labels
|
||||
for (int i = 0, j = -1; i < size<0>(S); ++i) {
|
||||
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i);
|
||||
}
|
||||
for (int i = -1, j = 0; j < size<1>(S); ++j) {
|
||||
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j);
|
||||
}
|
||||
|
||||
// D starting at 0,size<1>(S)+3
|
||||
for (int i = 0; i < size<0>(D); ++i) {
|
||||
for (int j = 0; j < size<1>(D); ++j) {
|
||||
int thrid = D(i,j) % size(TD);
|
||||
int val_idx = D(i,j) / size(TD);
|
||||
int thr_idx = TD(thrid);
|
||||
|
||||
printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
|
||||
color(thr_idx, val_idx),
|
||||
i, j + size<1>(S) + 3,
|
||||
thr_idx, val_idx);
|
||||
}
|
||||
}
|
||||
// Grid
|
||||
printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n",
|
||||
0, int(size<1>(S)+3), int(size<0>(D)), int(size<1>(D)+size<1>(S)+3));
|
||||
// D Labels
|
||||
for (int i = 0, j = size<1>(D); i < size<0>(D); ++i) {
|
||||
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j + size<1>(S) + 3, i);
|
||||
}
|
||||
for (int i = -1, j = 0; j < size<1>(D); ++j) {
|
||||
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j + size<1>(S) + 3, j);
|
||||
}
|
||||
|
||||
// Footer
|
||||
printf("\\end{tikzpicture}\n"
|
||||
"\\end{document}\n");
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -180,10 +180,10 @@ struct MMA_Atom<MMA_Traits<MMAOperation, Args...>>
|
||||
if constexpr (has_dereference<FrgTypeB>::value) {
|
||||
// If the intended FrgTypeB is a view (of the current tensor), forward the whole
|
||||
static_assert(is_same<ValTypeB, typename remove_cvref_t<BTensor>::value_type>::value
|
||||
|
||||
|
||||
|| (sizeof_bits_v<typename remove_cvref_t<BTensor>::value_type> == 8 &&
|
||||
(sizeof_bits_v<ValTypeB> == 8 || sizeof_bits_v<ValTypeB> == 6 || sizeof_bits_v<ValTypeB> == 4))
|
||||
|
||||
|
||||
, "Expecting ValTypeB type");
|
||||
return make_tensor<FrgTypeB>(static_cast<BTensor&&>(btensor));
|
||||
} else {
|
||||
@ -394,55 +394,22 @@ struct TiledMMA : MMA_Atom
|
||||
return size(permutation_mnk<I>());
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
get_layoutC_MN() const
|
||||
{
|
||||
// (M,N) -> (M,N)
|
||||
auto ref_C = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<1>()));
|
||||
// (cthrid,val) -> (M,N)
|
||||
auto layoutC_TV = thrfrg_C(ref_C);
|
||||
// (M,N) -> (cthrid,frg)
|
||||
auto layoutC_MN = right_inverse(layoutC_TV).with_shape(shape(ref_C));
|
||||
|
||||
// cthrid = (v,m,n) -> thr_idx
|
||||
auto thrID_C = thr_layout_vmnk_(_,_,_,Int<0>{});
|
||||
|
||||
return cute::make_tuple(layoutC_MN, thrID_C);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
get_layoutC_TV() const
|
||||
{
|
||||
// (M,N) -> (M,N)
|
||||
auto ref_C = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<1>()));
|
||||
// (cthrid,val) -> (M,N)
|
||||
auto layoutC_TV = thrfrg_C(ref_C);
|
||||
|
||||
// thr_idx -> (ThrV,ThrM,ThrN,ThrK)
|
||||
auto thridx_2_thrid = right_inverse(thr_layout_vmnk_);
|
||||
auto thridx_2_thrid = composition(make_layout(make_shape (size(thr_layout_vmnk_), Int<1>{}),
|
||||
make_stride(Int<1>{}, Int<0>{})),
|
||||
right_inverse(make_layout(thr_layout_vmnk_, complement(thr_layout_vmnk_))));
|
||||
|
||||
// (thr_idx,val) -> (M,N)
|
||||
return layoutC_TV.compose(thridx_2_thrid, _);
|
||||
return thrfrg_C(ref_C).compose(thridx_2_thrid, _);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
get_layoutA_MK() const
|
||||
{
|
||||
// (M,K) -> (M,K)
|
||||
auto ref_A = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<2>()));
|
||||
// (athrid,val) -> (M,K)
|
||||
auto layoutA_TV = thrfrg_A(ref_A);
|
||||
// (M,K) -> (athrid,frg)
|
||||
auto layoutA_MK = right_inverse(layoutA_TV).with_shape(shape(ref_A));
|
||||
|
||||
// athrid = (v,m,k) -> thr_idx
|
||||
auto thrID_A = thr_layout_vmnk_(_,_,Int<0>{},_);
|
||||
|
||||
return cute::make_tuple(layoutA_MK, thrID_A);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
@ -458,29 +425,14 @@ struct TiledMMA : MMA_Atom
|
||||
_));
|
||||
|
||||
// thr_idx -> (ThrV,ThrM,ThrN,ThrK)
|
||||
auto thridx_2_thrid = right_inverse(thr_layout_vmnk_);
|
||||
auto thridx_2_thrid = composition(make_layout(make_shape (size(thr_layout_vmnk_), Int<1>{}),
|
||||
make_stride(Int<1>{}, Int<0>{})),
|
||||
right_inverse(make_layout(thr_layout_vmnk_, complement(thr_layout_vmnk_))));
|
||||
|
||||
// (thr_idx,val) -> (M,K)
|
||||
return thrfrg_A(ref_A).compose(atile, _).compose(thridx_2_thrid, _);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
get_layoutB_NK() const
|
||||
{
|
||||
// (N,K) -> (N,K)
|
||||
auto ref_B = make_layout(make_shape(tile_size_mnk<1>(), tile_size_mnk<2>()));
|
||||
// (bthrid,val) -> (N,K)
|
||||
auto layoutB_TV = thrfrg_B(ref_B);
|
||||
// (N,K) -> (bthrid,frg)
|
||||
auto layoutB_NK = right_inverse(layoutB_TV).with_shape(shape(ref_B));
|
||||
|
||||
// bthrid = (v,n,k) -> thr_idx
|
||||
auto thrID_B = thr_layout_vmnk_(_,Int<0>{},_,_);
|
||||
|
||||
return cute::make_tuple(layoutB_NK, thrID_B);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
get_layoutB_TV() const
|
||||
@ -495,7 +447,9 @@ struct TiledMMA : MMA_Atom
|
||||
_));
|
||||
|
||||
// thr_idx -> (ThrV,ThrM,ThrN,ThrK)
|
||||
auto thridx_2_thrid = right_inverse(thr_layout_vmnk_);
|
||||
auto thridx_2_thrid = composition(make_layout(make_shape (size(thr_layout_vmnk_), Int<1>{}),
|
||||
make_stride(Int<1>{}, Int<0>{})),
|
||||
right_inverse(make_layout(thr_layout_vmnk_, complement(thr_layout_vmnk_))));
|
||||
|
||||
// (thr_idx,val) -> (N,K)
|
||||
return thrfrg_B(ref_B).compose(btile, _).compose(thridx_2_thrid, _);
|
||||
@ -733,376 +687,6 @@ print(ThrMMA<TiledMMA, ThrVMNK> const& thr_mma)
|
||||
print(static_cast<TiledMMA>(thr_mma));
|
||||
}
|
||||
|
||||
// MMA Atom to LaTeX TikZ
|
||||
template <class... Args, class TikzColorFn = TikzColor_TV>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print_latex(MMA_Atom<Args...> const& mma_atom,
|
||||
TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string
|
||||
{
|
||||
print_latex(make_tiled_mma(mma_atom));
|
||||
}
|
||||
|
||||
// TiledMMA to LaTeX TikZ
|
||||
template <class... Args, class TikzColorFn = TikzColor_TV>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print_latex(TiledMMA<Args...> const& mma,
|
||||
TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string
|
||||
{
|
||||
auto layout_and_thrid_C = mma.get_layoutC_MN();
|
||||
auto layoutC_MN = get<0>(layout_and_thrid_C);
|
||||
auto thrID_C = get<1>(layout_and_thrid_C);
|
||||
|
||||
auto layout_and_thrid_A = mma.get_layoutA_MK();
|
||||
auto layoutA_MK = get<0>(layout_and_thrid_A);
|
||||
auto thrID_A = get<1>(layout_and_thrid_A);
|
||||
|
||||
auto layout_and_thrid_B = mma.get_layoutB_NK();
|
||||
auto layoutB_NK = get<0>(layout_and_thrid_B);
|
||||
auto thrID_B = get<1>(layout_and_thrid_B);
|
||||
|
||||
print_latex_mma(layoutC_MN, thrID_C,
|
||||
layoutA_MK, thrID_A,
|
||||
layoutB_NK, thrID_B);
|
||||
}
|
||||
|
||||
// MNK MMA Layout to LaTeX TikZ
|
||||
template <class LayoutC, class ThrIDC,
|
||||
class LayoutA, class ThrIDA,
|
||||
class LayoutB, class ThrIDB,
|
||||
class TikzColorFn = TikzColor_TV>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print_latex_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and tid -> thr_idx
|
||||
LayoutA const& A, ThrIDA const& TA, // (m,k) -> (tid,vid) and tid -> thr_idx
|
||||
LayoutB const& B, ThrIDB const& TB, // (n,k) -> (tid,vid) and tid -> thr_idx
|
||||
TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{});
|
||||
|
||||
assert(size<0>(A) == size<0>(C));
|
||||
assert(size<0>(B) == size<1>(C));
|
||||
assert(size<1>(A) == size<1>(B));
|
||||
|
||||
// Commented prints
|
||||
printf("%% LayoutC: "); print(C); printf("\n");
|
||||
printf("%% ThrIDC : "); print(TC); printf("\n");
|
||||
printf("%% LayoutA: "); print(A); printf("\n");
|
||||
printf("%% ThrIDA : "); print(TA); printf("\n");
|
||||
printf("%% LayoutB: "); print(B); printf("\n");
|
||||
printf("%% ThrIDB : "); print(TB); printf("\n\n");
|
||||
// Header
|
||||
printf("\\documentclass[convert]{standalone}\n"
|
||||
"\\usepackage{tikz}\n\n"
|
||||
"\\begin{document}\n"
|
||||
"\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n");
|
||||
|
||||
// C starting at 0,0
|
||||
for (int m = 0; m < size<0>(C); ++m) {
|
||||
for (int n = 0; n < size<1>(C); ++n) {
|
||||
int thrid = C(m,n) % size(TC);
|
||||
int val_idx = C(m,n) / size(TC);
|
||||
int thr_idx = TC(thrid);
|
||||
|
||||
printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
|
||||
color(thr_idx, val_idx),
|
||||
m, n,
|
||||
thr_idx, val_idx);
|
||||
}
|
||||
}
|
||||
// Grid
|
||||
printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n",
|
||||
0, 0, int(size<0>(C)), int(size<1>(C)));
|
||||
|
||||
// A starting at 0,-size<1>(A)-1
|
||||
for (int m = 0; m < size<0>(A); ++m) {
|
||||
for (int k = 0; k < size<1>(A); ++k) {
|
||||
int thrid = A(m,k) % size(TA);
|
||||
int val_idx = A(m,k) / size(TA);
|
||||
int thr_idx = TA(thrid);
|
||||
|
||||
printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
|
||||
color(thr_idx, val_idx),
|
||||
m, k-1-size<1>(A),
|
||||
thr_idx, val_idx);
|
||||
}
|
||||
}
|
||||
// Grid
|
||||
printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n",
|
||||
0, int(-size<1>(A)-1), int(size<0>(A)), -1);
|
||||
// A labels
|
||||
for (int m = 0, k = -1; m < size<0>(A); ++m) {
|
||||
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k-1-size<1>(A), m);
|
||||
}
|
||||
for (int m = -1, k = 0; k < size<1>(A); ++k) {
|
||||
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k-1-size<1>(A), k);
|
||||
}
|
||||
|
||||
// B starting at -size<1>(B)-1,0
|
||||
for (int n = 0; n < size<0>(B); ++n) {
|
||||
for (int k = 0; k < size<1>(B); ++k) {
|
||||
int thrid = B(n,k) % size(TB);
|
||||
int val_idx = B(n,k) / size(TB);
|
||||
int thr_idx = TB(thrid);
|
||||
|
||||
printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
|
||||
color(thr_idx, val_idx),
|
||||
k-1-size<1>(B), n,
|
||||
thr_idx, val_idx);
|
||||
}
|
||||
}
|
||||
// Grid
|
||||
printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n",
|
||||
int(-size<1>(B)-1), 0, -1, int(size<0>(B)));
|
||||
// B labels
|
||||
for (int n = 0, k = -1; n < size<0>(B); ++n) {
|
||||
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", k-1-size<1>(B), n, n);
|
||||
}
|
||||
for (int n = -1, k = 0; k < size<1>(B); ++k) {
|
||||
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", k-1-size<1>(B), n, k);
|
||||
}
|
||||
|
||||
// Footer
|
||||
printf("\\end{tikzpicture}\n"
|
||||
"\\end{document}\n");
|
||||
}
|
||||
|
||||
// MNK MMA Layout to console printer
|
||||
template <class LayoutC, class ThrIDC,
|
||||
class LayoutA, class ThrIDA,
|
||||
class LayoutB, class ThrIDB>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print_layout_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and tid -> thr_idx
|
||||
LayoutA const& A, ThrIDA const& TA, // (m,k) -> (tid,vid) and tid -> thr_idx
|
||||
LayoutB const& B, ThrIDB const& TB) // (n,k) -> (tid,vid) and tid -> thr_idx
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{});
|
||||
|
||||
assert(size<0>(A) == size<0>(C));
|
||||
assert(size<0>(B) == size<1>(C));
|
||||
assert(size<1>(A) == size<1>(B));
|
||||
|
||||
int a_width = size<1>(A) * 6 + 4;
|
||||
|
||||
// Print out B (white-shifted) k-by-n
|
||||
for (int k = 0; k < size<1>(B); ++k) {
|
||||
// Header
|
||||
printf("%*s", a_width, "");
|
||||
for (int n = 0; n < size<0>(B); ++n) printf("+-----");
|
||||
printf("+\n");
|
||||
// Values
|
||||
printf("%*s", a_width, "");
|
||||
for (int n = 0; n < size<0>(B); ++n) printf("|T%02dV%1d", int(TB(B(n,k) % size(TB))), int(B(n,k) / size(TB)));
|
||||
printf("|\n");
|
||||
}
|
||||
// Footer
|
||||
printf("%*s", a_width, "");
|
||||
for (int n = 0; n < size<0>(B); ++n) printf("+-----");
|
||||
printf("+\n\n");
|
||||
|
||||
// Print out A m-by-k and C m-by-n
|
||||
for (int m = 0; m < size<0>(A); ++m) {
|
||||
// Header
|
||||
for (int k = 0; k < size<1>(A); ++k) printf("+-----");
|
||||
printf("+ ");
|
||||
for (int n = 0; n < size<1>(C); ++n) printf("+-----");
|
||||
printf("+\n");
|
||||
// Values
|
||||
for (int k = 0; k < size<1>(A); ++k) printf("|T%02dV%1d", int(TA(A(m,k) % size(TA))), int(A(m,k) / size(TA)));
|
||||
printf("| ");
|
||||
for (int n = 0; n < size<1>(C); ++n) printf("|T%02dV%1d", int(TC(C(m,n) % size(TC))), int(C(m,n) / size(TC)));
|
||||
printf("|\n");
|
||||
}
|
||||
// Footer
|
||||
for (int k = 0; k < size<1>(A); ++k) printf("+-----");
|
||||
printf("+ ");
|
||||
for (int n = 0; n < size<1>(C); ++n) printf("+-----");
|
||||
printf("+\n");
|
||||
}
|
||||
|
||||
// MNK MMA Layout to SVG -- 8-value color coded by thread
|
||||
template <class LayoutC, class ThrIDC,
|
||||
class LayoutA, class ThrIDA,
|
||||
class LayoutB, class ThrIDB>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print_svg_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and tid -> thr_idx
|
||||
LayoutA const& A, ThrIDA const& TA, // (m,k) -> (tid,vid) and tid -> thr_idx
|
||||
LayoutB const& B, ThrIDB const& TB) // (n,k) -> (tid,vid) and tid -> thr_idx
|
||||
{
|
||||
char const *color_map[8] = {"175,175,255", "175,255,175", "255,255,175",
|
||||
"255,175,175", "210,210,255", "210,255,210",
|
||||
"255,255,210", "255,210,210"};
|
||||
|
||||
const int cell_width = 20;
|
||||
const int cell_height = 20;
|
||||
|
||||
const int page_width = (size<1>(A) + size<0>(B) + 2) * cell_width;
|
||||
const int page_height = (size<1>(B) + size<0>(A) + 2) * cell_height;
|
||||
|
||||
// header
|
||||
printf("<svg width=\"100%%\" height=\"100%%\" viewBox=\"0 0 %d %d\" "
|
||||
"preserveAspectRatio=\"xMidYMid meet\" "
|
||||
"xmlns=\"http://www.w3.org/2000/svg\">\n",
|
||||
page_width, page_height);
|
||||
|
||||
// C
|
||||
int c_base_x = (size<1>(A) + 2) * cell_width;
|
||||
int c_base_y = (size<1>(B) + 2) * cell_height;
|
||||
for (int m = 0; m < cute::size<0>(C); ++m) {
|
||||
for (int n = 0; n < cute::size<1>(C); ++n) {
|
||||
|
||||
int thrid = C(m, n) % size(TC);
|
||||
int val_idx = C(m, n) / size(TC);
|
||||
int thr_idx = TC(thrid);
|
||||
|
||||
int x = n * cell_width + c_base_x;
|
||||
int y = m * cell_height + c_base_y;
|
||||
|
||||
int thr_x = x + cell_width / 2;
|
||||
int thr_y = y + cell_height / 4;
|
||||
int val_x = x + cell_width / 2;
|
||||
int val_y = y + cell_height * 3 / 4;
|
||||
|
||||
printf("<rect x=\"%d\" y=\"%d\" width=\"%d\" height=\"%d\" "
|
||||
"fill=\"rgb(%s)\" stroke=\"black\"/>\n",
|
||||
x, y, cell_width, cell_height, color_map[thr_idx % 8]);
|
||||
|
||||
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
|
||||
"alignment-baseline=\"central\" font-size=\"8\">T%d</text>\n",
|
||||
thr_x, thr_y, thr_idx);
|
||||
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
|
||||
"alignment-baseline=\"central\" font-size=\"8\">V%d</text>\n",
|
||||
val_x, val_y, val_idx);
|
||||
}
|
||||
}
|
||||
|
||||
// A
|
||||
int a_base_x = cell_width;
|
||||
int a_base_y = (size<1>(B) + 2) * cell_height;
|
||||
for (int m = 0; m < size<0>(A); ++m) {
|
||||
for (int k = 0; k < size<1>(A); ++k) {
|
||||
int thrid = A(m, k) % size(TA);
|
||||
int val_idx = A(m, k) / size(TA);
|
||||
int thr_idx = TA(thrid);
|
||||
|
||||
int x = k * cell_width + a_base_x;
|
||||
int y = m * cell_height + a_base_y;
|
||||
|
||||
int thr_x = x + cell_width / 2;
|
||||
int thr_y = y + cell_height / 4;
|
||||
int val_x = x + cell_width / 2;
|
||||
int val_y = y + cell_height * 3 / 4;
|
||||
|
||||
printf("<rect x=\"%d\" y=\"%d\" width=\"%d\" height=\"%d\" "
|
||||
"fill=\"rgb(%s)\" stroke=\"black\" />\n",
|
||||
x, y, cell_width, cell_height, color_map[thr_idx % 8]);
|
||||
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
|
||||
"alignment-baseline=\"central\" font-size=\"8\">T%d</text>\n",
|
||||
thr_x, thr_y, thr_idx);
|
||||
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
|
||||
"alignment-baseline=\"central\" font-size=\"8\">V%d</text>\n",
|
||||
val_x, val_y, val_idx);
|
||||
}
|
||||
}
|
||||
|
||||
// B
|
||||
int b_base_x = (size<1>(A) + 2) * cell_width;
|
||||
int b_base_y = cell_height;
|
||||
for (int n = 0; n < size<0>(B); ++n) {
|
||||
for (int k = 0; k < size<1>(B); ++k) {
|
||||
int thrid = B(n, k) % size(TB);
|
||||
int val_idx = B(n, k) / size(TB);
|
||||
int thr_idx = TB(thrid);
|
||||
|
||||
int x = n * cell_width + b_base_x;
|
||||
int y = k * cell_height + b_base_y;
|
||||
|
||||
int thr_x = x + cell_width / 2;
|
||||
int thr_y = y + cell_height / 4;
|
||||
int val_x = x + cell_width / 2;
|
||||
int val_y = y + cell_height * 3 / 4;
|
||||
|
||||
printf("<rect x=\"%d\" y=\"%d\" width=\"%d\" height=\"%d\" "
|
||||
"fill=\"rgb(%s)\" stroke=\"black\" />\n",
|
||||
x, y, cell_width, cell_height, color_map[thr_idx % 8]);
|
||||
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
|
||||
"alignment-baseline=\"central\" font-size=\"8\">T%d</text>\n",
|
||||
thr_x, thr_y, thr_idx);
|
||||
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
|
||||
"alignment-baseline=\"central\" font-size=\"8\">V%d</text>\n",
|
||||
val_x, val_y, val_idx);
|
||||
}
|
||||
}
|
||||
|
||||
// A labels
|
||||
for (int m = 0; m < size<0>(A); ++m) {
|
||||
int x = cell_width / 2;
|
||||
int y = m * cell_height + cell_height / 2 + a_base_y;
|
||||
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
|
||||
"alignment-baseline=\"central\" font-size=\"12\">%d</text>\n",
|
||||
x, y, m);
|
||||
}
|
||||
for (int k = 0; k < size<1>(A); ++k) {
|
||||
int x = cell_width + k * cell_width + cell_width / 2;
|
||||
int y = -cell_height / 2 + a_base_y;
|
||||
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
|
||||
"alignment-baseline=\"central\" font-size=\"12\">%d</text>\n",
|
||||
x, y, k);
|
||||
}
|
||||
|
||||
// B labels
|
||||
for (int n = 0; n < size<0>(B); ++n) {
|
||||
int x = b_base_x + cell_width * n + cell_width / 2;
|
||||
int y = cell_height / 2;
|
||||
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
|
||||
"alignment-baseline=\"central\" font-size=\"12\">%d</text>\n",
|
||||
x, y, n);
|
||||
}
|
||||
for (int k = 0; k < size<1>(B); ++k) {
|
||||
int x = b_base_x - cell_width / 2;
|
||||
int y = cell_height * (k + 1) + cell_height / 2;
|
||||
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
|
||||
"alignment-baseline=\"central\" font-size=\"12\">%d</text>\n",
|
||||
x, y, k);
|
||||
}
|
||||
|
||||
// footer
|
||||
printf("</svg>\n");
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print_svg(MMA_Atom<Args...> const &mma_atom) {
|
||||
print_svg(make_tiled_mma(mma_atom));
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print_svg(TiledMMA<Args...> const &mma) {
|
||||
auto layout_and_thrid_C = mma.get_layoutC_MN();
|
||||
auto layoutC_MN = get<0>(layout_and_thrid_C);
|
||||
auto thrID_C = get<1>(layout_and_thrid_C);
|
||||
|
||||
auto layout_and_thrid_A = mma.get_layoutA_MK();
|
||||
auto layoutA_MK = get<0>(layout_and_thrid_A);
|
||||
auto thrID_A = get<1>(layout_and_thrid_A);
|
||||
|
||||
auto layout_and_thrid_B = mma.get_layoutB_NK();
|
||||
auto layoutB_NK = get<0>(layout_and_thrid_B);
|
||||
auto thrID_B = get<1>(layout_and_thrid_B);
|
||||
|
||||
print_svg_mma(layoutC_MN, thrID_C, layoutA_MK, thrID_A, layoutB_NK, thrID_B);
|
||||
}
|
||||
|
||||
} // namespace cute
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1114,7 +698,7 @@ print_svg(TiledMMA<Args...> const &mma) {
|
||||
#include <cute/atom/mma_traits_sm89.hpp>
|
||||
#include <cute/atom/mma_traits_sm90.hpp>
|
||||
#include <cute/atom/mma_traits_sm90_gmma.hpp>
|
||||
#include <cute/atom/mma_traits_sm100.hpp>
|
||||
#include <cute/atom/mma_traits_sm100.hpp>
|
||||
#include <cute/atom/mma_traits_sm120.hpp>
|
||||
#include <cute/atom/mma_traits_sm120_sparse.hpp>
|
||||
|
||||
|
||||
@ -3844,4 +3844,39 @@ struct MMA_Traits<SM100_MMA_MXF4NVF4_2x1SM_SS_SPARSE<a_type, b_type, c_type, sf_
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Specialization for a vectorized FMA per thread.
|
||||
*/
|
||||
template <>
|
||||
struct MMA_Traits<SM100_2x1x1_F32F32F32F32>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = float;
|
||||
using ValTypeB = float;
|
||||
using ValTypeC = float;
|
||||
|
||||
using Shape_MNK = Shape<_2,_1,_1>;
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
using ALayout = Layout<Shape<_1,_2>>;
|
||||
using BLayout = Layout<Shape<_1,_1>>;
|
||||
using CLayout = Layout<Shape<_1,_2>>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM100_1x2x1_F32F32F32F32>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = float;
|
||||
using ValTypeB = float;
|
||||
using ValTypeC = float;
|
||||
|
||||
using Shape_MNK = Shape<_1,_2,_1>;
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
using ALayout = Layout<Shape<_1,_1>>;
|
||||
using BLayout = Layout<Shape<_1,_2>>;
|
||||
using CLayout = Layout<Shape<_1,_2>>;
|
||||
};
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
@ -834,7 +834,7 @@ coalesce_x(Layout<Shape,Stride> const& layout)
|
||||
} else {
|
||||
return detail::bw_coalesce<R-2>(flat_shape, flat_stride, get<R-1>(flat_shape), get<R-1>(flat_stride));
|
||||
}
|
||||
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
@ -1944,185 +1944,4 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, Layout<Shape,Stride> const&
|
||||
}
|
||||
#endif
|
||||
|
||||
// Generic 2D Layout to console table
|
||||
template <class Layout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print_layout(Layout const& layout) // (m,n) -> idx
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{});
|
||||
|
||||
int idx_width = num_digits(cosize(layout)) + 2;
|
||||
const char* delim = "+-----------------------";
|
||||
|
||||
print(layout); print("\n");
|
||||
|
||||
// Column indices
|
||||
print(" ");
|
||||
for (int n = 0; n < size<1>(layout); ++n) { printf(" %*d ", idx_width-2, n); }
|
||||
printf("\n");
|
||||
|
||||
// Print out A m-by-n
|
||||
for (int m = 0; m < size<0>(layout); ++m) {
|
||||
// Header
|
||||
print(" ");
|
||||
for (int n = 0; n < size<1>(layout); ++n) { printf("%.*s", idx_width+1, delim); }
|
||||
printf("+\n");
|
||||
// Values
|
||||
printf("%2d ", m); // Row indices
|
||||
for (int n = 0; n < size<1>(layout); ++n) { printf("| %*d ", idx_width-2, int(layout(m,n))); }
|
||||
printf("|\n");
|
||||
}
|
||||
// Footer
|
||||
print(" ");
|
||||
for (int n = 0; n < size<1>(layout); ++n) { printf("%.*s", idx_width+1, delim); }
|
||||
printf("+\n");
|
||||
}
|
||||
|
||||
// Generic ThrVal 2D Layout to console table
|
||||
template <class Layout, class ThrID>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print_layout(Layout const& layout, ThrID const& thrid) // (m,n) -> (tid,vid) and tid -> thr_idx
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{});
|
||||
|
||||
print(layout); print("\n");
|
||||
print(thrid); print("\n");
|
||||
|
||||
// Print out m-by-n
|
||||
for (int m = 0; m < size<0>(layout); ++m) {
|
||||
// Header
|
||||
for (int n = 0; n < size<1>(layout); ++n) printf("+------");
|
||||
printf("+\n");
|
||||
// Values
|
||||
for (int n = 0; n < size<1>(layout); ++n) printf("|%03d-%02d", int(thrid(layout(m,n) % size(thrid))), int(layout(m,n) / size(thrid)));
|
||||
printf("|\n");
|
||||
}
|
||||
// Footer
|
||||
for (int n = 0; n < size<1>(layout); ++n) printf("+------");
|
||||
printf("+\n");
|
||||
}
|
||||
|
||||
struct TikzColor_White {
|
||||
CUTE_HOST_DEVICE char const*
|
||||
operator()(int idx) const {
|
||||
return "white";
|
||||
}
|
||||
};
|
||||
|
||||
struct TikzColor_BWx8 {
|
||||
CUTE_HOST_DEVICE char const*
|
||||
operator()(int idx) const {
|
||||
static char const* color_map[8] = {"black!00", "black!40", "black!20", "black!60",
|
||||
"black!10", "black!50", "black!30", "black!70"};
|
||||
return color_map[idx % 8];
|
||||
}
|
||||
};
|
||||
|
||||
struct TikzColor_TV {
|
||||
CUTE_HOST_DEVICE char const*
|
||||
operator()(int tid, int vid) const {
|
||||
static char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}",
|
||||
"{rgb,255:red,175;green,255;blue,175}",
|
||||
"{rgb,255:red,255;green,255;blue,175}",
|
||||
"{rgb,255:red,255;green,175;blue,175}",
|
||||
"{rgb,255:red,210;green,210;blue,255}",
|
||||
"{rgb,255:red,210;green,255;blue,210}",
|
||||
"{rgb,255:red,255;green,255;blue,210}",
|
||||
"{rgb,255:red,255;green,210;blue,210}"};
|
||||
return color_map[tid % 8];
|
||||
}
|
||||
};
|
||||
|
||||
// Generic 2D Layout to LaTeX printer
|
||||
template <class LayoutA, class TikzColorFn = TikzColor_BWx8>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print_latex(LayoutA const& layout_a, // (m,n) -> idx
|
||||
TikzColorFn color = {}) // lambda(idx) -> tikz color string
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(rank(layout_a) <= Int<2>{});
|
||||
auto layout = append<2>(layout_a, Layout<_1,_0>{});
|
||||
|
||||
// Commented print(layout)
|
||||
printf("%% Layout: "); print(layout); printf("\n");
|
||||
// Header
|
||||
printf("\\documentclass[convert]{standalone}\n"
|
||||
"\\usepackage{tikz}\n\n"
|
||||
"\\begin{document}\n"
|
||||
"\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n");
|
||||
|
||||
// Layout
|
||||
for (int i = 0; i < size<0>(layout); ++i) {
|
||||
for (int j = 0; j < size<1>(layout); ++j) {
|
||||
int idx = layout(i,j);
|
||||
printf("\\node[fill=%s] at (%d,%d) {%d};\n",
|
||||
color(idx), i, j, idx);
|
||||
}
|
||||
}
|
||||
// Grid
|
||||
printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (0,0) grid (%d,%d);\n\n",
|
||||
int(size<0>(layout)), int(size<1>(layout)));
|
||||
// Labels
|
||||
for (int i = 0, j = -1; i < size<0>(layout); ++i) {
|
||||
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i);
|
||||
}
|
||||
for (int i = -1, j = 0; j < size<1>(layout); ++j) {
|
||||
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j);
|
||||
}
|
||||
|
||||
// Footer
|
||||
printf("\\end{tikzpicture}\n"
|
||||
"\\end{document}\n");
|
||||
}
|
||||
|
||||
// Generic ThrVal 2D Layout to LaTeX TikZ
|
||||
template <class Layout, class ThrID, class TikzColorFn = TikzColor_TV>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print_latex(Layout const& layout, // (m,n) -> (tid,vid)
|
||||
ThrID const& thr, // tid -> thr_idx
|
||||
TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{});
|
||||
|
||||
// Commented prints
|
||||
printf("%% Layout: "); print(layout); printf("\n");
|
||||
printf("%% ThrID : "); print(thr); printf("\n");
|
||||
// Header
|
||||
printf("\\documentclass[convert]{standalone}\n"
|
||||
"\\usepackage{tikz}\n\n"
|
||||
"\\begin{document}\n"
|
||||
"\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n");
|
||||
|
||||
// Layout
|
||||
for (int i = 0; i < size<0>(layout); ++i) {
|
||||
for (int j = 0; j < size<1>(layout); ++j) {
|
||||
int thrid = layout(i,j) % size(thr);
|
||||
int val_idx = layout(i,j) / size(thr);
|
||||
int thr_idx = thr(thrid);
|
||||
|
||||
printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
|
||||
color(thr_idx, val_idx),
|
||||
i, j,
|
||||
thr_idx, val_idx);
|
||||
}
|
||||
}
|
||||
// Grid
|
||||
printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (0,0) grid (%d,%d);\n\n",
|
||||
int(size<0>(layout)), int(size<1>(layout)));
|
||||
// Labels
|
||||
for (int i = 0, j = -1; i < size<0>(layout); ++i) {
|
||||
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i);
|
||||
}
|
||||
for (int j = 0, i = -1; j < size<1>(layout); ++j) {
|
||||
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j);
|
||||
}
|
||||
|
||||
// Footer
|
||||
printf("\\end{tikzpicture}\n"
|
||||
"\\end{document}\n");
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
@ -42,22 +42,15 @@ namespace cute
|
||||
{
|
||||
|
||||
template <class... T>
|
||||
struct ArithmeticTuple : tuple<T...>
|
||||
{
|
||||
template <class... U>
|
||||
struct ArithmeticTuple : public tuple<T...> {
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
ArithmeticTuple(ArithmeticTuple<U...> const& u)
|
||||
: tuple<T...>(static_cast<tuple<U...> const&>(u)) {}
|
||||
ArithmeticTuple() : tuple<T...>() {}
|
||||
|
||||
template <class... U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
ArithmeticTuple(tuple<U...> const& u)
|
||||
: tuple<T...>(u) {}
|
||||
ArithmeticTuple(tuple<T...> const& t) : tuple<T...>(t) {}
|
||||
|
||||
template <class... U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
ArithmeticTuple(U const&... u)
|
||||
: tuple<T...>(u...) {}
|
||||
ArithmeticTuple(T const&... t) : tuple<T...>(t...) {}
|
||||
};
|
||||
|
||||
template <class... T>
|
||||
@ -147,12 +140,12 @@ operator-(ArithmeticTuple<T...> const& t) {
|
||||
}
|
||||
|
||||
//
|
||||
// Special cases
|
||||
// Special cases for C<0>
|
||||
//
|
||||
|
||||
template <auto t, class... U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
ArithmeticTuple<U...> const&
|
||||
ArithmeticTuple<U...>
|
||||
operator+(C<t>, ArithmeticTuple<U...> const& u) {
|
||||
static_assert(t == 0, "Arithmetic tuple op+ error!");
|
||||
return u;
|
||||
@ -160,7 +153,7 @@ operator+(C<t>, ArithmeticTuple<U...> const& u) {
|
||||
|
||||
template <class... T, auto u>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
ArithmeticTuple<T...> const&
|
||||
ArithmeticTuple<T...>
|
||||
operator+(ArithmeticTuple<T...> const& t, C<u>) {
|
||||
static_assert(u == 0, "Arithmetic tuple op+ error!");
|
||||
return t;
|
||||
@ -168,7 +161,7 @@ operator+(ArithmeticTuple<T...> const& t, C<u>) {
|
||||
|
||||
template <auto t, class... U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
ArithmeticTuple<U...> const&
|
||||
ArithmeticTuple<U...>
|
||||
operator-(C<t>, ArithmeticTuple<U...> const& u) {
|
||||
static_assert(t == 0, "Arithmetic tuple op- error!");
|
||||
return -u;
|
||||
@ -176,7 +169,7 @@ operator-(C<t>, ArithmeticTuple<U...> const& u) {
|
||||
|
||||
template <class... T, auto u>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
ArithmeticTuple<T...> const&
|
||||
ArithmeticTuple<T...>
|
||||
operator-(ArithmeticTuple<T...> const& t, C<u>) {
|
||||
static_assert(u == 0, "Arithmetic tuple op- error!");
|
||||
return t;
|
||||
@ -212,27 +205,20 @@ struct ArithmeticTupleIterator
|
||||
}
|
||||
};
|
||||
|
||||
template <class Tuple>
|
||||
template <class... Ts>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
make_inttuple_iter(Tuple const& t) {
|
||||
return ArithmeticTupleIterator(as_arithmetic_tuple(t));
|
||||
}
|
||||
|
||||
template <class T0, class T1, class... Ts>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
make_inttuple_iter(T0 const& t0, T1 const& t1, Ts const&... ts) {
|
||||
return make_inttuple_iter(cute::make_tuple(t0, t1, ts...));
|
||||
make_inttuple_iter(Ts const&... ts) {
|
||||
return ArithmeticTupleIterator(as_arithmetic_tuple(ts...));
|
||||
}
|
||||
|
||||
//
|
||||
// ArithmeticTuple "basis" elements
|
||||
// A ScaledBasis<T,N> is a (at least) rank-N+1 ArithmeticTuple:
|
||||
// A ScaledBasis<T,Ns...> is a (at least) rank-N+1 ArithmeticTuple:
|
||||
// (_0,_0,...,T,_0,...)
|
||||
// with value T in the Nth mode
|
||||
|
||||
template <class T, int N>
|
||||
template <class T, int... Ns>
|
||||
struct ScaledBasis : private tuple<T>
|
||||
{
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
@ -243,40 +229,61 @@ struct ScaledBasis : private tuple<T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto) value() const { return get<0>(static_cast<tuple<T> const&>(*this)); }
|
||||
|
||||
// Deprecated: Get the first hierarchical mode in this basis.
|
||||
CUTE_HOST_DEVICE static constexpr
|
||||
auto mode() { return Int<N>{}; }
|
||||
auto mode() { return get<0>(int_sequence<Ns...>{}); }
|
||||
};
|
||||
|
||||
// Ensure flat representation
|
||||
template <class T, int... Ms, int... Ns>
|
||||
struct ScaledBasis<ScaledBasis<T, Ms...>, Ns...> : ScaledBasis<T, Ns..., Ms...> {};
|
||||
|
||||
template <class T>
|
||||
struct is_scaled_basis : false_type {};
|
||||
template <class T, int N>
|
||||
struct is_scaled_basis<ScaledBasis<T,N>> : true_type {};
|
||||
template <class T, int... Ns>
|
||||
struct is_scaled_basis<ScaledBasis<T,Ns...>> : true_type {};
|
||||
|
||||
template <class T, int N>
|
||||
struct is_integral<ScaledBasis<T,N>> : true_type {};
|
||||
template <class T, int... Ns>
|
||||
struct is_integral<ScaledBasis<T,Ns...>> : true_type {};
|
||||
|
||||
// Get the scalar T out of a ScaledBasis
|
||||
template <class SB>
|
||||
CUTE_HOST_DEVICE constexpr auto
|
||||
basis_value(SB const& e)
|
||||
// Shortcuts
|
||||
// E<> := _1
|
||||
// E<0> := (_1,_0,_0,...)
|
||||
// E<1> := (_0,_1,_0,...)
|
||||
// E<0,0> := ((_1,_0,_0,...),_0,_0,...)
|
||||
// E<0,1> := ((_0,_1,_0,...),_0,_0,...)
|
||||
// E<1,0> := (_0,(_1,_0,_0,...),_0,...)
|
||||
// E<1,1> := (_0,(_0,_1,_0,...),_0,...)
|
||||
template <int... Ns>
|
||||
using E = ScaledBasis<Int<1>,Ns...>;
|
||||
|
||||
// Apply the Ns... pack to another Tuple
|
||||
template <class T, class Tuple>
|
||||
CUTE_HOST_DEVICE decltype(auto)
|
||||
basis_get(T const&, Tuple&& t)
|
||||
{
|
||||
if constexpr (is_scaled_basis<SB>::value) {
|
||||
return basis_value(e.value());
|
||||
return static_cast<Tuple&&>(t);
|
||||
}
|
||||
|
||||
template <class T, int... Ns, class Tuple>
|
||||
CUTE_HOST_DEVICE decltype(auto)
|
||||
basis_get(ScaledBasis<T,Ns...> const&, Tuple&& t)
|
||||
{
|
||||
if constexpr (sizeof...(Ns) == 0) {
|
||||
return static_cast<Tuple&&>(t);
|
||||
} else {
|
||||
return e;
|
||||
return get<Ns...>(static_cast<Tuple&&>(t));
|
||||
}
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
// Apply the N... pack to another Tuple
|
||||
template <class SB, class Tuple>
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE decltype(auto)
|
||||
basis_get(SB const& e, Tuple&& t)
|
||||
{
|
||||
if constexpr (is_scaled_basis<SB>::value) {
|
||||
return basis_get(e.value(), get<SB::mode()>(static_cast<Tuple&&>(t)));
|
||||
basis_value(T const& e) {
|
||||
if constexpr (is_scaled_basis<T>::value) {
|
||||
return e.value();
|
||||
} else {
|
||||
return static_cast<Tuple&&>(t);
|
||||
return e;
|
||||
}
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
@ -294,65 +301,34 @@ to_atuple_i(T const& t, seq<I...>) {
|
||||
|
||||
// Turn a ScaledBases<T,N> into a rank-N+1 ArithmeticTuple
|
||||
// with N prefix 0s: (_0,_0,...N...,_0,T)
|
||||
template <class T, int N>
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
as_arithmetic_tuple(ScaledBasis<T,N> const& t) {
|
||||
return detail::to_atuple_i(as_arithmetic_tuple(t.value()), make_seq<N>{});
|
||||
as_arithmetic_tuple(ScaledBasis<T> const& t) {
|
||||
return t.value();
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
template <class T, int N, int... Ns>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
as_arithmetic_tuple(ScaledBasis<T,N,Ns...> const& t) {
|
||||
return detail::to_atuple_i(as_arithmetic_tuple(ScaledBasis<T,Ns...>{t.value()}), make_seq<N>{});
|
||||
}
|
||||
|
||||
template <int... Ns>
|
||||
struct Basis;
|
||||
|
||||
template <>
|
||||
struct Basis<> {
|
||||
using type = Int<1>;
|
||||
};
|
||||
|
||||
template <int N, int... Ns>
|
||||
struct Basis<N,Ns...> {
|
||||
using type = ScaledBasis<typename Basis<Ns...>::type, N>;
|
||||
};
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
// Shortcut for writing ScaledBasis<ScaledBasis<ScaledBasis<Int<1>, N0>, N1>, ...>
|
||||
// E<> := _1
|
||||
// E<0> := (_1,_0,_0,...)
|
||||
// E<1> := (_0,_1,_0,...)
|
||||
// E<0,0> := ((_1,_0,_0,...),_0,_0,...)
|
||||
// E<0,1> := ((_0,_1,_0,...),_0,_0,...)
|
||||
// E<1,0> := (_0,(_1,_0,_0,...),_0,...)
|
||||
// E<1,1> := (_0,(_0,_1,_0,...),_0,...)
|
||||
template <int... N>
|
||||
using E = typename detail::Basis<N...>::type;
|
||||
|
||||
template <class Shape>
|
||||
template <int... Ns, class Shape>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
make_basis_like(Shape const& shape)
|
||||
{
|
||||
if constexpr (is_integral<Shape>::value) {
|
||||
return Int<1>{};
|
||||
} else {
|
||||
// Generate bases for each rank of shape
|
||||
if constexpr (is_tuple<Shape>::value) {
|
||||
// Generate bases for each mode of shape
|
||||
return transform(tuple_seq<Shape>{}, shape, [](auto I, auto si) {
|
||||
// Generate bases for each rank of si and add an i on front
|
||||
using I_type = decltype(I);
|
||||
return transform_leaf(make_basis_like(si), [](auto e) {
|
||||
// MSVC has trouble capturing variables as constexpr,
|
||||
// so that they can be used as template arguments.
|
||||
// This is exactly what the code needs to do with i, unfortunately.
|
||||
// The work-around is to define i inside the inner lambda,
|
||||
// by using just the type from the enclosing scope.
|
||||
constexpr int i = I_type::value;
|
||||
return ScaledBasis<decltype(e), i>{};
|
||||
});
|
||||
// Generate bases for each si and add an i on end
|
||||
return make_basis_like<Ns...,decltype(I)::value>(si);
|
||||
});
|
||||
} else {
|
||||
return E<Ns...>{};
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
@ -360,109 +336,124 @@ make_basis_like(Shape const& shape)
|
||||
// Arithmetic
|
||||
//
|
||||
|
||||
template <class T, int M, class U>
|
||||
template <class T, int... Ns, class U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
safe_div(ScaledBasis<T,M> const& b, U const& u)
|
||||
safe_div(ScaledBasis<T,Ns...> const& b, U const& u)
|
||||
{
|
||||
auto t = safe_div(b.value(), u);
|
||||
return ScaledBasis<decltype(t),M>{t};
|
||||
return ScaledBasis<decltype(t),Ns...>{t};
|
||||
}
|
||||
|
||||
template <class T, int M, class U>
|
||||
template <class T, int... Ns, class U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
ceil_div(ScaledBasis<T,M> const& b, U const& u)
|
||||
ceil_div(ScaledBasis<T,Ns...> const& b, U const& u)
|
||||
{
|
||||
auto t = ceil_div(b.value(), u);
|
||||
return ScaledBasis<decltype(t),M>{t};
|
||||
return ScaledBasis<decltype(t),Ns...>{t};
|
||||
}
|
||||
|
||||
template <class T, int N>
|
||||
template <class T, int... Ns>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
abs(ScaledBasis<T,N> const& e)
|
||||
abs(ScaledBasis<T,Ns...> const& e)
|
||||
{
|
||||
auto t = abs(e.value());
|
||||
return ScaledBasis<decltype(t),N>{t};
|
||||
return ScaledBasis<decltype(t),Ns...>{t};
|
||||
}
|
||||
|
||||
// Equality
|
||||
template <class T, int N, class U, int M>
|
||||
template <class T, int... Ns, class U, int... Ms>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator==(ScaledBasis<T,N> const& t, ScaledBasis<U,M> const& u) {
|
||||
return bool_constant<M == N>{} && t.value() == u.value();
|
||||
operator==(ScaledBasis<T,Ns...> const& t, ScaledBasis<U,Ms...> const& u) {
|
||||
if constexpr (sizeof...(Ns) == sizeof...(Ms)) {
|
||||
return bool_constant<((Ns == Ms) && ...)>{} && t.value() == u.value();
|
||||
} else {
|
||||
return false_type{};
|
||||
}
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
// Not equal to anything else
|
||||
template <class T, int N, class U>
|
||||
template <class T, int... Ns, class U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
false_type
|
||||
operator==(ScaledBasis<T,N> const&, U const&) {
|
||||
operator==(ScaledBasis<T,Ns...> const&, U const&) {
|
||||
return {};
|
||||
}
|
||||
|
||||
template <class T, class U, int M>
|
||||
template <class T, class U, int... Ms>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
false_type
|
||||
operator==(T const&, ScaledBasis<U,M> const&) {
|
||||
operator==(T const&, ScaledBasis<U,Ms...> const&) {
|
||||
return {};
|
||||
}
|
||||
|
||||
// Multiplication
|
||||
template <class A, class T, int N>
|
||||
template <class A, class T, int... Ns>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator*(A const& a, ScaledBasis<T,N> const& e) {
|
||||
operator*(A const& a, ScaledBasis<T,Ns...> const& e) {
|
||||
auto r = a * e.value();
|
||||
return ScaledBasis<decltype(r),N>{r};
|
||||
return ScaledBasis<decltype(r),Ns...>{r};
|
||||
}
|
||||
|
||||
template <class T, int N, class B>
|
||||
template <class T, int... Ns, class B>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator*(ScaledBasis<T,N> const& e, B const& b) {
|
||||
operator*(ScaledBasis<T,Ns...> const& e, B const& b) {
|
||||
auto r = e.value() * b;
|
||||
return ScaledBasis<decltype(r),N>{r};
|
||||
return ScaledBasis<decltype(r),Ns...>{r};
|
||||
}
|
||||
|
||||
// Addition
|
||||
template <class T, int N, class U, int M>
|
||||
template <class T, int... Ns, class U, int... Ms>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator+(ScaledBasis<T,N> const& t, ScaledBasis<U,M> const& u) {
|
||||
operator+(ScaledBasis<T,Ns...> const& t, ScaledBasis<U,Ms...> const& u) {
|
||||
return as_arithmetic_tuple(t) + as_arithmetic_tuple(u);
|
||||
}
|
||||
|
||||
template <class T, int N, class... U>
|
||||
template <class T, int... Ns, class... U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator+(ScaledBasis<T,N> const& t, ArithmeticTuple<U...> const& u) {
|
||||
operator+(ScaledBasis<T,Ns...> const& t, ArithmeticTuple<U...> const& u) {
|
||||
return as_arithmetic_tuple(t) + u;
|
||||
}
|
||||
|
||||
template <class... T, class U, int M>
|
||||
template <class... T, class U, int... Ms>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator+(ArithmeticTuple<T...> const& t, ScaledBasis<U,M> const& u) {
|
||||
operator+(ArithmeticTuple<T...> const& t, ScaledBasis<U,Ms...> const& u) {
|
||||
return t + as_arithmetic_tuple(u);
|
||||
}
|
||||
|
||||
template <auto t, class U, int M>
|
||||
template <auto t, class U, int... Ms>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator+(C<t>, ScaledBasis<U,M> const& u) {
|
||||
static_assert(t == 0, "ScaledBasis op+ error!");
|
||||
return u;
|
||||
operator+(C<t>, ScaledBasis<U,Ms...> const& u) {
|
||||
if constexpr (sizeof...(Ms) == 0) {
|
||||
return C<t>{} + u.value();
|
||||
} else {
|
||||
static_assert(t == 0, "ScaledBasis op+ error!");
|
||||
return u;
|
||||
}
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <class T, int N, auto u>
|
||||
template <class T, int... Ns, auto u>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator+(ScaledBasis<T,N> const& t, C<u>) {
|
||||
static_assert(u == 0, "ScaledBasis op+ error!");
|
||||
return t;
|
||||
operator+(ScaledBasis<T,Ns...> const& t, C<u>) {
|
||||
if constexpr (sizeof...(Ns) == 0) {
|
||||
return t.value() + C<u>{};
|
||||
} else {
|
||||
static_assert(u == 0, "ScaledBasis op+ error!");
|
||||
return t;
|
||||
}
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
@ -475,10 +466,10 @@ CUTE_HOST_DEVICE void print(ArithmeticTupleIterator<ArithTuple> const& iter)
|
||||
printf("ArithTuple"); print(iter.coord_);
|
||||
}
|
||||
|
||||
template <class T, int N>
|
||||
CUTE_HOST_DEVICE void print(ScaledBasis<T,N> const& e)
|
||||
template <class T, int... Ns>
|
||||
CUTE_HOST_DEVICE void print(ScaledBasis<T,Ns...> const& e)
|
||||
{
|
||||
print(e.value()); printf("@%d", N);
|
||||
print(e.value()); (void(printf("@%d", Ns)), ...);
|
||||
}
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
@ -488,10 +479,11 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, ArithmeticTupleIterator<Ari
|
||||
return os << "ArithTuple" << iter.coord_;
|
||||
}
|
||||
|
||||
template <class T, int N>
|
||||
CUTE_HOST std::ostream& operator<<(std::ostream& os, ScaledBasis<T,N> const& e)
|
||||
template <class T, int... Ns>
|
||||
CUTE_HOST std::ostream& operator<<(std::ostream& os, ScaledBasis<T,Ns...> const& e)
|
||||
{
|
||||
return os << e.value() << "@" << N;
|
||||
os << e.value(); (void(os << "@" << Ns), ...);
|
||||
return os;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
@ -47,8 +47,9 @@ namespace cute
|
||||
// Signed integers
|
||||
//
|
||||
|
||||
using int2_t = cutlass::int2b_t;
|
||||
using int4_t = cutlass::int4b_t;
|
||||
using int2_t = cutlass::int2b_t;
|
||||
using int4_t = cutlass::int4b_t;
|
||||
using int6_t = cutlass::int6b_t;
|
||||
using CUTE_STL_NAMESPACE::int8_t;
|
||||
using CUTE_STL_NAMESPACE::int16_t;
|
||||
using CUTE_STL_NAMESPACE::int32_t;
|
||||
@ -75,10 +76,10 @@ using int_byte_t = typename int_byte<N>::type;
|
||||
// Unsigned integers
|
||||
//
|
||||
|
||||
using uint1_t = cutlass::uint1b_t;
|
||||
using uint2_t = cutlass::uint2b_t;
|
||||
using uint4_t = cutlass::uint4b_t;
|
||||
using uint6_t = cutlass::uint6b_t;
|
||||
using uint1_t = cutlass::uint1b_t;
|
||||
using uint2_t = cutlass::uint2b_t;
|
||||
using uint4_t = cutlass::uint4b_t;
|
||||
using uint6_t = cutlass::uint6b_t;
|
||||
using CUTE_STL_NAMESPACE::uint8_t;
|
||||
using CUTE_STL_NAMESPACE::uint16_t;
|
||||
using CUTE_STL_NAMESPACE::uint32_t;
|
||||
@ -88,7 +89,7 @@ template <int N> struct uint_bit;
|
||||
template <> struct uint_bit< 1> { using type = uint1_t; };
|
||||
template <> struct uint_bit< 2> { using type = uint2_t; };
|
||||
template <> struct uint_bit< 4> { using type = uint4_t; };
|
||||
template <> struct uint_bit< 6> { using type = uint6_t; };
|
||||
template <> struct uint_bit< 6> { using type = uint6_t; };
|
||||
template <> struct uint_bit< 8> { using type = uint8_t; };
|
||||
template <> struct uint_bit< 16> { using type = uint16_t; };
|
||||
template <> struct uint_bit< 32> { using type = uint32_t; };
|
||||
|
||||
@ -38,10 +38,19 @@
|
||||
|
||||
namespace cute {
|
||||
|
||||
template <typename T>
|
||||
struct sizeof_bits : public cutlass::sizeof_bits<T> {};
|
||||
template <class T>
|
||||
struct sizeof_bits : cutlass::sizeof_bits<T> {};
|
||||
|
||||
// DO NOT change auto to int, sizeof_bits<sparse_elem> use integral_ratio instead of int
|
||||
template <class T>
|
||||
struct sizeof_bits<T const> : sizeof_bits<T> {};
|
||||
|
||||
template <class T>
|
||||
struct sizeof_bits<T volatile> : sizeof_bits<T> {};
|
||||
|
||||
template <class T>
|
||||
struct sizeof_bits<T const volatile> : sizeof_bits<T> {};
|
||||
|
||||
// DO NOT change auto to int, sizeof_bits<sparse_elem> use integral_ratio instead of int
|
||||
template <class T>
|
||||
static constexpr auto sizeof_bits_v = sizeof_bits<T>::value;
|
||||
|
||||
@ -53,6 +62,23 @@ using cutlass::is_subbyte;
|
||||
template <class T>
|
||||
static constexpr auto is_subbyte_v = is_subbyte<T>::value;
|
||||
|
||||
//
|
||||
// Integral
|
||||
//
|
||||
|
||||
using cutlass::bin1_t;
|
||||
using cutlass::uint1b_t;
|
||||
using cutlass::int2b_t;
|
||||
using cutlass::uint2b_t;
|
||||
using cutlass::int4b_t;
|
||||
using cutlass::uint4b_t;
|
||||
using cutlass::int6b_t;
|
||||
using cutlass::uint6b_t;
|
||||
|
||||
//
|
||||
// Floating Point
|
||||
//
|
||||
|
||||
using cutlass::half_t;
|
||||
using cutlass::bfloat16_t;
|
||||
|
||||
@ -65,18 +91,12 @@ using cutlass::type_erased_dynamic_float8_t;
|
||||
using cutlass::float_e4m3_t;
|
||||
using cutlass::float_e5m2_t;
|
||||
|
||||
using cutlass::uint1b_t;
|
||||
using cutlass::int2b_t;
|
||||
using cutlass::uint2b_t;
|
||||
using cutlass::int4b_t;
|
||||
using cutlass::uint4b_t;
|
||||
using cutlass::bin1_t;
|
||||
|
||||
|
||||
|
||||
using cutlass::float_ue4m3_t;
|
||||
using cutlass::float_ue8m0_t;
|
||||
|
||||
using cutlass::uint6b_t;
|
||||
using cutlass::float_e2m1_t;
|
||||
using cutlass::float_e2m3_t;
|
||||
using cutlass::float_e3m2_t;
|
||||
@ -94,8 +114,6 @@ using cutlass::detail::type_erased_dynamic_float4_unpacksmem_t;
|
||||
using cutlass::detail::type_erased_dynamic_float6_unpacksmem_t;
|
||||
};
|
||||
|
||||
|
||||
|
||||
//
|
||||
// Print utility
|
||||
//
|
||||
@ -112,7 +130,6 @@ print(bfloat16_t a) {
|
||||
printf("%f", static_cast<float>(a));
|
||||
}
|
||||
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print(tfloat32_t a) {
|
||||
@ -131,6 +148,15 @@ print(float_e5m2_t a) {
|
||||
printf("%f", static_cast<float>(a));
|
||||
}
|
||||
|
||||
template <cutlass::detail::FpEncoding Encoding, class Derived>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print(cutlass::float_exmy_base<Encoding, Derived> a) {
|
||||
printf("%f", static_cast<float>(a));
|
||||
}
|
||||
|
||||
// Pretty Print utility
|
||||
|
||||
CUTE_HOST_DEVICE void
|
||||
pretty_print(bfloat16_t v) {
|
||||
printf("%*.2f", 8, float(v));
|
||||
@ -156,26 +182,11 @@ pretty_print(float_e5m2_t t) {
|
||||
printf("%*.2f", 8, static_cast<float>(t));
|
||||
}
|
||||
|
||||
|
||||
template <
|
||||
cutlass::detail::FpEncoding Encoding,
|
||||
class Derived
|
||||
>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print(cutlass::float_exmy_base<Encoding, Derived> a) {
|
||||
printf("%f", static_cast<float>(a));
|
||||
}
|
||||
|
||||
template <
|
||||
cutlass::detail::FpEncoding Encoding,
|
||||
class Derived
|
||||
>
|
||||
template <cutlass::detail::FpEncoding Encoding, class Derived>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
pretty_print_float_exmy_base(cutlass::float_exmy_base<Encoding, Derived> t) {
|
||||
printf("%*.2f", 8, static_cast<float>(t));
|
||||
}
|
||||
|
||||
|
||||
} // namespace cute
|
||||
|
||||
@ -33,9 +33,9 @@
|
||||
#include <cute/config.hpp> // CUTE_HOST_DEVICE
|
||||
#include <cute/pointer_base.hpp> // cute::iter_adaptor
|
||||
#include <cute/pointer_sparse.hpp>
|
||||
#include <cute/container/array_subbyte.hpp> // cute::subbyte_iterator
|
||||
#include <cute/numeric/integral_constant.hpp> // cute::true_type, cute::false_type
|
||||
#include <cute/numeric/numeric_types.hpp> // sizeof_bits
|
||||
#include <cute/container/array_subbyte.hpp> // cute::subbyte_iterator
|
||||
|
||||
namespace cute
|
||||
{
|
||||
@ -51,11 +51,13 @@ namespace cute
|
||||
// Requires construction of a sparse_ptr that emulates access to the S logical elements.
|
||||
//
|
||||
|
||||
template <class NewT>
|
||||
template <class NewT_, class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
recast_ptr(void* ptr)
|
||||
recast_ptr(T* ptr)
|
||||
{
|
||||
using NewT = copy_cv_t<T, NewT_>;
|
||||
|
||||
if constexpr (is_sparse<NewT>::value) {
|
||||
constexpr int sparsity = NewT::sparsity;
|
||||
NewT* p = reinterpret_cast<NewT*>(ptr);
|
||||
@ -69,24 +71,6 @@ recast_ptr(void* ptr)
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <class NewT>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
recast_ptr(void const* ptr)
|
||||
{
|
||||
if constexpr (is_sparse<NewT>::value) {
|
||||
constexpr int sparsity = NewT::sparsity;
|
||||
NewT const* p = reinterpret_cast<NewT const*>(ptr);
|
||||
return make_sparse_ptr<sparsity>(p);
|
||||
} else
|
||||
if constexpr (cute::is_subbyte_v<NewT>) {
|
||||
return subbyte_iterator<NewT const>(ptr);
|
||||
} else {
|
||||
return reinterpret_cast<NewT const*>(ptr);
|
||||
}
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
// Disambiguate nullptr
|
||||
template <class NewT>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
|
||||
@ -167,23 +167,6 @@ downcast(ComposedLayout<SwizzleFn,smem_sparse_ptr_flag_bits<S,B>,Layout> const&
|
||||
// Display utilities
|
||||
//
|
||||
|
||||
// Capture and cast smem_ptr_flag Layouts to offset-0 layouts
|
||||
template <class SwizzleFn, int B, class Layout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print_layout(ComposedLayout<SwizzleFn,smem_ptr_flag_bits<B>,Layout> const& layout)
|
||||
{
|
||||
print_layout(as_position_independent_swizzle_layout(layout));
|
||||
}
|
||||
|
||||
template <class SwizzleFn, int B, class Layout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print_latex(ComposedLayout<SwizzleFn,smem_ptr_flag_bits<B>,Layout> const& layout)
|
||||
{
|
||||
print_latex(as_position_independent_swizzle_layout(layout));
|
||||
}
|
||||
|
||||
template <int B>
|
||||
CUTE_HOST_DEVICE void print(smem_ptr_flag_bits<B> ptr)
|
||||
{
|
||||
|
||||
@ -56,3 +56,9 @@
|
||||
#include <cute/algorithm/cooperative_copy.hpp>
|
||||
#include <cute/algorithm/cooperative_gemm.hpp>
|
||||
|
||||
//
|
||||
// Utilities
|
||||
//
|
||||
|
||||
#include <cute/util/print_tensor.hpp>
|
||||
#include <cute/util/print_latex.hpp>
|
||||
|
||||
@ -753,24 +753,30 @@ domain_offset(Coord const& coord, Tensor&& tensor)
|
||||
// -- doesn't check dynamic integer divisibility
|
||||
// -- doesn't check alignment
|
||||
|
||||
template <class NewType, class Tensor>
|
||||
template <class NewType_, class Tensor>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
recast(Tensor&& tensor)
|
||||
{
|
||||
using OldType = typename remove_cvref_t<Tensor>::value_type;
|
||||
using OldType = typename remove_cvref_t<Tensor>::element_type;
|
||||
using NewType = copy_cv_t<OldType, NewType_>;
|
||||
|
||||
auto old_layout = tensor.layout();
|
||||
auto new_layout = recast_layout<OldType,NewType>(old_layout);
|
||||
|
||||
// If this is an upcast of a normal Layout with static negative strides, then offset as well
|
||||
if constexpr (sizeof(OldType) < sizeof(NewType) && not is_composed_layout<decltype(old_layout)>::value) {
|
||||
auto shape_diff = transform(flatten(old_layout.shape()), flatten(new_layout.shape()), minus{});
|
||||
auto extent_diff = transform(shape_diff, flatten(old_layout.stride()), multiplies{});
|
||||
auto offset = fold(extent_diff, Int<0>{}, [](auto const& i, auto const& a) { return i + cute::min(a,Int<0>{}); });
|
||||
|
||||
return make_tensor(recast_ptr<NewType>(static_cast<Tensor&&>(tensor).data() + offset), new_layout);
|
||||
if constexpr (is_same<NewType, OldType>::value) {
|
||||
return tensor;
|
||||
} else {
|
||||
return make_tensor(recast_ptr<NewType>(static_cast<Tensor&&>(tensor).data() ), new_layout);
|
||||
// If this is an upcast of a normal Layout with static negative strides, then offset as well
|
||||
if constexpr (sizeof(OldType) < sizeof(NewType) && not is_composed_layout<decltype(old_layout)>::value) {
|
||||
auto shape_diff = transform(flatten(old_layout.shape()), flatten(new_layout.shape()), minus{});
|
||||
auto extent_diff = transform(shape_diff, flatten(old_layout.stride()), multiplies{});
|
||||
auto offset = fold(extent_diff, Int<0>{}, [](auto const& i, auto const& a) { return i + cute::min(a,Int<0>{}); });
|
||||
|
||||
return make_tensor(recast_ptr<NewType>(static_cast<Tensor&&>(tensor).data() + offset), new_layout);
|
||||
} else {
|
||||
return make_tensor(recast_ptr<NewType>(static_cast<Tensor&&>(tensor).data() ), new_layout);
|
||||
}
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
@ -1114,95 +1120,5 @@ CUTE_HOST_DEVICE void print(Tensor<Engine,Layout> const& tensor)
|
||||
print(tensor.data()); print(" o "); print(tensor.layout());
|
||||
}
|
||||
|
||||
template <class Engine, class Layout>
|
||||
CUTE_HOST_DEVICE void print_tensor(Tensor<Engine,Layout> const& tensor, bool print_type = true)
|
||||
{
|
||||
if (print_type) {
|
||||
print(tensor); print(":\n");
|
||||
}
|
||||
|
||||
if constexpr (Layout::rank == 1)
|
||||
{
|
||||
for (int m = 0; m < size(tensor); ++m) {
|
||||
pretty_print(tensor(m));
|
||||
printf("\n");
|
||||
}
|
||||
} else
|
||||
if constexpr (Layout::rank == 2)
|
||||
{
|
||||
for (int m = 0; m < size<0>(tensor); ++m) {
|
||||
for (int n = 0; n < size<1>(tensor); ++n) {
|
||||
pretty_print(tensor(m,n));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
} else
|
||||
if constexpr (Layout::rank == 3)
|
||||
{
|
||||
print_tensor(tensor(_,_,0), false);
|
||||
for (int k = 1; k < size<2>(tensor); ++k) {
|
||||
for (int i = 0; i < 5*size<1>(tensor); ++i) { print("-"); } print("\n");
|
||||
print_tensor(tensor(_,_,k), false);
|
||||
}
|
||||
} else
|
||||
if constexpr (Layout::rank == 4)
|
||||
{
|
||||
print_tensor(tensor(_,_,_,0), false);
|
||||
for (int p = 1; p < size<3>(tensor); ++p) {
|
||||
for (int i = 0; i < 5*size<1>(tensor); ++i) { print("="); } print("\n");
|
||||
print_tensor(tensor(_,_,_,p), false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
template <class Engine, class Layout>
|
||||
CUTE_HOST std::ostream& print_tensor_os(std::ostream& os, Tensor<Engine,Layout> const& tensor)
|
||||
{
|
||||
int digits = 9;
|
||||
|
||||
if constexpr (Layout::rank == 1)
|
||||
{
|
||||
for (int m = 0; m < size(tensor); ++m) {
|
||||
os << std::setw(digits) << tensor(m) << std::endl;
|
||||
}
|
||||
} else
|
||||
if constexpr (Layout::rank == 2)
|
||||
{
|
||||
for (int m = 0; m < size<0>(tensor); ++m) {
|
||||
for (int n = 0; n < size<1>(tensor); ++n) {
|
||||
os << std::setw(digits) << tensor(m,n);
|
||||
}
|
||||
os << std::endl;
|
||||
}
|
||||
} else
|
||||
if constexpr (Layout::rank == 3)
|
||||
{
|
||||
print_tensor_os(os, tensor(_,_,0));
|
||||
for (int k = 1; k < size<2>(tensor); ++k) {
|
||||
for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "-"; } os << std::endl;
|
||||
print_tensor_os(os, tensor(_,_,k));
|
||||
}
|
||||
} else
|
||||
if constexpr (Layout::rank == 4)
|
||||
{
|
||||
print_tensor_os(os, tensor(_,_,_,0));
|
||||
for (int p = 1; p < size<3>(tensor); ++p) {
|
||||
for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "="; } os << std::endl;
|
||||
print_tensor_os(os, tensor(_,_,_,p));
|
||||
}
|
||||
}
|
||||
|
||||
return os;
|
||||
}
|
||||
|
||||
template <class Engine, class Layout>
|
||||
CUTE_HOST std::ostream& operator<<(std::ostream& os, Tensor<Engine,Layout> const& tensor)
|
||||
{
|
||||
os << tensor.layout() << std::endl;
|
||||
return print_tensor_os(os, tensor);
|
||||
}
|
||||
#endif // !defined(__CUDACC_RTC__)
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
|
||||
438
include/cute/util/print_latex.hpp
Normal file
438
include/cute/util/print_latex.hpp
Normal file
@ -0,0 +1,438 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp> // CUTE_HOST_DEVICE
|
||||
|
||||
#include <cute/atom/mma_atom.hpp>
|
||||
#include <cute/atom/copy_atom.hpp>
|
||||
|
||||
#include <cute/layout.hpp>
|
||||
#include <cute/tensor_impl.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
///////////////////////////////////////
|
||||
// Common LaTeX TikZ Color utilities //
|
||||
///////////////////////////////////////
|
||||
|
||||
struct TikzColor_White {
|
||||
CUTE_HOST_DEVICE char const*
|
||||
operator()(int idx) const {
|
||||
return "white";
|
||||
}
|
||||
};
|
||||
|
||||
struct TikzColor_BWx8 {
|
||||
CUTE_HOST_DEVICE char const*
|
||||
operator()(int idx) const {
|
||||
static char const* color_map[8] = {"black!00", "black!40", "black!20", "black!60",
|
||||
"black!10", "black!50", "black!30", "black!70"};
|
||||
return color_map[idx % 8];
|
||||
}
|
||||
};
|
||||
|
||||
struct TikzColor_TV {
|
||||
CUTE_HOST_DEVICE char const*
|
||||
operator()(int tid, int vid) const {
|
||||
static char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}",
|
||||
"{rgb,255:red,175;green,255;blue,175}",
|
||||
"{rgb,255:red,255;green,255;blue,175}",
|
||||
"{rgb,255:red,255;green,175;blue,175}",
|
||||
"{rgb,255:red,210;green,210;blue,255}",
|
||||
"{rgb,255:red,210;green,255;blue,210}",
|
||||
"{rgb,255:red,255;green,255;blue,210}",
|
||||
"{rgb,255:red,255;green,210;blue,210}"};
|
||||
return color_map[tid % 8];
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////
|
||||
// Layout 2D to LaTeX TikZ //
|
||||
/////////////////////////////
|
||||
|
||||
template <class LayoutA, class TikzColorFn = TikzColor_BWx8>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print_latex(LayoutA const& layout_a, // (m,n) -> idx
|
||||
TikzColorFn color = {}) // lambda(idx) -> tikz color string
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(rank(layout_a) <= Int<2>{});
|
||||
auto layout = append<2>(layout_a, Layout<_1,_0>{});
|
||||
|
||||
// Commented print(layout)
|
||||
printf("%% Layout: "); print(layout); printf("\n");
|
||||
// Header
|
||||
printf("\\documentclass[convert]{standalone}\n"
|
||||
"\\usepackage{tikz}\n\n"
|
||||
"\\begin{document}\n"
|
||||
"\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n");
|
||||
|
||||
auto [M, N] = product_each(shape(layout));
|
||||
|
||||
// Layout
|
||||
for (int m = 0; m < M; ++m) {
|
||||
for (int n = 0; n < N; ++n) {
|
||||
int idx = layout(m,n);
|
||||
printf("\\node[fill=%s] at (%d,%d) {%d};\n",
|
||||
color(idx), m, n, idx);
|
||||
}
|
||||
}
|
||||
// Grid
|
||||
printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (0,0) grid (%d,%d);\n\n",
|
||||
int(M), int(N));
|
||||
// Labels
|
||||
for (int m = 0, n = -1; m < M; ++m) {
|
||||
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, n, m);
|
||||
}
|
||||
for (int m = -1, n = 0; n < N; ++n) {
|
||||
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, n, n);
|
||||
}
|
||||
|
||||
// Footer
|
||||
printf("\\end{tikzpicture}\n"
|
||||
"\\end{document}\n");
|
||||
}
|
||||
|
||||
template <class SwizzleFn, int B, class Layout, class TikzColorFn = TikzColor_BWx8>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print_latex(ComposedLayout<SwizzleFn,smem_ptr_flag_bits<B>,Layout> const& layout,
|
||||
TikzColorFn color = {}) // lambda(idx) -> tikz color string)
|
||||
{
|
||||
print_latex(as_position_independent_swizzle_layout(layout), color);
|
||||
}
|
||||
|
||||
///////////////////////////////
|
||||
// LayoutTV 2D to LaTeX TikZ //
|
||||
///////////////////////////////
|
||||
|
||||
template <class LayoutTV, class Tile_MN,
|
||||
class TikzColorFn = TikzColor_TV>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print_latex_tv(LayoutTV const& layout_tv, // (t,v) -> m,n coord
|
||||
Tile_MN const& tile_mn, // (M,N)
|
||||
TikzColorFn color = {}) // (t,v) -> color
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(rank(layout_tv) == Int<2>{});
|
||||
|
||||
// Commented prints
|
||||
printf("%% Layout TV: "); print(layout_tv); printf("\n");
|
||||
// Header
|
||||
printf("\\documentclass[convert]{standalone}\n"
|
||||
"\\usepackage{tikz}\n\n"
|
||||
"\\begin{document}\n"
|
||||
"\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n");
|
||||
|
||||
auto [M, N] = product_each(shape(tile_mn));
|
||||
Tensor filled = make_tensor<bool>(make_shape(M, N));
|
||||
clear(filled);
|
||||
|
||||
// Layout
|
||||
for (int tid = 0; tid < size<0>(layout_tv); ++tid) {
|
||||
for (int vid = 0; vid < size<1>(layout_tv); ++vid) {
|
||||
auto [m, n] = layout_tv(tid, vid);
|
||||
if (not filled(m, n)) {
|
||||
filled(m, n) = true;
|
||||
printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
|
||||
color(tid, vid),
|
||||
int(m), int(n),
|
||||
tid, vid);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Grid
|
||||
printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (0,0) grid (%d,%d);\n\n", int(M), int(N));
|
||||
// Labels
|
||||
for (int m = 0, n = -1; m < M; ++m) {
|
||||
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, n, m);
|
||||
}
|
||||
for (int n = 0, m = -1; n < N; ++n) {
|
||||
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, n, n);
|
||||
}
|
||||
// Footer
|
||||
printf("\\end{tikzpicture}\n"
|
||||
"\\end{document}\n");
|
||||
}
|
||||
|
||||
////////////////////////////
|
||||
// MMA Atom to LaTeX TikZ //
|
||||
////////////////////////////
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class LayoutC, class LayoutA, class LayoutB, class Tile_MNK,
|
||||
class TikzColorFn = TikzColor_TV>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print_latex_mma(LayoutC const& C, // (tid,vid) -> (m,n) coord
|
||||
LayoutA const& A, // (tid,vid) -> (m,k) coord
|
||||
LayoutB const& B, // (tid,vid) -> (n,k) coord
|
||||
Tile_MNK const& tile_mnk, // (M,N,K)
|
||||
TikzColorFn color = {}) // lambda(tid,vid) -> tikz color string
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{});
|
||||
|
||||
// Commented prints
|
||||
printf("%% LayoutC: "); print(C); printf("\n");
|
||||
printf("%% LayoutA: "); print(A); printf("\n");
|
||||
printf("%% LayoutB: "); print(B); printf("\n");
|
||||
// Header
|
||||
printf("\\documentclass[convert]{standalone}\n"
|
||||
"\\usepackage{tikz}\n\n"
|
||||
"\\begin{document}\n"
|
||||
"\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n");
|
||||
|
||||
auto [M, N, K] = product_each(shape(tile_mnk));
|
||||
Tensor filled = make_tensor<bool>(make_shape(M, N, K));
|
||||
clear(filled);
|
||||
|
||||
// C starting at 0,0
|
||||
for (int tid = 0; tid < size<0>(C); ++tid) {
|
||||
for (int vid = 0; vid < size<1>(C); ++vid) {
|
||||
auto [m, n] = C(tid, vid);
|
||||
if (not filled(m, n, 0)) {
|
||||
filled(m, n, 0) = true;
|
||||
printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
|
||||
color(tid, vid),
|
||||
int(m), int(n),
|
||||
tid, vid);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Grid
|
||||
printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n",
|
||||
0, 0, int(M), int(N));
|
||||
|
||||
clear(filled);
|
||||
|
||||
// A starting at 0,-K-1
|
||||
for (int tid = 0; tid < size<0>(A); ++tid) {
|
||||
for (int vid = 0; vid < size<1>(A); ++vid) {
|
||||
auto [m, k] = A(tid, vid);
|
||||
if (not filled(m, 0, k)) {
|
||||
filled(m, 0, k) = true;
|
||||
printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
|
||||
color(tid, vid),
|
||||
int(m), int(k-K-1),
|
||||
tid, vid);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Grid
|
||||
printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n",
|
||||
0, -int(K)-1, int(M), -1);
|
||||
// A labels
|
||||
for (int m = 0, k = -1; m < M; ++m) {
|
||||
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, int(k-K-1), m);
|
||||
}
|
||||
for (int m = -1, k = 0; k < K; ++k) {
|
||||
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, int(k-K-1), k);
|
||||
}
|
||||
|
||||
clear(filled);
|
||||
|
||||
// B starting at -K-1,0
|
||||
for (int tid = 0; tid < size<0>(B); ++tid) {
|
||||
for (int vid = 0; vid < size<1>(B); ++vid) {
|
||||
auto [n, k] = B(tid, vid);
|
||||
if (not filled(0, n, k)) {
|
||||
filled(0, n, k) = true;
|
||||
printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
|
||||
color(tid, vid),
|
||||
int(k)-int(K)-1, int(n),
|
||||
tid, vid);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Grid
|
||||
printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n",
|
||||
-int(K)-1, 0, -1, int(N));
|
||||
// B labels
|
||||
for (int n = 0, k = -1; n < N; ++n) {
|
||||
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", int(k-K-1), n, n);
|
||||
}
|
||||
for (int n = -1, k = 0; k < K; ++k) {
|
||||
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", int(k-K-1), n, k);
|
||||
}
|
||||
|
||||
// Footer
|
||||
printf("\\end{tikzpicture}\n"
|
||||
"\\end{document}\n");
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
// MMA Atom to LaTeX TikZ
|
||||
template <class... Args, class TikzColorFn = TikzColor_TV>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print_latex(MMA_Atom<Args...> const& mma_atom,
|
||||
TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string
|
||||
{
|
||||
print_latex(make_tiled_mma(mma_atom));
|
||||
}
|
||||
|
||||
// TiledMMA to LaTeX TikZ
|
||||
template <class... Args, class TikzColorFn = TikzColor_TV>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print_latex(TiledMMA<Args...> const& mma,
|
||||
TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string
|
||||
{
|
||||
auto tile_mnk = tile_shape(mma);
|
||||
|
||||
Tensor refC = make_identity_tensor(select<0,1>(tile_mnk));
|
||||
Tensor tensorC_TV = composition(refC, mma.get_layoutC_TV());
|
||||
|
||||
Tensor refA = make_identity_tensor(select<0,2>(tile_mnk));
|
||||
Tensor tensorA_TV = composition(refA, mma.get_layoutA_TV());
|
||||
|
||||
Tensor refB = make_identity_tensor(select<1,2>(tile_mnk));
|
||||
Tensor tensorB_TV = composition(refB, mma.get_layoutB_TV());
|
||||
|
||||
detail::print_latex_mma(tensorC_TV, tensorA_TV, tensorB_TV, tile_mnk, color);
|
||||
}
|
||||
|
||||
////////////////////////////
|
||||
// CopyAtom to LaTeX TikZ //
|
||||
////////////////////////////
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Generic TV Layout to LaTeX TikZ
|
||||
template <class LayoutS_TV, class LayoutD_TV, class Tile_MN,
|
||||
class TikzColorFn = TikzColor_TV>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print_latex_copy(LayoutS_TV const& S, // (t,v) -> m,n coord
|
||||
LayoutD_TV const& D, // (t,v) -> m,n coord
|
||||
Tile_MN const& tile_mn, // (M,N)
|
||||
TikzColorFn color = {}) // (t,v) -> color
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(rank(S) == Int<2>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(D) == Int<2>{});
|
||||
|
||||
// Commented prints
|
||||
printf("%% Layout S TV: "); print(S); printf("\n");
|
||||
printf("%% Layout D TV: "); print(D); printf("\n");
|
||||
|
||||
// Header
|
||||
printf("\\documentclass[convert]{standalone}\n"
|
||||
"\\usepackage{tikz}\n\n"
|
||||
"\\begin{document}\n"
|
||||
"\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n");
|
||||
|
||||
auto [M, N] = product_each(shape(tile_mn));
|
||||
Tensor filled = make_tensor<bool>(make_shape(M, N));
|
||||
clear(filled);
|
||||
|
||||
// S starting at 0,0
|
||||
for (int tid = 0; tid < size<0>(S); ++tid) {
|
||||
for (int vid = 0; vid < size<1>(S); ++vid) {
|
||||
auto [m, n] = S(tid, vid);
|
||||
if (not filled(m, n)) {
|
||||
filled(m, n) = true;
|
||||
printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
|
||||
color(tid, vid),
|
||||
int(m), int(n),
|
||||
tid, vid);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Grid
|
||||
printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n",
|
||||
0, 0, int(M), int(N));
|
||||
// S Labels
|
||||
for (int m = 0, n = -1; m < M; ++m) {
|
||||
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, n, m);
|
||||
}
|
||||
for (int m = -1, n = 0; n < N; ++n) {
|
||||
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, n, n);
|
||||
}
|
||||
|
||||
clear(filled);
|
||||
|
||||
// D starting at 0,N+3
|
||||
for (int tid = 0; tid < size<0>(D); ++tid) {
|
||||
for (int vid = 0; vid < size<1>(D); ++vid) {
|
||||
auto [m, n] = D(tid, vid);
|
||||
if (not filled(m, n)) {
|
||||
filled(m, n) = true;
|
||||
printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
|
||||
color(tid, vid),
|
||||
int(m), int(n) + int(N) + 3,
|
||||
tid, vid);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Grid
|
||||
printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n",
|
||||
0, int(N) + 3, int(M), int(N) + int(N) + 3);
|
||||
// D Labels
|
||||
for (int m = 0, n = N; m < M; ++m) {
|
||||
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, int(n+N+3), m);
|
||||
}
|
||||
for (int m = -1, n = 0; n < N; ++n) {
|
||||
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, int(n+N+3), n);
|
||||
}
|
||||
|
||||
// Footer
|
||||
printf("\\end{tikzpicture}\n"
|
||||
"\\end{document}\n");
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
// TiledCopy to LaTeX TikZ
|
||||
template <class... Args, class TikzColorFn = TikzColor_TV>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print_latex(TiledCopy<Args...> const& copy,
|
||||
TikzColorFn color = {}) // lambda(tid,vid) -> tikz color string
|
||||
{
|
||||
auto tiler_mn = typename TiledCopy<Args...>::Tiler_MN{};
|
||||
auto tile_mn = product_each(shape(logical_divide(make_layout(Shape<_1,_1>{}), tiler_mn))); // tile_shape
|
||||
|
||||
Tensor refS = make_identity_tensor(tile_mn);
|
||||
Tensor layoutS_TV = copy.tidfrg_S(refS)(_,_,Int<0>{});
|
||||
|
||||
Tensor refD = make_identity_tensor(tile_mn);
|
||||
Tensor layoutD_TV = copy.tidfrg_D(refD)(_,_,Int<0>{});
|
||||
|
||||
detail::print_latex_copy(layoutS_TV, layoutD_TV, tile_mn, color);
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
257
include/cute/util/print_svg.hpp
Normal file
257
include/cute/util/print_svg.hpp
Normal file
@ -0,0 +1,257 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp> // CUTE_HOST_DEVICE
|
||||
|
||||
#include <cute/atom/mma_atom.hpp>
|
||||
#include <cute/atom/copy_atom.hpp>
|
||||
|
||||
#include <cute/layout.hpp>
|
||||
#include <cute/tensor_impl.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
////////////////////////////////
|
||||
// Common SVG Color utilities //
|
||||
////////////////////////////////
|
||||
|
||||
struct TSVGColor_White {
|
||||
CUTE_HOST_DEVICE char const*
|
||||
operator()(int idx) const {
|
||||
return "255,255,255";
|
||||
}
|
||||
};
|
||||
|
||||
struct TSVGColor_BWx8 {
|
||||
CUTE_HOST_DEVICE char const*
|
||||
operator()(int idx) const {
|
||||
static char const* color_map[8] = {"255,255,255", "230,230,230", "205,205,205", "180,180,180",
|
||||
"155,155,155", "130,130,130", "105,105,105", "080,080,080"};
|
||||
return color_map[idx % 8];
|
||||
}
|
||||
};
|
||||
|
||||
struct SVGColor_TV {
|
||||
CUTE_HOST_DEVICE char const*
|
||||
operator()(int tid, int vid) const {
|
||||
static char const* color_map[8] = {"175,175,255", "175,255,175", "255,255,175", "255,175,175",
|
||||
"210,210,255", "210,255,210", "255,255,210", "255,210,210"};
|
||||
return color_map[tid % 8];
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////
|
||||
// MMA Atom to SVG //
|
||||
/////////////////////
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class LayoutC, class LayoutA, class LayoutB, class Tile_MNK,
|
||||
class SVGColorFn = SVGColor_TV>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print_svg_mma(LayoutC const& C,
|
||||
LayoutA const& A,
|
||||
LayoutB const& B,
|
||||
Tile_MNK const& tile_mnk,
|
||||
SVGColorFn color = {}) // lambda(tid,vid) -> SVG color string
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{});
|
||||
|
||||
auto [M, N, K] = product_each(shape(tile_mnk));
|
||||
|
||||
int cell_size = 20;
|
||||
|
||||
int page_width = (K + N + 2) * cell_size;
|
||||
int page_height = (K + M + 2) * cell_size;
|
||||
|
||||
// Commented print
|
||||
printf("<!-- Tile: "); print(tile_mnk); printf(" -->\n");
|
||||
printf("<!-- A: "); print(A); printf(" -->\n");
|
||||
printf("<!-- B: "); print(B); printf(" -->\n");
|
||||
printf("<!-- C: "); print(C); printf(" -->\n");
|
||||
|
||||
// SVG Header
|
||||
printf("<svg width=\"100%%\" height=\"100%%\" viewBox=\"0 0 %d %d\" "
|
||||
"preserveAspectRatio=\"xMidYMid meet\" "
|
||||
"xmlns=\"http://www.w3.org/2000/svg\">\n",
|
||||
page_width, page_height);
|
||||
|
||||
Tensor filled = make_tensor<bool>(make_shape(M, N, K));
|
||||
clear(filled);
|
||||
|
||||
// --- Draw C ---
|
||||
for (int tid = 0; tid < size<0>(C); ++tid) {
|
||||
for (int vid = 0; vid < size<1>(C); ++vid) {
|
||||
auto [m, n] = C(tid, vid);
|
||||
if (!filled(m, n, 0)) {
|
||||
filled(m, n, 0) = true;
|
||||
|
||||
int x = (n + K + 2) * cell_size;
|
||||
int y = (m + K + 2) * cell_size;
|
||||
|
||||
printf("<rect x=\"%d\" y=\"%d\" width=\"%d\" height=\"%d\" "
|
||||
"fill=\"rgb(%s)\" stroke=\"black\"/>\n",
|
||||
x, y, cell_size, cell_size, color(tid,vid));
|
||||
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
|
||||
"alignment-baseline=\"central\" font-size=\"8\">T%d</text>\n",
|
||||
x + cell_size/2, y + 1*cell_size/4, tid);
|
||||
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
|
||||
"alignment-baseline=\"central\" font-size=\"8\">V%d</text>\n",
|
||||
x + cell_size/2, y + 3*cell_size/4, vid);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
clear(filled);
|
||||
|
||||
// --- Draw A ---
|
||||
for (int tid = 0; tid < size<0>(A); ++tid) {
|
||||
for (int vid = 0; vid < size<1>(A); ++vid) {
|
||||
auto [m, k] = A(tid, vid);
|
||||
if (!filled(m, 0, k)) {
|
||||
filled(m, 0, k) = true;
|
||||
|
||||
int x = (k + 1) * cell_size;
|
||||
int y = (m + K + 2) * cell_size;
|
||||
|
||||
printf("<rect x=\"%d\" y=\"%d\" width=\"%d\" height=\"%d\" "
|
||||
"fill=\"rgb(%s)\" stroke=\"black\" />\n",
|
||||
x, y, cell_size, cell_size, color(tid,vid));
|
||||
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
|
||||
"alignment-baseline=\"central\" font-size=\"8\">T%d</text>\n",
|
||||
x + cell_size/2, y + 1*cell_size/4, tid);
|
||||
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
|
||||
"alignment-baseline=\"central\" font-size=\"8\">V%d</text>\n",
|
||||
x + cell_size/2, y + 3*cell_size/4, vid);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A labels
|
||||
for (int m = 0, k = -1; m < M; ++m) {
|
||||
int x = (k + 1) * cell_size;
|
||||
int y = (m + K + 2) * cell_size;
|
||||
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
|
||||
"alignment-baseline=\"central\" font-size=\"12\">%d</text>\n",
|
||||
x + cell_size/2, y + cell_size/2, m);
|
||||
}
|
||||
for (int m = -1, k = 0; k < K; ++k) {
|
||||
int x = (k + 1) * cell_size;
|
||||
int y = (m + K + 2) * cell_size;
|
||||
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
|
||||
"alignment-baseline=\"central\" font-size=\"12\">%d</text>\n",
|
||||
x + cell_size/2, y + cell_size/2, k);
|
||||
}
|
||||
|
||||
clear(filled);
|
||||
|
||||
// --- Draw B ---
|
||||
for (int tid = 0; tid < size<0>(B); ++tid) {
|
||||
for (int vid = 0; vid < size<1>(B); ++vid) {
|
||||
auto [n, k] = B(tid, vid);
|
||||
if (!filled(0, n, k)) {
|
||||
filled(0, n, k) = true;
|
||||
|
||||
int x = (n + K + 2) * cell_size;
|
||||
int y = (k + 1) * cell_size;
|
||||
|
||||
printf("<rect x=\"%d\" y=\"%d\" width=\"%d\" height=\"%d\" "
|
||||
"fill=\"rgb(%s)\" stroke=\"black\" />\n",
|
||||
x, y, cell_size, cell_size, color(tid,vid));
|
||||
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
|
||||
"alignment-baseline=\"central\" font-size=\"8\">T%d</text>\n",
|
||||
x + cell_size/2, y + 1*cell_size/4, tid);
|
||||
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
|
||||
"alignment-baseline=\"central\" font-size=\"8\">V%d</text>\n",
|
||||
x + cell_size/2, y + 3*cell_size/4, vid);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// B labels
|
||||
for (int n = 0, k = -1; n < N; ++n) {
|
||||
int x = (n + K + 2) * cell_size;
|
||||
int y = (k + 1) * cell_size;
|
||||
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
|
||||
"alignment-baseline=\"central\" font-size=\"12\">%d</text>\n",
|
||||
x + cell_size/2, y + cell_size/2, n);
|
||||
}
|
||||
for (int n = -1, k = 0; k < K; ++k) {
|
||||
int x = (n + K + 2) * cell_size;
|
||||
int y = (k + 1) * cell_size;
|
||||
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
|
||||
"alignment-baseline=\"central\" font-size=\"12\">%d</text>\n",
|
||||
x + cell_size/2, y + cell_size/2, k);
|
||||
}
|
||||
|
||||
// SVG footer
|
||||
printf("</svg>\n");
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
// MMA Atom to SVG
|
||||
template <class... Args, class SVGColorFn = SVGColor_TV>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print_svg(MMA_Atom<Args...> const& mma_atom,
|
||||
SVGColorFn color = {}) // lambda(thr_idx,val_idx) -> svg color string
|
||||
{
|
||||
print_svg(make_tiled_mma(mma_atom));
|
||||
}
|
||||
|
||||
// TiledMMA to SVG
|
||||
template <class... Args, class SVGColorFn = SVGColor_TV>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print_svg(TiledMMA<Args...> const& mma,
|
||||
SVGColorFn color = {}) // lambda(thr_idx,val_idx) -> svg color string
|
||||
{
|
||||
auto tile_mnk = tile_shape(mma);
|
||||
|
||||
Tensor refC = make_identity_tensor(select<0,1>(tile_mnk));
|
||||
Tensor tensorC_TV = composition(refC, mma.get_layoutC_TV());
|
||||
|
||||
Tensor refA = make_identity_tensor(select<0,2>(tile_mnk));
|
||||
Tensor tensorA_TV = composition(refA, mma.get_layoutA_TV());
|
||||
|
||||
Tensor refB = make_identity_tensor(select<1,2>(tile_mnk));
|
||||
Tensor tensorB_TV = composition(refB, mma.get_layoutB_TV());
|
||||
|
||||
detail::print_svg_mma(tensorC_TV, tensorA_TV, tensorB_TV, tile_mnk, color);
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
188
include/cute/util/print_tensor.hpp
Normal file
188
include/cute/util/print_tensor.hpp
Normal file
@ -0,0 +1,188 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp> // CUTE_HOST_DEVICE
|
||||
|
||||
#include <cute/layout.hpp>
|
||||
#include <cute/tensor_impl.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
////////////////////////////////
|
||||
// Layout 2D to Console table //
|
||||
////////////////////////////////
|
||||
|
||||
template <class Layout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print_layout(Layout const& layout) // (m,n) -> idx
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{});
|
||||
|
||||
int idx_width = num_digits(cosize(layout)) + 2;
|
||||
const char* delim = "+-----------------------";
|
||||
|
||||
print(layout); print("\n");
|
||||
|
||||
// Column indices
|
||||
print(" ");
|
||||
for (int n = 0; n < size<1>(layout); ++n) { printf(" %*d ", idx_width-2, n); }
|
||||
printf("\n");
|
||||
|
||||
// Print out A m-by-n
|
||||
for (int m = 0; m < size<0>(layout); ++m) {
|
||||
// Header
|
||||
print(" ");
|
||||
for (int n = 0; n < size<1>(layout); ++n) { printf("%.*s", idx_width+1, delim); }
|
||||
printf("+\n");
|
||||
// Values
|
||||
printf("%2d ", m); // Row indices
|
||||
for (int n = 0; n < size<1>(layout); ++n) { printf("| %*d ", idx_width-2, int(layout(m,n))); }
|
||||
printf("|\n");
|
||||
}
|
||||
// Footer
|
||||
print(" ");
|
||||
for (int n = 0; n < size<1>(layout); ++n) { printf("%.*s", idx_width+1, delim); }
|
||||
printf("+\n");
|
||||
}
|
||||
|
||||
// Capture and cast smem_ptr_flag Layouts to offset-0 layouts
|
||||
template <class SwizzleFn, int B, class Layout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print_layout(ComposedLayout<SwizzleFn,smem_ptr_flag_bits<B>,Layout> const& layout)
|
||||
{
|
||||
print_layout(as_position_independent_swizzle_layout(layout));
|
||||
}
|
||||
|
||||
////////////////////////////////
|
||||
// Tensor 1D,2D,3D,4D Console //
|
||||
////////////////////////////////
|
||||
|
||||
template <class Engine, class Layout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print_tensor(Tensor<Engine,Layout> const& tensor, bool print_type = true)
|
||||
{
|
||||
if (print_type) {
|
||||
print(tensor); print(":\n");
|
||||
}
|
||||
|
||||
if constexpr (Layout::rank == 1)
|
||||
{
|
||||
for (int m = 0; m < size(tensor); ++m) {
|
||||
pretty_print(tensor(m));
|
||||
printf("\n");
|
||||
}
|
||||
} else
|
||||
if constexpr (Layout::rank == 2)
|
||||
{
|
||||
for (int m = 0; m < size<0>(tensor); ++m) {
|
||||
for (int n = 0; n < size<1>(tensor); ++n) {
|
||||
pretty_print(tensor(m,n));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
} else
|
||||
if constexpr (Layout::rank == 3)
|
||||
{
|
||||
print_tensor(tensor(_,_,0), false);
|
||||
for (int k = 1; k < size<2>(tensor); ++k) {
|
||||
for (int i = 0; i < 5*size<1>(tensor); ++i) { print("-"); } print("\n");
|
||||
print_tensor(tensor(_,_,k), false);
|
||||
}
|
||||
} else
|
||||
if constexpr (Layout::rank == 4)
|
||||
{
|
||||
print_tensor(tensor(_,_,_,0), false);
|
||||
for (int p = 1; p < size<3>(tensor); ++p) {
|
||||
for (int i = 0; i < 5*size<1>(tensor); ++i) { print("="); } print("\n");
|
||||
print_tensor(tensor(_,_,_,p), false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
template <class Engine, class Layout>
|
||||
CUTE_HOST
|
||||
std::ostream&
|
||||
print_tensor_os(std::ostream& os, Tensor<Engine,Layout> const& tensor)
|
||||
{
|
||||
int digits = 9;
|
||||
|
||||
if constexpr (Layout::rank == 1)
|
||||
{
|
||||
for (int m = 0; m < size(tensor); ++m) {
|
||||
os << std::setw(digits) << tensor(m) << std::endl;
|
||||
}
|
||||
} else
|
||||
if constexpr (Layout::rank == 2)
|
||||
{
|
||||
for (int m = 0; m < size<0>(tensor); ++m) {
|
||||
for (int n = 0; n < size<1>(tensor); ++n) {
|
||||
os << std::setw(digits) << tensor(m,n);
|
||||
}
|
||||
os << std::endl;
|
||||
}
|
||||
} else
|
||||
if constexpr (Layout::rank == 3)
|
||||
{
|
||||
print_tensor_os(os, tensor(_,_,0));
|
||||
for (int k = 1; k < size<2>(tensor); ++k) {
|
||||
for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "-"; } os << std::endl;
|
||||
print_tensor_os(os, tensor(_,_,k));
|
||||
}
|
||||
} else
|
||||
if constexpr (Layout::rank == 4)
|
||||
{
|
||||
print_tensor_os(os, tensor(_,_,_,0));
|
||||
for (int p = 1; p < size<3>(tensor); ++p) {
|
||||
for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "="; } os << std::endl;
|
||||
print_tensor_os(os, tensor(_,_,_,p));
|
||||
}
|
||||
}
|
||||
|
||||
return os;
|
||||
}
|
||||
|
||||
template <class Engine, class Layout>
|
||||
CUTE_HOST
|
||||
std::ostream&
|
||||
operator<<(std::ostream& os, Tensor<Engine,Layout> const& tensor)
|
||||
{
|
||||
os << tensor.layout() << std::endl;
|
||||
return print_tensor_os(os, tensor);
|
||||
}
|
||||
#endif // !defined(__CUDACC_RTC__)
|
||||
|
||||
} // end namespace cute
|
||||
@ -92,6 +92,29 @@ using CUTE_STL_NAMESPACE::remove_const_t;
|
||||
using CUTE_STL_NAMESPACE::remove_cv_t;
|
||||
using CUTE_STL_NAMESPACE::remove_reference_t;
|
||||
|
||||
template <class Src, class Dst>
|
||||
struct copy_cv {
|
||||
using type = Dst;
|
||||
};
|
||||
|
||||
template <class Src, class Dst>
|
||||
struct copy_cv<Src const, Dst> {
|
||||
using type = Dst const;
|
||||
};
|
||||
|
||||
template <class Src, class Dst>
|
||||
struct copy_cv<Src volatile, Dst> {
|
||||
using type = Dst volatile;
|
||||
};
|
||||
|
||||
template <class Src, class Dst>
|
||||
struct copy_cv<Src const volatile, Dst> {
|
||||
using type = Dst const volatile;
|
||||
};
|
||||
|
||||
template <class Src, class Dst>
|
||||
using copy_cv_t = typename copy_cv<Src,Dst>::type;
|
||||
|
||||
using CUTE_STL_NAMESPACE::extent;
|
||||
using CUTE_STL_NAMESPACE::remove_extent;
|
||||
|
||||
|
||||
118
include/cutlass/arch/mma_sm100.h
Normal file
118
include/cutlass/arch/mma_sm100.h
Normal file
@ -0,0 +1,118 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Matrix multiply
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda/std/cassert>
|
||||
|
||||
#include "cutlass/arch/mma.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/arch/config.h"
|
||||
#include "cute/arch/simd_sm100.hpp"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass{
|
||||
namespace arch {
|
||||
|
||||
|
||||
/// Matrix multiply-add operation
|
||||
template <
|
||||
/// Data type of A elements
|
||||
typename ElementA,
|
||||
/// Layout of A matrix (concept: MatrixLayout)
|
||||
typename LayoutA,
|
||||
/// Data type of B elements
|
||||
typename ElementB,
|
||||
/// Layout of B matrix (concept: MatrixLayout)
|
||||
typename LayoutB,
|
||||
/// Element type of C matrix
|
||||
typename ElementC_,
|
||||
/// Layout of C matrix (concept: MatrixLayout)
|
||||
typename LayoutC
|
||||
>
|
||||
struct Mma<gemm::GemmShape<2, 1, 1>, 1, ElementA, LayoutA, ElementB, LayoutB, ElementC_, LayoutC, OpMultiplyAdd> {
|
||||
|
||||
using Shape = gemm::GemmShape<2, 1, 1>;
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ElementC = ElementC_;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
Array<ElementC, 2> &d,
|
||||
Array<ElementA, 2> const &a,
|
||||
Array<ElementB, 1> const &b,
|
||||
Array<ElementC, 2> const &c
|
||||
) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
d[i] = a[i] * b[0] + c[i];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Matrix multiply-add operation
|
||||
template <
|
||||
/// Layout of A matrix
|
||||
typename LayoutA,
|
||||
/// Layout of B matrix
|
||||
typename LayoutB,
|
||||
/// Layout of C matrix
|
||||
typename LayoutC
|
||||
>
|
||||
struct Mma<gemm::GemmShape<2, 1, 1>, 1, float, LayoutA, float, LayoutB, float, LayoutC, OpMultiplyAdd> {
|
||||
|
||||
using Shape = gemm::GemmShape<2, 1, 1>;
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ElementC = float;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
Array<float, 2> &d,
|
||||
Array<float, 2> const &a,
|
||||
Array<float, 1> const &b,
|
||||
Array<float, 2> const &c
|
||||
) {
|
||||
float2 result;
|
||||
cute::fma(result, make_float2(a[0], a[1]), make_float2(b[0], b[0]), make_float2(c[0], c[1]));
|
||||
d[0] = result.x;
|
||||
d[1] = result.y;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace arch
|
||||
} // namespace cutlass
|
||||
@ -88,7 +88,7 @@ struct LayoutAwareConvertImpl<
|
||||
static void convert(
|
||||
cute::Tensor<EngineIn,
|
||||
cute::Layout<cute::Shape<_2,_4>, cute::Stride<_4,_1>>
|
||||
> const& src,
|
||||
> const& src,
|
||||
cute::Tensor<EngineOut,
|
||||
cute::Layout<_8>
|
||||
>& dst) {
|
||||
@ -136,7 +136,7 @@ struct LayoutAwareConvertImpl<
|
||||
static void convert(
|
||||
cute::Tensor<EngineIn,
|
||||
cute::Layout<cute::Shape<_2,_4>, cute::Stride<_4,_1>>
|
||||
> const& src,
|
||||
> const& src,
|
||||
cute::Tensor<EngineOut,
|
||||
cute::Layout<_8>
|
||||
>& dst) {
|
||||
@ -184,7 +184,7 @@ struct LayoutAwareConvertImpl<
|
||||
static void convert(
|
||||
cute::Tensor<EngineIn,
|
||||
cute::Layout<cute::Shape<_2,_4>, cute::Stride<_4,_1>>
|
||||
> const& src,
|
||||
> const& src,
|
||||
cute::Tensor<EngineOut,
|
||||
cute::Layout<_8>
|
||||
>& dst) {
|
||||
@ -250,7 +250,7 @@ struct LayoutAwareConvertImpl<
|
||||
static void convert(
|
||||
cute::Tensor<EngineIn,
|
||||
cute::Layout<cute::Shape<_2,_4>, cute::Stride<_4,_1>>
|
||||
> const& src,
|
||||
> const& src,
|
||||
cute::Tensor<EngineOut,
|
||||
cute::Layout<_8>
|
||||
>& dst) {
|
||||
@ -477,7 +477,7 @@ void LayoutAwareConvert(
|
||||
Tensor dst_vm = coalesce(dst);
|
||||
Layout src_layout = src_vm.layout();
|
||||
Layout dst_layout = dst_vm.layout();
|
||||
LayoutAwareConvertImpl<SrcType,
|
||||
LayoutAwareConvertImpl<SrcType,
|
||||
DstType,
|
||||
decltype(src_layout),
|
||||
decltype(dst_layout)>::convert(src_vm, dst_vm);
|
||||
@ -487,18 +487,25 @@ void LayoutAwareConvert(
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace detail {
|
||||
enum class ConversionMode {
|
||||
DirectConvert, // A * B
|
||||
ConvertAndScale, // (scale * A) * B
|
||||
ConvertAndScaleWithZero // (scale * A + zeros) * B
|
||||
};
|
||||
} // namespace detail
|
||||
} //namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective::detail {
|
||||
|
||||
template <class PointerType>
|
||||
static constexpr
|
||||
CUTLASS_HOST_DEVICE
|
||||
auto get_logical_ptr(PointerType const* ptr) {
|
||||
if constexpr (cute::sizeof_bits_v<PointerType> < 8) {
|
||||
return subbyte_iterator<PointerType const>(ptr);
|
||||
}
|
||||
else {
|
||||
return ptr;
|
||||
}
|
||||
return cute::recast_ptr<PointerType const>(ptr);
|
||||
}
|
||||
template<int Stages, class LayoutAtom, class TileShape, class Stride>
|
||||
static constexpr
|
||||
@ -530,8 +537,8 @@ auto get_gmem_layout(Shape const& shape, Stride const& stride) {
|
||||
template<class Collective>
|
||||
struct MixedInputUtils {
|
||||
private:
|
||||
using ConversionMode = cutlass::detail::ConversionMode;
|
||||
using KernelSchedule = typename Collective::KernelSchedule;
|
||||
using ConversionMode = typename Collective::ConversionMode;
|
||||
using SmemLayoutA = typename Collective::SmemLayoutA;
|
||||
using SmemLayoutB = typename Collective::SmemLayoutB;
|
||||
using SmemLayoutScale = typename Collective::SmemLayoutScale;
|
||||
@ -551,10 +558,10 @@ public:
|
||||
elements_per_smem_scale() {
|
||||
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
else if constexpr (ModeHasScales) {
|
||||
return cute::cosize_v<SmemLayoutScale>;
|
||||
}
|
||||
}
|
||||
else {
|
||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Type not handled in scale smem allocation.");
|
||||
}
|
||||
@ -565,10 +572,10 @@ public:
|
||||
if constexpr (KernelConversionMode == ConversionMode::DirectConvert ||
|
||||
KernelConversionMode == ConversionMode::ConvertAndScale ) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
|
||||
return cute::cosize_v<SmemLayoutScale>;
|
||||
}
|
||||
}
|
||||
else {
|
||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Type not handled in scale smem allocation.");
|
||||
}
|
||||
@ -634,7 +641,7 @@ public:
|
||||
// We are starting a new k-tile so copy the scale
|
||||
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
|
||||
// nothing to do
|
||||
}
|
||||
}
|
||||
else if constexpr (ModeHasScales) {
|
||||
auto smem_tiled_copy_S = cute::get<0>(tiled_copy_and_views);
|
||||
auto tCrS_copy_view = cute::get<1>(tiled_copy_and_views);
|
||||
@ -649,13 +656,23 @@ public:
|
||||
} else {
|
||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions to select packing for conversion
|
||||
template <class SrcType,
|
||||
class DstType,
|
||||
int Cosize>
|
||||
struct select_packing { // Naive packing policy
|
||||
static constexpr auto value() {
|
||||
return Int<cute::gcd(Cosize, 32 / cute::min(sizeof_bits_v<SrcType>, sizeof_bits_v<DstType>))>{};
|
||||
}
|
||||
};
|
||||
|
||||
// The core converter uses a lookup table to converts i4 -> 8 bit value.
|
||||
template <class EngineIn,
|
||||
class LayoutIn,
|
||||
@ -669,7 +686,7 @@ public:
|
||||
Tensor<EngineOut, LayoutOut> && dst,
|
||||
Tensor<EngineScale, LayoutScale> const& scales_neg,
|
||||
Tensor<EngineScale, LayoutScale> const& scales_pos) {
|
||||
|
||||
|
||||
lookup_table_convert(src, dst, scales_neg, scales_pos);
|
||||
}
|
||||
template <class EngineIn,
|
||||
@ -687,7 +704,7 @@ public:
|
||||
|
||||
constexpr int N = cute::cosize(LayoutIn{});
|
||||
static_assert(N == 4 || N == 8);
|
||||
static_assert(cosize(LayoutScale{}) <= N / 4,
|
||||
static_assert(cosize(LayoutScale{}) <= N / 4,
|
||||
"at least 4 consecutive weights must share the same scale.");
|
||||
using SrcArray = cutlass::Array<cutlass::int4b_t, 8>;
|
||||
using DstArray = cutlass::Array<RealSwappedElementB, 8>;
|
||||
@ -699,7 +716,7 @@ public:
|
||||
|
||||
// Determines if to get from the signed or unsigned candidates
|
||||
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
|
||||
uint32_t sign; // ((reg & 0x88888888) | 0x64206420) >> 1
|
||||
uint32_t sign; // ((reg & 0x88888888) | 0x64206420) >> 1
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %1, %2, %3, %4;\n" \
|
||||
@ -743,13 +760,13 @@ public:
|
||||
static_check_scale(flatten(Layout{}));
|
||||
}
|
||||
template <class EngineIn,
|
||||
class EngineOut,
|
||||
class EngineOut,
|
||||
class LayoutIn,
|
||||
class LayoutOut,
|
||||
class... Ts>
|
||||
CUTLASS_DEVICE
|
||||
static void dequantize_A_kblock(
|
||||
Tensor<EngineIn, LayoutIn> const& tCrA_load,
|
||||
Tensor<EngineIn, LayoutIn> const& tCrA_load,
|
||||
Tensor<EngineOut, LayoutOut>& tCrA_mma,
|
||||
cute::tuple<Ts...>& partitioned_extra_info,
|
||||
int const k_block) {
|
||||
@ -764,7 +781,7 @@ public:
|
||||
|
||||
Tensor src = tCrA_load(_, _, k_block);
|
||||
Tensor dst = tCrA_mma(_, _, k_block);
|
||||
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size(src(_, 0)) == cosize(src(_, 0).layout()),
|
||||
"The first mode of tensor src must be contiguous in memory");
|
||||
// try to make the size of the first mode equal to 32bit
|
||||
@ -778,7 +795,7 @@ public:
|
||||
for (int i = 0; i < size<1>(dst_vm); ++i) {
|
||||
LayoutAwareConvert(src_vm(_, i), dst_vm(_, i));
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr (UseScaleLookupTable) {
|
||||
constexpr int num_elements = decltype(size(src))::value;
|
||||
static_assert(is_same_v<RealSwappedElementA, cutlass::int4b_t>, "Lookup table only supports int4 being the quant type now.");
|
||||
@ -856,7 +873,7 @@ public:
|
||||
CUTE_STATIC_ASSERT_V(size(src) == size(zeros));
|
||||
Tensor scales_vm = cute::group_modes<1,-1>(cute::zipped_divide(scales, Int<NumValPerSrcReg>{}));
|
||||
Tensor zeros_vm = cute::group_modes<1,-1>(cute::zipped_divide(zeros, Int<NumValPerSrcReg>{}));
|
||||
|
||||
|
||||
if constexpr (is_same_v<DstType, ElementScale>) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<1>(dst_vm); ++i) {
|
||||
@ -885,6 +902,7 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Utilities for any additional inputs inside of the TMA load
|
||||
template <
|
||||
class Params,
|
||||
@ -897,39 +915,39 @@ public:
|
||||
cute::tuple<Ts...> const& load_inputs,
|
||||
TensorStorage& shared_tensors,
|
||||
uint2 const& cluster_local_block_id,
|
||||
int const m_coord,
|
||||
int const m_coord,
|
||||
int const l_coord) {
|
||||
|
||||
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
|
||||
return cute::make_tuple();
|
||||
}
|
||||
}
|
||||
else if constexpr (ModeHasScales) {
|
||||
Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor gS_mkl = get<2>(load_inputs);
|
||||
auto block_tma_s = mainloop_params.tma_load_scale.get_slice(cluster_local_block_id.y);
|
||||
Tensor gS = gS_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
|
||||
|
||||
Tensor tSgS = block_tma_s.partition_S(gS); // (TMA,TMA_M,TMA_K,k)
|
||||
Tensor tSgS = block_tma_s.partition_S(gS);
|
||||
Tensor tSsS = block_tma_s.partition_D(sS); // (TMA,TMA_M,TMA_K,PIPE)
|
||||
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
|
||||
return cute::make_tuple(tSgS, tSsS);
|
||||
}
|
||||
}
|
||||
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
|
||||
Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor gZ_mkl = get<3>(load_inputs);
|
||||
auto block_tma_z = mainloop_params.tma_load_zero.get_slice(cluster_local_block_id.y);
|
||||
Tensor gZ = gZ_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
|
||||
|
||||
Tensor tZgZ = block_tma_z.partition_S(gZ); // (TMA,TMA_M,TMA_K,k)
|
||||
Tensor tZgZ = block_tma_z.partition_S(gZ);
|
||||
Tensor tZsZ = block_tma_z.partition_D(sZ); // (TMA,TMA_M,TMA_K,PIPE)
|
||||
return cute::make_tuple(tSgS, tSsS, tZgZ, tZsZ);
|
||||
return cute::make_tuple(tSgS, tSsS, tZgZ, tZsZ);
|
||||
}
|
||||
else {
|
||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled for input partitioning.");
|
||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled for input partitioning.");
|
||||
}
|
||||
}
|
||||
else {
|
||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled for input partitioning.");
|
||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled for input partitioning.");
|
||||
}
|
||||
}
|
||||
|
||||
@ -938,7 +956,7 @@ public:
|
||||
class ThreadMma,
|
||||
class TensorStorage
|
||||
>
|
||||
CUTLASS_DEVICE
|
||||
CUTLASS_DEVICE
|
||||
static auto partition_extra_mma_info(
|
||||
ThreadMma const& mma_thread_slice,
|
||||
TensorStorage& shared_tensors) {
|
||||
@ -950,8 +968,8 @@ public:
|
||||
else if constexpr (UseScaleLookupTable) {
|
||||
Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE)
|
||||
Tensor tCsS = mma_thread_slice.partition_A(sS);
|
||||
Tensor tCrS_neg = make_tensor<ElementScale>(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout());
|
||||
Tensor tCrS_pos = make_tensor<ElementScale>(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout());
|
||||
Tensor tCrS_neg = make_tensor<ElementScale>(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout());
|
||||
Tensor tCrS_pos = make_tensor<ElementScale>(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout());
|
||||
|
||||
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
|
||||
return cute::make_tuple(tCsS, tCrS_neg, tCrS_pos);
|
||||
@ -960,7 +978,7 @@ public:
|
||||
else if constexpr (ModeHasScales) {
|
||||
Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE)
|
||||
Tensor tCsS = mma_thread_slice.partition_A(sS);
|
||||
Tensor tCrS = make_tensor<ElementScale>(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout());
|
||||
Tensor tCrS = make_tensor<ElementScale>(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout());
|
||||
|
||||
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
|
||||
return cute::make_tuple(tCsS, tCrS);
|
||||
@ -968,13 +986,13 @@ public:
|
||||
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
|
||||
Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE)
|
||||
Tensor tCsZ = mma_thread_slice.partition_A(sZ);
|
||||
Tensor tCrZ = make_tensor<ElementZero>(mma_thread_slice.partition_fragment_A(sZ(_,_,Int<0>{})).layout());
|
||||
Tensor tCrZ = make_tensor<ElementZero>(mma_thread_slice.partition_fragment_A(sZ(_,_,Int<0>{})).layout());
|
||||
return cute::make_tuple(tCsS, tCrS, tCsZ, tCrZ);
|
||||
}
|
||||
else {
|
||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
|
||||
}
|
||||
@ -996,18 +1014,18 @@ public:
|
||||
auto smem_tiled_copy_S = make_tiled_copy_A(SmemCopyAtomScale{}, tiled_mma);
|
||||
auto smem_thr_copy_S = smem_tiled_copy_S.get_thread_slice(warp_group_thread_idx);
|
||||
Tensor tCrS_copy_view = smem_thr_copy_S.retile_D(cute::get<1>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K)
|
||||
|
||||
|
||||
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
|
||||
return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view);
|
||||
}
|
||||
}
|
||||
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
|
||||
Tensor tCrZ_copy_view = smem_thr_copy_S.retile_D(cute::get<3>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K)
|
||||
return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view, tCrZ_copy_view);
|
||||
}
|
||||
}
|
||||
else {
|
||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
|
||||
}
|
||||
|
||||
@ -1519,6 +1519,105 @@ public:
|
||||
>::CollectiveOp;
|
||||
};
|
||||
|
||||
template <
|
||||
class MmaTileShape_MNK,
|
||||
class ClusterShape_MNK,
|
||||
class ElementAccumulator,
|
||||
class ElementCompute,
|
||||
class ElementC_,
|
||||
class GmemLayoutTagC_,
|
||||
int AlignmentC,
|
||||
class ElementD,
|
||||
class GmemLayoutTagD,
|
||||
int AlignmentD,
|
||||
class EpilogueScheduleType,
|
||||
class FusionOp
|
||||
>
|
||||
struct CollectiveBuilder<
|
||||
arch::Sm100,
|
||||
arch::OpClassSimt,
|
||||
MmaTileShape_MNK,
|
||||
ClusterShape_MNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
ElementC_,
|
||||
GmemLayoutTagC_,
|
||||
AlignmentC,
|
||||
ElementD,
|
||||
GmemLayoutTagD,
|
||||
AlignmentD,
|
||||
EpilogueScheduleType,
|
||||
FusionOp,
|
||||
cute::enable_if_t<
|
||||
cute::is_same_v<EpilogueScheduleType, EpilogueSimtVectorized> ||
|
||||
cute::is_same_v<EpilogueScheduleType, EpiloguePtrArraySimtVectorized> ||
|
||||
cute::is_same_v<EpilogueScheduleType, EpilogueScheduleAuto> >> {
|
||||
using CtaTileShape_MNK = MmaTileShape_MNK; // cluster MMA not supported
|
||||
|
||||
// Passing void C disables source load
|
||||
using ElementC = cute::conditional_t<cute::is_void_v<ElementC_>,
|
||||
ElementD, ElementC_>; // prevents void ref breakages
|
||||
using GmemLayoutTagC = cute::conditional_t<cute::is_void_v<ElementC_>,
|
||||
GmemLayoutTagD, GmemLayoutTagC_>;
|
||||
static constexpr thread::ScaleType::Kind ScaleType = cute::is_void_v<ElementC_> ?
|
||||
thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default;
|
||||
|
||||
using GmemStrideTypeC = cutlass::detail::TagToStrideC_t<GmemLayoutTagC>;
|
||||
using GmemStrideTypeD = cutlass::detail::TagToStrideC_t<GmemLayoutTagD>;
|
||||
|
||||
using ThreadOp = cute::conditional_t<
|
||||
IsDefaultFusionOp<FusionOp>::value,
|
||||
thread::LinearCombination<
|
||||
ElementD, AlignmentD, ElementAccumulator, ElementCompute,
|
||||
ScaleType, FloatRoundStyle::round_to_nearest, ElementC>
|
||||
,
|
||||
thread::LinearCombinationBiasElementwise<
|
||||
ElementC, ElementAccumulator, ElementCompute, ElementD, ElementD, AlignmentD,
|
||||
typename FusionOp::ActivationFn, cutlass::plus<ElementCompute>,
|
||||
false, typename FusionOp::ElementBias>
|
||||
>;
|
||||
static_assert(not (cute::is_same_v<EpilogueScheduleType, EpiloguePtrArraySimtVectorized> && not IsDefaultFusionOp<FusionOp>::value), "unsupported schedule + fusion");
|
||||
|
||||
using WarpShape_MNK = decltype(cutlass::gemm::collective::detail::sm100_simt_f32_warp_shape_mnk_selector<CtaTileShape_MNK>());
|
||||
static constexpr int ThreadCount = cute::size(WarpShape_MNK{}) * NumThreadsPerWarp;
|
||||
static constexpr int WarpShape_M = cute::size<0>(WarpShape_MNK{});
|
||||
static constexpr int WarpShape_N = cute::size<1>(WarpShape_MNK{});
|
||||
|
||||
// For 32 threads in 1 warp, we use [8 x 4] thread layouts and each thread will hold [4 x 4] accumulator value layouts.
|
||||
// Then totally each warp will hold [32 x 16] accumulator value layouts.
|
||||
// We separate the whole epilogue calculation to multi steps,
|
||||
// each step will calculate 1x [32 x 16] for each warp to reduce register pressure (mainly for C register allocation for beta 1!= 0 case).
|
||||
// So EpiTileM = WarpShape_M * 32 and EpiTileN = WarpShape_N * 16.
|
||||
using EpiTileM = Int<WarpShape_M * 32>;
|
||||
using EpiTileN = Int<WarpShape_N * 16>;
|
||||
|
||||
using SmemLayout = cute::conditional_t<cutlass::detail::is_major<0>(GmemStrideTypeD{}),
|
||||
cute::Layout<cute::Shape<EpiTileM, EpiTileN>, cute::Stride<_1, EpiTileM>>,
|
||||
cute::Layout<cute::Shape<EpiTileM, EpiTileN>, cute::Stride<EpiTileN, _1>>>;
|
||||
|
||||
using CopyAtomR2S = Copy_Atom<cute::AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccumulator>;
|
||||
|
||||
using CopyAtomS2R = Copy_Atom<cute::AutoVectorizingCopyWithAssumedAlignment<AlignmentD * sizeof_bits_v<ElementAccumulator>>, ElementAccumulator>;
|
||||
|
||||
using TiledCopyS2R = decltype(
|
||||
cutlass::gemm::collective::detail::make_simt_gmem_tiled_copy<
|
||||
CopyAtomS2R, ThreadCount, AlignmentD, GmemStrideTypeD, EpiTileM, EpiTileN>());
|
||||
|
||||
using Schedule = cute::conditional_t<is_same_v<EpilogueScheduleType, EpilogueScheduleAuto>,
|
||||
EpilogueSimtVectorized,
|
||||
EpilogueScheduleType>;
|
||||
using CopyAtomR2G = Copy_Atom<cute::AutoVectorizingCopyWithAssumedAlignment<AlignmentD * sizeof_bits_v<ElementD>>, ElementD>;
|
||||
using CollectiveOp = cutlass::epilogue::collective::Epilogue<
|
||||
GmemStrideTypeC,
|
||||
GmemStrideTypeD,
|
||||
ThreadOp,
|
||||
SmemLayout,
|
||||
CopyAtomR2S,
|
||||
TiledCopyS2R,
|
||||
CopyAtomR2G,
|
||||
Schedule>;
|
||||
};
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::epilogue::collective
|
||||
|
||||
@ -205,6 +205,16 @@ struct IsThreadEpilogueOpWithPerChannelScaling <ThreadEpilogueOp, cute::enable_i
|
||||
static constexpr bool value = true;
|
||||
};
|
||||
|
||||
template <typename ThreadEpilogueOp, typename = void>
|
||||
struct IsThreadEpilogueOpWithResidualAdd {
|
||||
static constexpr bool value = false;
|
||||
};
|
||||
|
||||
template <typename ThreadEpilogueOp>
|
||||
struct IsThreadEpilogueOpWithResidualAdd <ThreadEpilogueOp, cute::void_t<decltype(ThreadEpilogueOp::IsResidualSupported)>> {
|
||||
static constexpr bool value = ThreadEpilogueOp::IsResidualSupported;
|
||||
};
|
||||
|
||||
template <typename ThreadEpilogueOp, typename = void>
|
||||
struct IsThreadEpilogueOpWithActivation {
|
||||
static constexpr bool value = false;
|
||||
|
||||
@ -39,6 +39,8 @@
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/collective/detail.hpp"
|
||||
#include "cutlass/detail/helper_macros.hpp"
|
||||
#include "cutlass/conv/convnd_problem_shape.hpp"
|
||||
#include "cutlass/conv/detail.hpp"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/numeric/numeric_types.hpp"
|
||||
@ -133,6 +135,7 @@ public:
|
||||
constexpr static int ThreadCount = 128;
|
||||
constexpr static int kOutputAlignment = ThreadEpilogueOp::kCount;
|
||||
constexpr static bool isEpilogueBiasSupported = detail::IsThreadEpilogueOpWithBias<ThreadEpilogueOp>::value;
|
||||
constexpr static bool isSourceNeeded = not cute::is_void_v<ElementC>;
|
||||
|
||||
using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;
|
||||
constexpr static uint32_t TmaTransactionBytes = 0;
|
||||
@ -173,12 +176,27 @@ public:
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
template <conv::Operator ConvOp, int NumDims>
|
||||
static bool
|
||||
can_implement(cutlass::conv::ConvProblemShape<ConvOp,NumDims> const& problem_shape, Arguments const& args) {
|
||||
return can_implement(cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape), args);
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static bool
|
||||
can_implement(
|
||||
[[maybe_unused]] ProblemShape const& problem_shape,
|
||||
[[maybe_unused]] Arguments const& args) {
|
||||
return true;
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
auto shape = cute::make_shape(M,N,L);
|
||||
|
||||
bool implementable = true;
|
||||
implementable = implementable && cutlass::detail::check_alignment<AlignmentD{}>(shape, StrideD{});
|
||||
if constexpr (isSourceNeeded) {
|
||||
implementable = implementable && cutlass::detail::check_alignment<AlignmentC{}>(shape, StrideC{});
|
||||
}
|
||||
return implementable;
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
@ -57,6 +57,7 @@ struct FusionOperation {
|
||||
|
||||
using ElementSource = void;
|
||||
static constexpr bool IsSourceSupported = false;
|
||||
static constexpr bool IsResidualSupported = false; // Source is added after activation
|
||||
|
||||
using ElementScalar = void;
|
||||
static constexpr int AlignmentScalar = 0;
|
||||
@ -317,6 +318,24 @@ struct PerColLinCombPerColBiasEltAct
|
||||
static constexpr bool IsPerColScaleSupported = true;
|
||||
};
|
||||
|
||||
// D = activation(per-col alpha * acc + per-column bias) + per-col beta * C
|
||||
template<
|
||||
template <class> class ActivationFn_,
|
||||
class ElementOutput_,
|
||||
class ElementCompute_,
|
||||
class ElementBias_ = ElementOutput_,
|
||||
class ElementSource_ = ElementOutput_,
|
||||
class ElementScalar_ = ElementCompute_, // per-row alpha/beta
|
||||
int AlignmentBias_ = 128 / cute::sizeof_bits_v<ElementBias_>,
|
||||
int AlignmentScalar_ = 128 / cute::sizeof_bits_v<ElementScalar_>,
|
||||
FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
struct PerColResAddPerColBiasEltAct
|
||||
: PerColLinCombPerColBiasEltAct<ActivationFn_, ElementOutput_, ElementCompute_,
|
||||
ElementBias_, ElementSource_, ElementScalar_, AlignmentBias_, AlignmentScalar_, RoundStyle_> {
|
||||
static constexpr bool IsResidualSupported = true;
|
||||
};
|
||||
|
||||
// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias
|
||||
// if D is fp8
|
||||
// D = scale_d * activation(Z)
|
||||
|
||||
@ -1306,6 +1306,114 @@ struct FusionCallbacks<
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// D = activation(per-col alpha * acc + per-column bias) + per-col beta * C
|
||||
template<
|
||||
class CtaTileShapeMNK,
|
||||
class EpilogueTile,
|
||||
template <class> class ActivationFn,
|
||||
class ElementOutput,
|
||||
class ElementCompute,
|
||||
class ElementBias = ElementOutput,
|
||||
class ElementSource = ElementOutput,
|
||||
class ElementScalar = ElementCompute,
|
||||
int AlignmentBias = 128 / sizeof_bits_v<ElementBias>,
|
||||
int AlignmentScalar = 128 / sizeof_bits_v<ElementScalar>,
|
||||
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
using Sm90PerColResAddPerColBiasEltAct =
|
||||
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + activation(alpha * acc + bias)
|
||||
Sm90RowBroadcast<0, CtaTileShapeMNK, ElementScalar, ElementCompute, Stride<_0,bool,int64_t>, AlignmentScalar>, // beta, dynamic scalar/vector broadcast
|
||||
Sm90SrcFetch<ElementSource>, // C
|
||||
Sm90EVT<Sm90Compute<ActivationFn, ElementCompute, ElementCompute, RoundStyle>, // activation(alpha * acc + bias)
|
||||
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc + bias
|
||||
Sm90RowBroadcast<0, CtaTileShapeMNK, ElementScalar, ElementCompute, Stride<_0,bool,int64_t>, AlignmentScalar>, // alpha, dynamic scalar/vector broadcast
|
||||
Sm90AccFetch, // acc
|
||||
Sm90RowBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_0,_1,int64_t>, AlignmentBias> // bias
|
||||
>
|
||||
>
|
||||
>;
|
||||
|
||||
template <
|
||||
int StagesC,
|
||||
int StagesD,
|
||||
int FragmentSize,
|
||||
bool ReuseSmemC,
|
||||
bool DelayTmaStore,
|
||||
template <class> class ActivationFn,
|
||||
class ElementOutput,
|
||||
class ElementCompute,
|
||||
class ElementBias,
|
||||
class ElementSource,
|
||||
class ElementScalar,
|
||||
int AlignmentBias,
|
||||
int AlignmentScalar,
|
||||
FloatRoundStyle RoundStyle,
|
||||
class CtaTileShapeMNK,
|
||||
class EpilogueTile
|
||||
>
|
||||
struct FusionCallbacks<
|
||||
epilogue::Sm90TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmemC, DelayTmaStore>,
|
||||
fusion::PerColResAddPerColBiasEltAct<
|
||||
ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle
|
||||
>,
|
||||
CtaTileShapeMNK,
|
||||
EpilogueTile
|
||||
> : Sm90PerColResAddPerColBiasEltAct<
|
||||
CtaTileShapeMNK, EpilogueTile, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle
|
||||
> {
|
||||
|
||||
using Impl =
|
||||
Sm90PerColResAddPerColBiasEltAct<
|
||||
CtaTileShapeMNK, EpilogueTile, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle
|
||||
>;
|
||||
using Operation =
|
||||
fusion::PerColResAddPerColBiasEltAct<
|
||||
ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle
|
||||
>;
|
||||
|
||||
struct Arguments {
|
||||
ElementScalar alpha = ElementScalar(1);
|
||||
ElementScalar beta = ElementScalar(0);
|
||||
ElementScalar const* alpha_ptr = nullptr;
|
||||
ElementScalar const* beta_ptr = nullptr;
|
||||
|
||||
using StrideAlpha = Stride<_0,bool,int64_t>;
|
||||
using StrideBeta = Stride<_0,bool,int64_t>;
|
||||
StrideAlpha dAlpha = {_0{}, bool(1), 0};
|
||||
StrideBeta dBeta = {_0{}, bool(1), 0};
|
||||
|
||||
using StrideBias = Stride<_0,_1,int64_t>;
|
||||
ElementBias const* bias_ptr = nullptr;
|
||||
StrideBias dBias = {};
|
||||
|
||||
using ActivationArguments = typename Sm90Compute<ActivationFn, ElementOutput, ElementCompute, RoundStyle>::Arguments;
|
||||
ActivationArguments activation = ActivationArguments();
|
||||
|
||||
operator typename Impl::Arguments() const {
|
||||
return
|
||||
{ // ternary op : beta * C + activation(alpha * acc + bias)
|
||||
{beta_ptr, beta, dBeta}, // leaf args : beta
|
||||
{}, // leaf args : C
|
||||
{ // unary op : activation(alpha * acc + bias)
|
||||
{ // ternary op : alpha * acc + bias
|
||||
{alpha_ptr, alpha, dAlpha}, // leaf args : alpha
|
||||
{}, // leaf args : acc
|
||||
{bias_ptr, ElementBias(0), dBias}, // leaf args : bias
|
||||
{} // ternary args : multiply_add
|
||||
}, // end ternary op
|
||||
activation // unary args : activation
|
||||
}, // end unary op
|
||||
{} // ternary args : multiply_add
|
||||
}; // end ternary op
|
||||
}
|
||||
};
|
||||
|
||||
// Ctor inheritance
|
||||
using Impl::Impl;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename T>
|
||||
|
||||
@ -591,7 +591,7 @@ struct Sm90TreeVisitor<
|
||||
|
||||
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||
gmem_ptr ptr_aux = make_gmem_ptr(subbyte_iterator<cutlass::uint1b_t>(params_aux.ptr_aux));
|
||||
gmem_ptr ptr_aux = make_gmem_ptr<cutlass::uint1b_t>(params_aux.ptr_aux);
|
||||
Tensor mAux = make_tensor(ptr_aux, make_layout(make_shape(M,N,L), params_aux.dAux)); // (M,N,L)
|
||||
Tensor gAux = local_tile(mAux, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N)
|
||||
|
||||
@ -765,7 +765,7 @@ struct Sm90AuxLoad<
|
||||
|
||||
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||
gmem_ptr ptr_aux = make_gmem_ptr(subbyte_iterator<cutlass::uint1b_t const>(params.ptr_aux));
|
||||
gmem_ptr ptr_aux = make_gmem_ptr<cutlass::uint1b_t const>(params.ptr_aux);
|
||||
Tensor mAux = make_tensor(ptr_aux, make_layout(make_shape(M,N,L), params.dAux)); // (M,N,L)
|
||||
Tensor gAux = local_tile(mAux, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N)
|
||||
|
||||
|
||||
@ -1173,8 +1173,9 @@ public:
|
||||
CUTLASS_DEVICE auto
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
Layout ref_layout_MN = [&] () {
|
||||
if constexpr (ReferenceSrc) { return get<0>(args.tiled_copy.get_layoutS_MN()); }
|
||||
else { return get<0>(args.tiled_copy.get_layoutD_MN()); }
|
||||
auto mn_shape = shape(typename decltype(args.tiled_copy)::Tiler_MN{});
|
||||
if constexpr (ReferenceSrc) { return right_inverse(args.tiled_copy.get_layoutS_TV()).with_shape(mn_shape); }
|
||||
else { return right_inverse(args.tiled_copy.get_layoutD_TV()).with_shape(mn_shape); }
|
||||
}(); // tile_mn -> tv_idx
|
||||
|
||||
// Get the MN layout + coord of lanes to determine shuffle reduction iterations
|
||||
@ -1650,8 +1651,9 @@ public:
|
||||
CUTLASS_DEVICE auto
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
Layout ref_layout_MN = [&] () {
|
||||
if constexpr (ReferenceSrc) { return get<0>(args.tiled_copy.get_layoutS_MN()); }
|
||||
else { return get<0>(args.tiled_copy.get_layoutD_MN()); }
|
||||
auto mn_shape = shape(typename decltype(args.tiled_copy)::Tiler_MN{});
|
||||
if constexpr (ReferenceSrc) { return right_inverse(args.tiled_copy.get_layoutS_TV()).with_shape(mn_shape); }
|
||||
else { return right_inverse(args.tiled_copy.get_layoutD_TV()).with_shape(mn_shape); }
|
||||
}(); // tile_mn -> tv_idx
|
||||
|
||||
// Get the MN layout + coord of lanes to determine shuffle reduction iterations
|
||||
|
||||
@ -93,7 +93,7 @@ Array<float, 2> top_2_reduce(Array<float, 2> a, Array<float, 2> b) {
|
||||
" setp.gtu.f32 p, %2, %4;\n" // a0 > b0
|
||||
" selp.f32 %1, mx.x, mx.y, p;\n" // a0 > b0 ? max(a1, b0) : max(a0, b1)
|
||||
" selp.f32 %0, %2, %4, p;\n" // a0 > b0 ? a0 : b0
|
||||
"}\n" : "=f"(out[0]), "=f"(out[1]) :
|
||||
"}\n" : "=f"(out[0]), "=f"(out[1]) :
|
||||
"f"(a[0]), "f"(a[1]), "f"(b[0]), "f"(b[1]));
|
||||
return out;
|
||||
}
|
||||
@ -117,8 +117,8 @@ Array<float, 4> top_4_reduce_scalar(Array<float, 4> a, float scalar) {
|
||||
" selp.f32 %1, %5, %8, p1;\n" // a0 = a1 > b ? a1 : b
|
||||
" selp.f32 %1, %1, %4, p0;\n" // a0 > b ? max(a1, b) : a0 == a0 > b ? a0 : old_a0
|
||||
" selp.f32 %0, %4, %8, p0;\n" // a0 = a0 > b ? a0 : b
|
||||
"}\n" :
|
||||
"=f"(out[0]), "=f"(out[1]), "=f"(out[2]), "=f"(out[3]) :
|
||||
"}\n" :
|
||||
"=f"(out[0]), "=f"(out[1]), "=f"(out[2]), "=f"(out[3]) :
|
||||
"f"(a[0]), "f"(a[1]), "f"(a[2]), "f"(a[3]), "f"(scalar));
|
||||
return out;
|
||||
}
|
||||
@ -187,8 +187,8 @@ Array<float, 4> top_4_reduce(Array<float, 4> a, Array<float, 4> b) {
|
||||
" selp.f32 %3, mxa2b1, %3, pa1b1;\n" // a3 = a1 > b1 ? max(a2, b1) ** second most likely case
|
||||
" selp.f32 %3, mxa3b0, %3, pa2b0;\n" // a0 > a1 > a2 > b0
|
||||
" selp.f32 %3, mxa0b3, %3, pb2a0;\n" // b0 > b1 > b2 > a0
|
||||
"}\n" :
|
||||
"=f"(out[0]), "=f"(out[1]), "=f"(out[2]), "=f"(out[3]) :
|
||||
"}\n" :
|
||||
"=f"(out[0]), "=f"(out[1]), "=f"(out[2]), "=f"(out[3]) :
|
||||
"f"(a[0]), "f"(a[1]), "f"(a[2]), "f"(a[3]),
|
||||
"f"(b[0]), "f"(b[1]), "f"(b[2]), "f"(b[3]));
|
||||
return out;
|
||||
@ -351,7 +351,7 @@ private:
|
||||
// we can track logsumexp instead of tracking two variables (sum of exps and the max).
|
||||
// In addition, subtracting logsumexp from any element and taking its exp is equivalent to
|
||||
// computing its softmax.
|
||||
//
|
||||
//
|
||||
// The overlap between softmax and top-K is that we don't need to reduce logsumexp along the
|
||||
// way at all, because any element not in the top-K is going to be masked out and set to 0.
|
||||
// Therefore, we only reduce the top-K elements, and when done, compute their logsumexp and
|
||||
@ -370,7 +370,7 @@ private:
|
||||
ReductionResult() { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ReductionResult(ElementCompute min, ElementCompute logsumexp):
|
||||
ReductionResult(ElementCompute min, ElementCompute logsumexp):
|
||||
logsumexp_(logsumexp), min_(min) { }
|
||||
|
||||
// Warp shuffle broadcast
|
||||
@ -541,7 +541,7 @@ public:
|
||||
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n,
|
||||
Array<ElementInput, FragmentSize> const& frg_input) {
|
||||
|
||||
auto& [tCrTopK, tCrSoftmax, tCcCol, cCol,
|
||||
auto& [tCrTopK, tCrSoftmax, tCcCol, cCol,
|
||||
lane_layout_MN, lane_mn,
|
||||
residue_cCol, residue_tCcCol] = args_tuple;
|
||||
Tensor tCcCol_mn = tCcCol(_,_,_,epi_m,epi_n);
|
||||
@ -566,7 +566,7 @@ public:
|
||||
CUTLASS_DEVICE void
|
||||
reduce(STensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) {
|
||||
|
||||
auto& [tCrTopK, tCrSoftmax, tCcCol, cCol,
|
||||
auto& [tCrTopK, tCrSoftmax, tCcCol, cCol,
|
||||
lane_layout_MN, lane_mn,
|
||||
residue_cCol, residue_tCcCol] = args_tuple;
|
||||
|
||||
@ -668,7 +668,7 @@ public:
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
end_loop(int epi_m, int epi_n) {
|
||||
auto& [tCrTopK, tCrSoftmax, tCcCol, cCol,
|
||||
auto& [tCrTopK, tCrSoftmax, tCcCol, cCol,
|
||||
lane_layout_MN, lane_mn,
|
||||
residue_cCol, residue_tCcCol] = args_tuple;
|
||||
|
||||
@ -690,8 +690,9 @@ public:
|
||||
CUTLASS_DEVICE auto
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
Layout ref_layout_MN = [&] () {
|
||||
if constexpr (ReferenceSrc) { return get<0>(args.tiled_copy.get_layoutS_MN()); }
|
||||
else { return get<0>(args.tiled_copy.get_layoutD_MN()); }
|
||||
auto mn_shape = shape(typename decltype(args.tiled_copy)::Tiler_MN{});
|
||||
if constexpr (ReferenceSrc) { return right_inverse(args.tiled_copy.get_layoutS_TV()).with_shape(mn_shape); }
|
||||
else { return right_inverse(args.tiled_copy.get_layoutD_TV()).with_shape(mn_shape); }
|
||||
}(); // tile_mn -> tv_idx
|
||||
|
||||
// Get the MN layout + coord of lanes to determine shuffle reduction iterations
|
||||
@ -739,7 +740,7 @@ public:
|
||||
Tensor tRS_rSoftmax = thread_r2s.retile_S(tCrColReduce); // ((R2S,R2S_V),MMA_M,MMA_N)
|
||||
auto tCrC_layout = args.tCrC.layout(); // (R2S,R2S_M,R2S_N)
|
||||
|
||||
// Compose the new accumulator R2S layout with the expected tCrC layout to get final
|
||||
// Compose the new accumulator R2S layout with the expected tCrC layout to get final
|
||||
// reduction tensor layout.
|
||||
auto tCrSoftmax_layout = take<0, 3>(tRS_rSoftmax.layout()).compose(tCrC_layout); // (R2S,R2S_V) o (R2S,R2S_M,R2S_N)
|
||||
|
||||
|
||||
@ -569,6 +569,47 @@ sm100_make_trivial_fastFP32_tiled_mma() {
|
||||
}
|
||||
}
|
||||
|
||||
template<
|
||||
class CtaShape_MNK
|
||||
>
|
||||
constexpr auto
|
||||
sm100_simt_f32_warp_shape_mnk_selector() {
|
||||
using namespace cute;
|
||||
|
||||
constexpr int CtaShape_M = cute::size<0>(CtaShape_MNK{});
|
||||
constexpr int CtaShape_N = cute::size<1>(CtaShape_MNK{});
|
||||
constexpr int CtaShape_K = cute::size<2>(CtaShape_MNK{});
|
||||
|
||||
// CTA tile shape M and N are supposed to be divisible by 32.
|
||||
static_assert(CtaShape_M % 32 == 0, "CtaShape_M needs to be divisible by 32.");
|
||||
static_assert(CtaShape_N % 32 == 0, "CtaShape_N needs to be divisible by 32.");
|
||||
|
||||
// WarpShape_MNK configuration
|
||||
// We assume WarpShape_K is always 1 in our SM100 SIMT SGEMM implementation.
|
||||
if constexpr (CtaShape_M >= CtaShape_N) {
|
||||
if constexpr (CtaShape_M == 256 && CtaShape_N == 128) {
|
||||
return cute::Shape<_4, _2, _1>{};
|
||||
}
|
||||
else if constexpr ((CtaShape_M == 64 || CtaShape_M == 32) && CtaShape_N == 32) {
|
||||
return cute::Shape<_1, _2, _1>{};
|
||||
}
|
||||
else {
|
||||
return cute::Shape<_2, _2, _1>{};
|
||||
}
|
||||
}
|
||||
else {
|
||||
if constexpr (CtaShape_M == 128 && CtaShape_N == 256) {
|
||||
return cute::Shape<_2, _4, _1>{};
|
||||
}
|
||||
else if constexpr (CtaShape_M == 32 && CtaShape_N == 64) {
|
||||
return cute::Shape<_1, _2, _1>{};
|
||||
}
|
||||
else {
|
||||
return cute::Shape<_1, _4, _1>{};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <
|
||||
class ElementPairA,
|
||||
|
||||
216
include/cutlass/gemm/collective/builders/sm100_simt_builder.inl
Normal file
216
include/cutlass/gemm/collective/builders/sm100_simt_builder.inl
Normal file
@ -0,0 +1,216 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/collective/builders/sm100_common.inl"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail {
|
||||
|
||||
template<
|
||||
class LayoutA,
|
||||
int AlignmentA,
|
||||
class LayoutB,
|
||||
int AlignmentB,
|
||||
class CtaShape_MNK,
|
||||
class WarpShape_MNK
|
||||
>
|
||||
constexpr auto
|
||||
sm100_make_simt_f32_tiled_mma() {
|
||||
using namespace cute;
|
||||
|
||||
constexpr int CtaShape_M = cute::size<0>(CtaShape_MNK{});
|
||||
constexpr int CtaShape_N = cute::size<1>(CtaShape_MNK{});
|
||||
constexpr int CtaShape_K = cute::size<2>(CtaShape_MNK{});
|
||||
|
||||
constexpr int WarpShape_M = cute::size<0>(WarpShape_MNK{});
|
||||
constexpr int WarpShape_N = cute::size<1>(WarpShape_MNK{});
|
||||
constexpr int WarpShape_K = cute::size<2>(WarpShape_MNK{});
|
||||
|
||||
// Use Permutation to achieve a [4 x 4] value layout for each thread.
|
||||
// Ideally, we want the tiled mma to be such that loads from shared memory are 128 bit wide.
|
||||
// While as we are using CtaShape_K = 16, when A and B are K-major, we use tranpose + 8 byte padding to avoid smem bank conflict,
|
||||
// so we could only use 64 bit smem load.
|
||||
// When A and B are MN-major, we use 128 bit smem load.
|
||||
using PermutationA = Layout<Shape<_2, Int<WarpShape_M * 8>, _2>, Stride< _1, _4, _2>>;
|
||||
using PermutationB = Layout<Shape<Int<WarpShape_N * 4>, _4>, Stride< _4, _1>>;
|
||||
|
||||
// For 32 threads in 1 warp, we use [8 x 4] thread layouts and each thread will hold [4 x 4] value layouts.
|
||||
// Then totally each warp will hold [32 x 16] value layouts.
|
||||
// So WarpShape_M needs to be equal or smaller than CtaShape_M / 32 and WarpShape_N needs to be equal or smaller than CtaShape_N / 16.
|
||||
static_assert(WarpShape_M <= CtaShape_M / 32, "WarpShape_M is too large, it needs to be equal or smaller than CtaShape_M / 32.");
|
||||
static_assert(WarpShape_N <= CtaShape_N / 16, "WarpShape_N is too large, it needs to be equal or smaller than CtaShape_N / 16.");
|
||||
|
||||
constexpr int WarpStride_M = (WarpShape_M != 1) * NumThreadsPerWarp;
|
||||
constexpr int WarpStride_N = WarpShape_M * NumThreadsPerWarp;
|
||||
|
||||
// We first introduce a [8 x 4] thread layouts in 1 warp.
|
||||
// And inside this [8 x 4] thread layouts, each 4 threads will be arranged as [2 x 2].
|
||||
// Then we could set different WarpShape to finalize how many warps we use in our tiled mma.
|
||||
// For example :
|
||||
// With 128 threads in the tiled mma, we could set the WarpShapeMNK as [2 x 2 x 1], [1 x 4 x 1] and [4 x 1 x 1].
|
||||
// With 64 threads in the tiled mma, we could set the WarpShapeMNK as [1 x 2 x 1] and [2 x 1 x 1].
|
||||
return make_tiled_mma(
|
||||
MMA_Atom<SM100_2x1x1_F32F32F32F32>{},
|
||||
Layout<Shape < Shape <_2, _4, Int<WarpShape_M>>, Shape <_2, _2, Int<WarpShape_N>>, _1>,
|
||||
Stride< Stride<_1, _8, Int<WarpStride_M>>, Stride<_2, _4, Int<WarpStride_N>>, _1>>{},
|
||||
Tile<
|
||||
PermutationA,
|
||||
PermutationB,
|
||||
Underscore>{});
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <
|
||||
class GmemLayoutATag,
|
||||
int AlignmentA,
|
||||
class GmemLayoutBTag,
|
||||
int AlignmentB,
|
||||
class CtaShape_MNK,
|
||||
class ClusterShape_MNK,
|
||||
int stages,
|
||||
class BuilderScheduleTag>
|
||||
struct CollectiveBuilder<
|
||||
arch::Sm100,
|
||||
arch::OpClassSimt,
|
||||
float,
|
||||
GmemLayoutATag,
|
||||
AlignmentA,
|
||||
float,
|
||||
GmemLayoutBTag,
|
||||
AlignmentB,
|
||||
float,
|
||||
CtaShape_MNK,
|
||||
ClusterShape_MNK,
|
||||
StageCount<stages>,
|
||||
BuilderScheduleTag,
|
||||
cute::enable_if_t<
|
||||
(cute::is_same_v<BuilderScheduleTag, KernelMultistage> ||
|
||||
cute::is_same_v<BuilderScheduleTag, KernelPtrArrayMultistage> ||
|
||||
cute::is_same_v<BuilderScheduleTag, KernelScheduleAuto>) &&
|
||||
((sizeof(float) * AlignmentA) % detail::cp_async_min_alignment_bytes == 0) &&
|
||||
((sizeof(float) * AlignmentB) % detail::cp_async_min_alignment_bytes == 0) >> {
|
||||
static_assert(cute::size<2>(CtaShape_MNK{}) == 16, "SM100 SIMT SGEMM Kernels only support TileShape_K = 16.");
|
||||
|
||||
// This kernel is specialized for F32 data type.
|
||||
using ElementA = float;
|
||||
using ElementB = float;
|
||||
|
||||
using M = decltype(cute::size<0>(CtaShape_MNK{}));
|
||||
using N = decltype(cute::size<1>(CtaShape_MNK{}));
|
||||
using K = decltype(cute::size<2>(CtaShape_MNK{}));
|
||||
|
||||
using WarpShape_MNK = decltype(detail::sm100_simt_f32_warp_shape_mnk_selector<CtaShape_MNK>());
|
||||
|
||||
static constexpr int ThreadCount = cute::size(WarpShape_MNK{}) * NumThreadsPerWarp;
|
||||
|
||||
using TiledMma = decltype(
|
||||
detail::sm100_make_simt_f32_tiled_mma<
|
||||
GmemLayoutATag,
|
||||
AlignmentA,
|
||||
GmemLayoutBTag,
|
||||
AlignmentB,
|
||||
CtaShape_MNK,
|
||||
WarpShape_MNK>());
|
||||
|
||||
// for K major layouts, add a smem alignment offset to avoid bank conflicts
|
||||
static constexpr int SmemAlignmentOffsetA = cutlass::gemm::detail::is_mn_major_A<GmemLayoutATag>() ? 0 : 2;
|
||||
static constexpr int SmemAlignmentOffsetB = cutlass::gemm::detail::is_mn_major_B<GmemLayoutBTag>() ? 0 : 2;
|
||||
static constexpr int CtaShape_M = cute::size<0>(CtaShape_MNK{});
|
||||
static constexpr int CtaShape_N = cute::size<1>(CtaShape_MNK{});
|
||||
|
||||
// Shared memory layout is [M x K] in M-major
|
||||
using SmemLayoutAtomA = cute::Layout<cute::Shape< M, K>,
|
||||
cute::Stride<_1, Int<CtaShape_M + SmemAlignmentOffsetA>>>;
|
||||
// A M-major use 128bit smem load.
|
||||
// A K-major needs to do tranpose and 8 byte padding to make smem bank conflict free, then we can only use 64bit smem load.
|
||||
using SmemCopyAtomA = std::conditional_t<cutlass::gemm::detail::is_mn_major_A<GmemLayoutATag>(),
|
||||
cute::Copy_Atom<cute::AutoVectorizingCopyWithAssumedAlignment<128>, ElementA>,
|
||||
cute::Copy_Atom<cute::AutoVectorizingCopyWithAssumedAlignment<64>, ElementA>>;
|
||||
|
||||
using AlignmentTypeA = cute::uint_byte_t<static_cast<int>(sizeof(ElementA)) * AlignmentA>;
|
||||
using GmemCopyAtomA = cute::Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS_ZFILL<AlignmentTypeA>, ElementA>;
|
||||
using GmemTiledCopyA = decltype(
|
||||
detail::make_simt_gmem_tiled_copy<
|
||||
GmemCopyAtomA, ThreadCount, AlignmentA, TagToStrideA_t<GmemLayoutATag>, M, K>());
|
||||
|
||||
// Shared memory layout is [N x K] in N-major
|
||||
using SmemLayoutAtomB = cute::Layout<cute::Shape< N, K>,
|
||||
cute::Stride<_1, Int<CtaShape_N + SmemAlignmentOffsetB>>>;
|
||||
// B N-major use 128bit smem load.
|
||||
// B K-major needs to do tranpose and 8 byte padding to make smem bank conflict free, then we can only use 64bit smem load.
|
||||
using SmemCopyAtomB = std::conditional_t<cutlass::gemm::detail::is_mn_major_B<GmemLayoutBTag>(),
|
||||
cute::Copy_Atom<cute::AutoVectorizingCopyWithAssumedAlignment<128>, ElementB>,
|
||||
cute::Copy_Atom<cute::AutoVectorizingCopyWithAssumedAlignment<64>, ElementB>>;
|
||||
|
||||
using AlignmentTypeB = cute::uint_byte_t<static_cast<int>(sizeof(ElementB)) * AlignmentB>;
|
||||
using GmemCopyAtomB = cute::Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS_ZFILL<AlignmentTypeB>, ElementB>;
|
||||
using GmemTiledCopyB = decltype(
|
||||
detail::make_simt_gmem_tiled_copy<
|
||||
GmemCopyAtomB, ThreadCount, AlignmentB, TagToStrideB_t<GmemLayoutBTag>, N, K>());
|
||||
|
||||
static constexpr bool IsArrayOfPointersGemm = cute::is_same_v<BuilderScheduleTag, KernelPtrArrayMultistage>;
|
||||
using DispatchPolicy = cute::conditional_t<IsArrayOfPointersGemm,
|
||||
cutlass::gemm::MainloopSm80ArrayCpAsync<stages,
|
||||
ClusterShape_MNK>,
|
||||
cutlass::gemm::MainloopSm80CpAsync<stages,
|
||||
ClusterShape_MNK>
|
||||
>;
|
||||
|
||||
using CollectiveOp = cutlass::gemm::collective::CollectiveMma<
|
||||
DispatchPolicy,
|
||||
CtaShape_MNK,
|
||||
ElementA,
|
||||
TagToStrideA_t<GmemLayoutATag>,
|
||||
ElementB,
|
||||
TagToStrideB_t<GmemLayoutBTag>,
|
||||
TiledMma,
|
||||
GmemTiledCopyA,
|
||||
SmemLayoutAtomA,
|
||||
SmemCopyAtomA,
|
||||
cute::identity,
|
||||
GmemTiledCopyB,
|
||||
SmemLayoutAtomB,
|
||||
SmemCopyAtomB,
|
||||
cute::identity
|
||||
>;
|
||||
};
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -46,6 +46,7 @@
|
||||
#include "cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl"
|
||||
#include "cutlass/gemm/collective/builders/sm100_blockwise_umma_builder.inl"
|
||||
#include "cutlass/gemm/collective/builders/sm100_blockscaled_sparse_umma_builder.inl"
|
||||
#include "cutlass/gemm/collective/builders/sm100_simt_builder.inl"
|
||||
#include "cutlass/gemm/collective/builders/sm120_mma_builder.inl"
|
||||
#include "cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl"
|
||||
#include "cutlass/gemm/collective/builders/sm120_sparse_mma_builder.inl"
|
||||
|
||||
@ -37,6 +37,7 @@
|
||||
|
||||
#include "cutlass/gemm/collective/sm70_mma_twostage.hpp"
|
||||
#include "cutlass/gemm/collective/sm80_mma_multistage.hpp"
|
||||
#include "cutlass/gemm/collective/sm80_mma_array_multistage.hpp"
|
||||
#include "cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp"
|
||||
#include "cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp"
|
||||
#include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp"
|
||||
|
||||
412
include/cutlass/gemm/collective/sm80_mma_array_multistage.hpp
Normal file
412
include/cutlass/gemm/collective/sm80_mma_array_multistage.hpp
Normal file
@ -0,0 +1,412 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
|
||||
#include "cute/algorithm/functional.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cute/algorithm/gemm.hpp"
|
||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
using namespace cute;
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
int Stages,
|
||||
class ClusterShape_,
|
||||
class TileShape_,
|
||||
class ElementA_,
|
||||
class StrideA_,
|
||||
class ElementB_,
|
||||
class StrideB_,
|
||||
class TiledMma_,
|
||||
class GmemTiledCopyA_,
|
||||
class SmemLayoutAtomA_,
|
||||
class SmemCopyAtomA_,
|
||||
class TransformA_,
|
||||
class GmemTiledCopyB_,
|
||||
class SmemLayoutAtomB_,
|
||||
class SmemCopyAtomB_,
|
||||
class TransformB_
|
||||
>
|
||||
struct CollectiveMma<
|
||||
MainloopSm80ArrayCpAsync<
|
||||
Stages,
|
||||
ClusterShape_>,
|
||||
TileShape_,
|
||||
ElementA_,
|
||||
StrideA_,
|
||||
ElementB_,
|
||||
StrideB_,
|
||||
TiledMma_,
|
||||
GmemTiledCopyA_,
|
||||
SmemLayoutAtomA_,
|
||||
SmemCopyAtomA_,
|
||||
TransformA_,
|
||||
GmemTiledCopyB_,
|
||||
SmemLayoutAtomB_,
|
||||
SmemCopyAtomB_,
|
||||
TransformB_
|
||||
>
|
||||
{
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using DispatchPolicy = MainloopSm80ArrayCpAsync<
|
||||
Stages,
|
||||
ClusterShape_>;
|
||||
using TileShape = TileShape_;
|
||||
// Follow the change in TestSmall: TileShape switch to CtaShape
|
||||
// In legacy arch, it should be same
|
||||
using CtaShape_MNK = TileShape;
|
||||
using ElementA = ElementA_;
|
||||
using StrideA = StrideA_;
|
||||
using InternalStrideA = cute::remove_pointer_t<StrideA>;
|
||||
using ElementB = ElementB_;
|
||||
using StrideB = StrideB_;
|
||||
using InternalStrideB = cute::remove_pointer_t<StrideB>;
|
||||
using TiledMma = TiledMma_;
|
||||
using ElementAccumulator = typename TiledMma::ValTypeC; using GmemTiledCopyA = GmemTiledCopyA_;
|
||||
using GmemTiledCopyB = GmemTiledCopyB_;
|
||||
using SmemLayoutAtomA = SmemLayoutAtomA_;
|
||||
using SmemLayoutAtomB = SmemLayoutAtomB_;
|
||||
using SmemCopyAtomA = SmemCopyAtomA_;
|
||||
using SmemCopyAtomB = SmemCopyAtomB_;
|
||||
using TransformA = TransformA_;
|
||||
using TransformB = TransformB_;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
using ArrayElementA = ElementA;
|
||||
using ArrayElementB = ElementB;
|
||||
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
using SmemLayoutA = decltype(tile_to_shape(
|
||||
SmemLayoutAtomA{},
|
||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{})));
|
||||
using SmemLayoutB = decltype(tile_to_shape(
|
||||
SmemLayoutAtomB{},
|
||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{})));
|
||||
|
||||
static_assert(DispatchPolicy::Stages >= 2, "CpAsync mainloop must have at least 2 stages in the pipeline.");
|
||||
|
||||
struct SharedStorage
|
||||
{
|
||||
cute::array_aligned<ElementA, cute::cosize_v<SmemLayoutA>> smem_a;
|
||||
cute::array_aligned<ElementB, cute::cosize_v<SmemLayoutB>> smem_b;
|
||||
};
|
||||
|
||||
// Host side kernel arguments
|
||||
struct Arguments {
|
||||
ElementA const** ptr_A{nullptr};
|
||||
StrideA dA{};
|
||||
ElementB const** ptr_B{nullptr};
|
||||
StrideB dB{};
|
||||
};
|
||||
|
||||
// Device side kernel params
|
||||
using Params = Arguments;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CollectiveMma() = default;
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& _, Arguments const& args, void* workspace) {
|
||||
(void) workspace;
|
||||
return args;
|
||||
}
|
||||
|
||||
/// Perform a collective-scoped matrix multiply-accumulate
|
||||
template <
|
||||
class FrgTensorD,
|
||||
class TensorA,
|
||||
class TensorB,
|
||||
class FrgTensorC,
|
||||
class KTileIterator,
|
||||
class ResidueMNK
|
||||
>
|
||||
CUTLASS_DEVICE void
|
||||
operator() (
|
||||
FrgTensorD &accum,
|
||||
TensorA gA, // (BLK_M, BLK_K, K_TILES)
|
||||
TensorB gB, // (BLK_N, BLK_K, K_TILES)
|
||||
FrgTensorC const &src_accum,
|
||||
KTileIterator k_tile_iter, int k_tile_count,
|
||||
ResidueMNK residue_mnk,
|
||||
int thread_idx,
|
||||
char *smem_buf)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
static_assert(is_rmem<FrgTensorD>::value, "D tensor must be rmem resident.");
|
||||
static_assert(is_gmem<TensorA>::value, "A tensor must be gmem resident.");
|
||||
static_assert(is_gmem<TensorB>::value, "B tensor must be gmem resident.");
|
||||
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
|
||||
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
||||
|
||||
// Construct shared memory tiles
|
||||
SharedStorage& storage = *reinterpret_cast<SharedStorage*>(smem_buf);
|
||||
Tensor sA = make_tensor(make_smem_ptr(storage.smem_a.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(storage.smem_b.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<0>(gA) == size<0>(sA)); // BLK_M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(gA) == size<1>(sA)); // BLK_K
|
||||
CUTE_STATIC_ASSERT_V(size<0>(gB) == size<0>(sB)); // BLK_N
|
||||
CUTE_STATIC_ASSERT_V(size<1>(gB) == size<1>(sB)); // BLK_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // BLK_K
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
|
||||
|
||||
// Shift tensor so residue_k is at origin (Can't read any k_coord < residue_k)
|
||||
// This aligns the tensor with BLK_K for all but the 0th k_tile
|
||||
gA = cute::domain_offset(make_coord(0, get<2>(residue_mnk), 0), gA);
|
||||
gB = cute::domain_offset(make_coord(0, get<2>(residue_mnk), 0), gB);
|
||||
|
||||
// Partition the copying of A and B tiles across the threads
|
||||
GmemTiledCopyA gmem_tiled_copy_A;
|
||||
GmemTiledCopyB gmem_tiled_copy_B;
|
||||
auto gmem_thr_copy_A = gmem_tiled_copy_A.get_slice(thread_idx);
|
||||
auto gmem_thr_copy_B = gmem_tiled_copy_B.get_slice(thread_idx);
|
||||
|
||||
Tensor tAgA = gmem_thr_copy_A.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k)
|
||||
Tensor tAsA = gmem_thr_copy_A.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,PIPE)
|
||||
Tensor tBgB = gmem_thr_copy_B.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k)
|
||||
Tensor tBsB = gmem_thr_copy_B.partition_D(sB); // (BCPY,BCPY_N,BCPY_K,PIPE)
|
||||
|
||||
//
|
||||
// PREDICATES
|
||||
//
|
||||
|
||||
// Allocate predicate tensors for m and n
|
||||
Tensor tApA = make_tensor<bool>(make_shape(size<1>(tAsA), size<2>(tAsA)), Stride<_1,_0>{});
|
||||
Tensor tBpB = make_tensor<bool>(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{});
|
||||
|
||||
// Construct identity layout for sA and sB
|
||||
Tensor cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
Tensor cB = make_identity_tensor(make_shape(size<0>(sB), size<1>(sB))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
|
||||
|
||||
// Repeat the partitioning with identity layouts
|
||||
Tensor tAcA = gmem_thr_copy_A.partition_S(cA); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
|
||||
Tensor tBcB = gmem_thr_copy_B.partition_S(cB); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
|
||||
|
||||
// Set predicates for m bounds
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < size<0>(tApA); ++m) {
|
||||
tApA(m,0) = get<0>(tAcA(0,m,0)) < get<0>(residue_mnk); // blk_m coord < residue_m
|
||||
}
|
||||
// Set predicates for n bounds
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < size<0>(tBpB); ++n) {
|
||||
tBpB(n,0) = get<0>(tBcB(0,n,0)) < get<1>(residue_mnk); // blk_n coord < residue_n
|
||||
}
|
||||
|
||||
//
|
||||
// PREFETCH
|
||||
//
|
||||
|
||||
// Clear the smem tiles to account for predicated off loads
|
||||
clear(tAsA);
|
||||
clear(tBsB);
|
||||
|
||||
// Start async loads for 0th k-tile, where we take care of the k residue
|
||||
{
|
||||
constexpr int k_pipe = 0;
|
||||
|
||||
Tensor tAgAk = tAgA(_,_,_,*k_tile_iter);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k = 0; k < size<2>(tAsA); ++k) {
|
||||
if (get<1>(tAcA(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gA shifted)
|
||||
copy_if(gmem_tiled_copy_A, tApA(_,k), tAgAk(_,_,k), tAsA(_,_,k,k_pipe));
|
||||
}
|
||||
}
|
||||
Tensor tBgBk = tBgB(_,_,_,*k_tile_iter);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k = 0; k < size<2>(tBsB); ++k) {
|
||||
if (get<1>(tBcB(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gB shifted)
|
||||
copy_if(gmem_tiled_copy_B, tBpB(_,k), tBgBk(_,_,k), tBsB(_,_,k,k_pipe));
|
||||
}
|
||||
}
|
||||
cp_async_fence();
|
||||
++k_tile_iter;
|
||||
--k_tile_count;
|
||||
}
|
||||
|
||||
// Start async loads for 1st k-tile onwards, no k-residue handling needed
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_pipe = 1; k_pipe < DispatchPolicy::Stages-1; ++k_pipe) {
|
||||
if (k_tile_count <= 0) {
|
||||
clear(tApA);
|
||||
clear(tBpB);
|
||||
}
|
||||
copy_if(gmem_tiled_copy_A, tApA, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,k_pipe)); // CpAsync
|
||||
copy_if(gmem_tiled_copy_B, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,k_pipe)); // CpAsync
|
||||
cp_async_fence();
|
||||
++k_tile_iter;
|
||||
--k_tile_count;
|
||||
}
|
||||
|
||||
//
|
||||
// MMA Atom partitioning
|
||||
//
|
||||
|
||||
// Tile MMA compute thread partitions and allocate accumulators
|
||||
TiledMma tiled_mma;
|
||||
auto thr_mma = tiled_mma.get_thread_slice(thread_idx);
|
||||
Tensor tCrA = thr_mma.partition_fragment_A(sA(_,_,0)); // (MMA,MMA_M,MMA_K)
|
||||
Tensor tCrB = thr_mma.partition_fragment_B(sB(_,_,0)); // (MMA,MMA_N,MMA_K)
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(accum)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(src_accum)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(accum)); // MMA_N
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(src_accum)); // MMA_N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
|
||||
|
||||
//
|
||||
// Copy Atom retiling
|
||||
//
|
||||
|
||||
auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma);
|
||||
auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx);
|
||||
Tensor tCsA = smem_thr_copy_A.partition_S(sA); // (CPY,CPY_M,CPY_K,PIPE)
|
||||
Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); // (CPY,CPY_M,CPY_K)
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K
|
||||
|
||||
auto smem_tiled_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma);
|
||||
auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx);
|
||||
Tensor tCsB = smem_thr_copy_B.partition_S(sB); // (CPY,CPY_N,CPY_K,PIPE)
|
||||
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); // (CPY,CPY_N,CPY_K)
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // CPY_N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrB_copy_view)); // CPY_K
|
||||
|
||||
//
|
||||
// PIPELINED MAIN LOOP
|
||||
//
|
||||
|
||||
// Current pipe index in smem to read from
|
||||
int smem_pipe_read = 0;
|
||||
// Current pipe index in smem to write to
|
||||
int smem_pipe_write = DispatchPolicy::Stages-1;
|
||||
|
||||
Tensor tCsA_p = tCsA(_,_,_,smem_pipe_read);
|
||||
Tensor tCsB_p = tCsB(_,_,_,smem_pipe_read);
|
||||
|
||||
// Size of the register pipeline
|
||||
auto K_BLOCK_MAX = size<2>(tCrA);
|
||||
|
||||
// PREFETCH register pipeline
|
||||
if (K_BLOCK_MAX > 1) {
|
||||
// Wait until our first prefetched tile is loaded in
|
||||
cp_async_wait<DispatchPolicy::Stages-2>();
|
||||
__syncthreads();
|
||||
|
||||
// Prefetch the first rmem from the first k-tile
|
||||
copy(smem_tiled_copy_A, tCsA_p(_,_,Int<0>{}), tCrA_copy_view(_,_,Int<0>{}));
|
||||
copy(smem_tiled_copy_B, tCsB_p(_,_,Int<0>{}), tCrB_copy_view(_,_,Int<0>{}));
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for ( ; k_tile_count > -(DispatchPolicy::Stages-1); --k_tile_count)
|
||||
{
|
||||
// Pipeline the outer products with a static for loop.
|
||||
//
|
||||
// Note, the for_each() function is required here to ensure `k_block` is of type Int<N>.
|
||||
for_each(make_int_sequence<K_BLOCK_MAX>{}, [&] (auto k_block)
|
||||
{
|
||||
if (k_block == K_BLOCK_MAX - 1)
|
||||
{
|
||||
// Slice the smem_pipe_read smem
|
||||
tCsA_p = tCsA(_,_,_,smem_pipe_read);
|
||||
tCsB_p = tCsB(_,_,_,smem_pipe_read);
|
||||
|
||||
// Commit the smem for smem_pipe_read
|
||||
cp_async_wait<DispatchPolicy::Stages-2>();
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Load A, B shmem->regs for k_block+1
|
||||
auto k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static
|
||||
copy(smem_tiled_copy_A, tCsA_p(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next));
|
||||
copy(smem_tiled_copy_B, tCsB_p(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next));
|
||||
// Copy gmem to smem before computing gemm on each k-pipe
|
||||
if (k_block == 0)
|
||||
{
|
||||
// Set all predicates to false if we are going to overshoot bounds
|
||||
if (k_tile_count <= 0) {
|
||||
clear(tApA);
|
||||
clear(tBpB);
|
||||
}
|
||||
copy_if(gmem_tiled_copy_A, tApA, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,smem_pipe_write));
|
||||
copy_if(gmem_tiled_copy_B, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,smem_pipe_write));
|
||||
cp_async_fence();
|
||||
++k_tile_iter;
|
||||
|
||||
// Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe)
|
||||
smem_pipe_write = smem_pipe_read;
|
||||
++smem_pipe_read;
|
||||
smem_pipe_read = (smem_pipe_read == DispatchPolicy::Stages) ? 0 : smem_pipe_read;
|
||||
}
|
||||
|
||||
// Transform before compute
|
||||
cute::transform(tCrA(_,_,k_block), TransformA{});
|
||||
cute::transform(tCrB(_,_,k_block), TransformB{});
|
||||
// Thread-level register gemm for k_block
|
||||
cute::gemm(tiled_mma, accum, tCrA(_,_,k_block), tCrB(_,_,k_block), src_accum);
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -89,15 +89,10 @@ struct CollectiveMma<
|
||||
TransformB_>
|
||||
{
|
||||
public:
|
||||
enum class ConversionMode {
|
||||
DirectConvert,
|
||||
ConvertAndScale,
|
||||
ConvertAndScaleWithZero
|
||||
};
|
||||
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using ConversionMode = cutlass::detail::ConversionMode;
|
||||
using DispatchPolicy = MainloopSm90ArrayTmaGmmaWarpSpecializedMixedInput<Stages, ClusterShape, KernelSchedule_>;
|
||||
using TileShape = TileShape_;
|
||||
using KernelSchedule = KernelSchedule_;
|
||||
|
||||
@ -96,15 +96,11 @@ struct CollectiveMma<
|
||||
TransformB_>
|
||||
{
|
||||
public:
|
||||
enum class ConversionMode {
|
||||
DirectConvert,
|
||||
ConvertAndScale,
|
||||
ConvertAndScaleWithZero
|
||||
};
|
||||
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using ConversionMode = cutlass::detail::ConversionMode;
|
||||
using DispatchPolicy = MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput<Stages, ClusterShape, KernelSchedule_>;
|
||||
using TileShape = TileShape_;
|
||||
using KernelSchedule = KernelSchedule_;
|
||||
|
||||
@ -109,6 +109,7 @@ static constexpr bool HasAuxiliaryLoad_v = HasAuxiliaryLoad<T>::value;
|
||||
// Kernel schedule policies (the base class tags, one for each kernel layer file)
|
||||
//
|
||||
struct KernelMultistage { };
|
||||
struct KernelPtrArrayMultistage { };
|
||||
struct KernelCpAsyncWarpSpecialized { };
|
||||
struct KernelCpAsyncWarpSpecializedPingpong { };
|
||||
struct KernelCpAsyncWarpSpecializedCooperative { };
|
||||
@ -198,6 +199,17 @@ struct MainloopSm80CpAsync {
|
||||
using ClusterShape = ClusterShape_;
|
||||
};
|
||||
|
||||
// n-buffer in smem (cp.async), pipelined with registers, with predicated gmem loads for SM100 Simt Ptr-Array
|
||||
template<int Stages_,
|
||||
class ClusterShape_ = Shape<_1,_1,_1>
|
||||
>
|
||||
struct MainloopSm80ArrayCpAsync {
|
||||
constexpr static int Stages = Stages_;
|
||||
using ArchTag = cute::conditional_t<(size(ClusterShape_{}) > 1), arch::Sm90, arch::Sm80>;
|
||||
using Schedule = KernelPtrArrayMultistage;
|
||||
using ClusterShape = ClusterShape_;
|
||||
};
|
||||
|
||||
// n-buffer in smem (cp.async), pipelined with Hopper GMMA, with predicated gmem loads, warp specialized dynamic schedule
|
||||
template<
|
||||
int Stages_,
|
||||
@ -479,6 +491,16 @@ struct KernelTmaWarpSpecializedInputTransformSm100 final {
|
||||
static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_;
|
||||
};
|
||||
|
||||
// InputTransform GEMM
|
||||
template<
|
||||
int SchedulerPipelineStageCount_,
|
||||
int AccumulatorPipelineStageCount_
|
||||
>
|
||||
struct KernelTmaWarpSpecializedMixedInputTransformSm100 final {
|
||||
static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_;
|
||||
static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_;
|
||||
};
|
||||
|
||||
// Ptr-Array Dense GEMM: SM100 tensor op policy that applies to both 1SM and 2SM MMA atoms
|
||||
template<
|
||||
int SchedulerPipelineStageCount_,
|
||||
|
||||
@ -54,6 +54,7 @@ struct IsCutlass3ArrayKernel<ProblemShape, cute::void_t<typename ProblemShape::U
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass/gemm/kernel/sm70_gemm.hpp"
|
||||
#include "cutlass/gemm/kernel/sm70_gemm_array.hpp"
|
||||
#include "cutlass/gemm/kernel/sm90_gemm_tma.hpp"
|
||||
#include "cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp"
|
||||
#include "cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp"
|
||||
|
||||
@ -1008,39 +1008,6 @@ public:
|
||||
// Advance the mm2accum pipe
|
||||
mma2accum_pipeline_consumer_state = mma2accum_pipeline_consumer_state_next;
|
||||
}
|
||||
else if constexpr (InputTransformType == cutlass::gemm::detail::KernelInputTransformType::MixedInput) {
|
||||
|
||||
mma2accum_pipeline.consumer_wait(mma2accum_pipeline_consumer_state);
|
||||
|
||||
// Accumulators
|
||||
Tensor accumulators = bulk_tmem(_,_,_,mma2accum_pipeline_consumer_state.index()); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N)
|
||||
|
||||
mma2accum_pipeline_consumer_state = scheduler.template fixup<IsComplex>(
|
||||
TiledMma{},
|
||||
work_tile_info,
|
||||
accumulators,
|
||||
mma2accum_pipeline,
|
||||
mma2accum_pipeline_consumer_state,
|
||||
typename CollectiveEpilogue::CopyOpT2R{}
|
||||
);
|
||||
|
||||
//
|
||||
// Epilogue and write to gD
|
||||
//
|
||||
if (scheduler.compute_epilogue(work_tile_info)) {
|
||||
auto [mma2accum_pipeline_state_next] = collective_epilogue(
|
||||
mma2accum_pipeline,
|
||||
mma2accum_pipeline_consumer_state,
|
||||
problem_shape_MNKL,
|
||||
CtaShape_MNK{},
|
||||
cta_coord_mnkl,
|
||||
accumulators,
|
||||
shared_storage.tensors.epilogue
|
||||
);
|
||||
// Advance the mma2accum pipe
|
||||
mma2accum_pipeline_consumer_state = mma2accum_pipeline_state_next;
|
||||
}
|
||||
}
|
||||
// Complex kernels use a collective epilogue
|
||||
else {
|
||||
mma2accum_pipeline.consumer_wait(mma2accum_pipeline_consumer_state);
|
||||
|
||||
279
include/cutlass/gemm/kernel/sm70_gemm_array.hpp
Normal file
279
include/cutlass/gemm/kernel/sm70_gemm_array.hpp
Normal file
@ -0,0 +1,279 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/kernel_hardware_info.hpp"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
namespace cutlass::gemm::kernel {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
class ProblemShape_,
|
||||
class CollectiveMainloop_,
|
||||
class CollectiveEpilogue_,
|
||||
class TileScheduler_
|
||||
>
|
||||
class GemmUniversal<
|
||||
ProblemShape_,
|
||||
CollectiveMainloop_,
|
||||
CollectiveEpilogue_,
|
||||
TileScheduler_,
|
||||
cute::enable_if_t<cute::is_base_of_v<KernelPtrArrayMultistage, typename CollectiveMainloop_::DispatchPolicy::Schedule>>>
|
||||
{
|
||||
public:
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using ProblemShape = ProblemShape_;
|
||||
static_assert(rank(typename ProblemShape::UnderlyingProblemShape{}) == 4,
|
||||
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
|
||||
|
||||
// Mainloop derived types
|
||||
using CollectiveMainloop = CollectiveMainloop_;
|
||||
using TileShape = typename CollectiveMainloop::TileShape;
|
||||
using TiledMma = typename CollectiveMainloop::TiledMma;
|
||||
using ArchTag = typename CollectiveMainloop::ArchTag;
|
||||
using ElementA = typename CollectiveMainloop::ElementA;
|
||||
using StrideA = typename CollectiveMainloop::StrideA;
|
||||
using InternalStrideA = typename CollectiveMainloop::InternalStrideA;
|
||||
using ElementB = typename CollectiveMainloop::ElementB;
|
||||
using StrideB = typename CollectiveMainloop::StrideB;
|
||||
using InternalStrideB = typename CollectiveMainloop::InternalStrideB;
|
||||
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
|
||||
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
|
||||
using MainloopArguments = typename CollectiveMainloop::Arguments;
|
||||
using MainloopParams = typename CollectiveMainloop::Params;
|
||||
|
||||
using TileSchedulerTag = TileScheduler_;
|
||||
using TileScheduler = typename detail::TileSchedulerSelector<
|
||||
TileScheduler_, ArchTag, TileShape,
|
||||
cute::Shape<cute::Int<1>, cute::Int<1>, cute::Int<1>>>::Scheduler;
|
||||
using TileSchedulerArguments = typename TileScheduler::Arguments;
|
||||
static constexpr bool IsGdcEnabled = false;
|
||||
|
||||
static constexpr bool is_valid_tile_scheduler =
|
||||
cute::is_void_v<TileScheduler_> or cute::is_same_v<TileScheduler_, PersistentScheduler>;
|
||||
static_assert(is_valid_tile_scheduler, "SM70 kernel does not support specializing the tile scheduler.");
|
||||
|
||||
// Epilogue derived types
|
||||
using CollectiveEpilogue = CollectiveEpilogue_;
|
||||
using ElementC = typename CollectiveEpilogue::ElementC;
|
||||
using StrideC = typename CollectiveEpilogue::StrideC;
|
||||
using InternalStrideC = typename CollectiveEpilogue::InternalStrideC;
|
||||
using ElementD = typename CollectiveEpilogue::ElementD;
|
||||
using StrideD = typename CollectiveEpilogue::StrideD;
|
||||
using InternalStrideD = typename CollectiveEpilogue::InternalStrideD;
|
||||
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
|
||||
using EpilogueParams = typename CollectiveEpilogue::Params;
|
||||
static_assert(cute::is_same_v<ElementAccumulator, typename CollectiveEpilogue::ElementAccumulator>,
|
||||
"Mainloop and epilogue do not agree on accumulator value type.");
|
||||
|
||||
// MSVC requires the cast to fix a warning-as-error.
|
||||
static constexpr int SharedStorageSize = static_cast<int>(cute::max(
|
||||
sizeof(typename CollectiveMainloop::SharedStorage),
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage)));
|
||||
|
||||
static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(cute::size(TiledMma{}));
|
||||
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
|
||||
|
||||
// Device side arguments
|
||||
struct Arguments {
|
||||
GemmUniversalMode mode{};
|
||||
ProblemShape problem_shape{};
|
||||
MainloopArguments mainloop{};
|
||||
EpilogueArguments epilogue{};
|
||||
KernelHardwareInfo hw_info{};
|
||||
TileSchedulerArguments scheduler{};
|
||||
};
|
||||
|
||||
// Kernel entry point API
|
||||
struct Params {
|
||||
GemmUniversalMode mode{};
|
||||
typename ProblemShape::UnderlyingProblemShape problem_shape{};
|
||||
MainloopParams mainloop{};
|
||||
EpilogueParams epilogue{};
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
|
||||
static
|
||||
Params
|
||||
to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||
(void) workspace;
|
||||
typename ProblemShape::UnderlyingProblemShape problem_shape = args.problem_shape.get_host_problem_shape();
|
||||
|
||||
KernelHardwareInfo hw_info{args.hw_info.device_id, args.hw_info.sm_count};
|
||||
auto problem_shape_MNKL = append<4>(args.problem_shape, Int<1>{});
|
||||
|
||||
return {
|
||||
args.mode,
|
||||
problem_shape,
|
||||
CollectiveMainloop::to_underlying_arguments(problem_shape, args.mainloop, workspace),
|
||||
CollectiveEpilogue::to_underlying_arguments(problem_shape, args.epilogue, workspace)
|
||||
};
|
||||
}
|
||||
|
||||
static bool
|
||||
can_implement(Arguments const& args) {
|
||||
|
||||
bool implementable = (args.mode == GemmUniversalMode::kArray && rank(typename ProblemShape::UnderlyingProblemShape{}) == 4);
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
|
||||
return implementable;
|
||||
}
|
||||
typename ProblemShape::UnderlyingProblemShape problem_shape = args.problem_shape.get_host_problem_shape();
|
||||
implementable &= TileScheduler::can_implement(args.scheduler);
|
||||
return implementable;
|
||||
}
|
||||
|
||||
static size_t
|
||||
get_workspace_size(Arguments const& args) {
|
||||
size_t workspace_size = 0;
|
||||
return workspace_size;
|
||||
}
|
||||
|
||||
static
|
||||
cutlass::Status
|
||||
initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr,
|
||||
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||
cutlass::Status status = Status::kSuccess;
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
static dim3
|
||||
get_grid_shape(Params const& params) {
|
||||
int batch_count = cute::size<3>(params.problem_shape);
|
||||
return dim3(
|
||||
cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(TileShape{}))),
|
||||
cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(TileShape{}))),
|
||||
batch_count
|
||||
);
|
||||
}
|
||||
|
||||
static dim3
|
||||
get_block_shape() {
|
||||
return dim3(MaxThreadsPerBlock, 1, 1);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
operator()(Params const& params, char* smem_buf) {
|
||||
using namespace cute;
|
||||
using X = Underscore;
|
||||
|
||||
// Preconditions
|
||||
CUTE_STATIC_ASSERT(is_static<TileShape>::value);
|
||||
|
||||
// Separate out problem shape for convenience
|
||||
// Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
|
||||
// Preconditions
|
||||
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
|
||||
// Get the appropriate blocks for this thread block -- potential for thread block locality
|
||||
int thread_idx = int(threadIdx.x);
|
||||
auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
|
||||
auto [m_coord, n_coord, l_coord] = static_cast<uint3>(blockIdx);
|
||||
auto blk_coord_mnkl = make_coord(int(m_coord), int(n_coord), _, int(l_coord)); // (m,n,k,l)
|
||||
|
||||
// Represent the full tensors
|
||||
Tensor mA_mkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_A[l_coord]), make_shape(M,K,1), params.mainloop.dA); //(m,k,l)
|
||||
Tensor mB_nkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_B[l_coord]), make_shape(N,K,1), params.mainloop.dB); //(n,k,l)
|
||||
|
||||
// Get batch slice
|
||||
Tensor mA_mk = mA_mkl(_,_,0); // (m,k)
|
||||
Tensor mB_nk = mB_nkl(_,_,0); // (n,k)
|
||||
|
||||
// Slice to get the tiles this thread block is responsible for
|
||||
Tensor gA = local_tile(mA_mk, blk_shape, take<0,3>(blk_coord_mnkl), Step<_1, X,_1>{}); // (BLK_M,BLK_K,k)
|
||||
Tensor gB = local_tile(mB_nk, blk_shape, take<0,3>(blk_coord_mnkl), Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
|
||||
|
||||
// Compute tile residues for predication
|
||||
auto m_max_coord = M - size<0>(gA) * get<0>(blk_coord_mnkl); // M - BLK_M * m_coord
|
||||
auto n_max_coord = N - size<0>(gB) * get<1>(blk_coord_mnkl); // N - BLK_N * n_coord
|
||||
auto k_residue = K - size<1>(gA) * size<2>(gA); // K - BLK_K * k_coord_max
|
||||
auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue);
|
||||
|
||||
// Allocate the tiled_mma and the accumulators for the (M,N) blk_shape
|
||||
TiledMma tiled_mma;
|
||||
Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N)
|
||||
clear(accumulators);
|
||||
|
||||
auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA));
|
||||
int k_tile_count = size<2>(gA);
|
||||
|
||||
|
||||
// Perform the collective scoped MMA
|
||||
CollectiveMainloop collective_mma;
|
||||
collective_mma(
|
||||
accumulators,
|
||||
gA,
|
||||
gB,
|
||||
accumulators,
|
||||
k_tile_iter, k_tile_count,
|
||||
residue_mnk,
|
||||
thread_idx,
|
||||
smem_buf
|
||||
);
|
||||
|
||||
// Epilogue and write to gD
|
||||
CollectiveEpilogue epilogue{params.epilogue};
|
||||
epilogue(
|
||||
problem_shape_MNKL,
|
||||
blk_shape,
|
||||
blk_coord_mnkl,
|
||||
accumulators,
|
||||
tiled_mma,
|
||||
residue_mnk,
|
||||
thread_idx,
|
||||
smem_buf
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::kernel
|
||||
@ -194,6 +194,9 @@ struct integer_subbyte {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// 1-bit binary type
|
||||
using bin1_t = bool;
|
||||
|
||||
/// 1-bit Unsigned integer type
|
||||
using uint1b_t = integer_subbyte<1, false>;
|
||||
|
||||
@ -209,14 +212,12 @@ using int4b_t = integer_subbyte<4, true>;
|
||||
/// 4-bit Unsigned integer type
|
||||
using uint4b_t = integer_subbyte<4, false>;
|
||||
|
||||
/// 6-bit integer type
|
||||
using int6b_t = integer_subbyte<6, true>;
|
||||
|
||||
/// 6-bit unsigned integer type
|
||||
using uint6b_t = integer_subbyte<6, false>;
|
||||
|
||||
|
||||
/// 1-bit binary type
|
||||
using bin1_t = bool;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int Bits, bool Signed>
|
||||
|
||||
@ -50,7 +50,13 @@ struct sizeof_bits {
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct sizeof_bits<T const>: sizeof_bits<T> {};
|
||||
struct sizeof_bits<T const> : sizeof_bits<T> {};
|
||||
|
||||
template <typename T>
|
||||
struct sizeof_bits<T volatile> : sizeof_bits<T> {};
|
||||
|
||||
template <typename T>
|
||||
struct sizeof_bits<T const volatile> : sizeof_bits<T> {};
|
||||
|
||||
template <>
|
||||
struct sizeof_bits<void> {
|
||||
|
||||
Reference in New Issue
Block a user