cutlass 3.9 update (#2255)

* cutlass 3.9 update

* rebase

* fixes out of shared memory for blockwise Blackwell

* doc format

* fix issue 2253

* disable host ref by default

* fix sm120 smem capacity

---------

Co-authored-by: yuzhai <yuzhai@nvidia.com>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Yujia Zhai
2025-04-24 12:42:40 -07:00
committed by GitHub
parent 8e345c5c5b
commit 331a1f5b3f
143 changed files with 18089 additions and 5935 deletions

File diff suppressed because it is too large Load Diff

View File

@ -647,7 +647,7 @@ make_tma_atom_im2col(CopyOp,
gtensor_cwhdn,
range_c,
range_whdn,
detail::get_swizzle_portion(slayout),
get_swizzle_portion(slayout),
tma_layout_vt,
lower_corner_whd,
upper_corner_whd,

View File

@ -37,10 +37,13 @@
#include <cute/arch/mma_sm100.hpp>
#include <cute/arch/mma_sm100_desc.hpp>
#include <cute/arch/mma_sm100_umma.hpp>
#include <cute/atom/copy_traits_sm100.hpp> // cute::TMEM::
#include <cute/arch/tmem_allocator_sm100.hpp> // cute::TMEM::
#include <cute/atom/mma_traits.hpp>
#include <cute/atom/mma_traits_sm90_gmma.hpp> // cute::GMMA::
#include <cute/atom/mma_traits_sm90_gmma_sparse.hpp> // cute::GMMA::
#include <cute/atom/copy_traits_sm100.hpp> // UTCCP smem desc
#include <cute/numeric/numeric_types.hpp>
// Check that aggregate initialization in .with() initializes all fields
@ -417,6 +420,9 @@ constexpr auto get_utccp_smem_desc_tensor(Tensor<TEngine, TLayout> const& smem_u
namespace UMMA {
// Import TMEM constants
namespace TMEM = cute::TMEM;
enum class TmemAllocMode {
// Default allocation mode.
// If a TMEM Atom uses a half-subpartition (16DPs), then multiple atoms can be
@ -3053,7 +3059,7 @@ struct MMA_Traits<SM100_MMA_F8F6F4_2x1SM_SS, a_type, b_type, c_type,
static_assert(cute::sizeof_bits_v<a_type> <= 8 && cute::sizeof_bits_v<b_type> <= 8, "SM100_MMA_F8F6F4_2x1SM_SS supports types with leq 8bit types");
static_assert(M == 128 || M == 256, "SM100_MMA_F8F6F4_2x1SM_SS M-mode size should be 64 or 128 for 1 CTA cluster MMA.");
static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F8F6F4_2x1SM_SS N-mode size should be a multiple of 32 between 32 and 256.");
using FrgTypeA = UMMA::smem_desc<a_major>;
using FrgTypeB = UMMA::smem_desc<b_major>;
using FrgTypeC = UMMA::tmem_frg_2sm<c_type>;