cutlass 3.9 update (#2255)
* cutlass 3.9 update * rebase * fixes out of shared memory for blockwise Blackwell * doc format * fix issue 2253 * disable host ref by default * fix sm120 smem capacity --------- Co-authored-by: yuzhai <yuzhai@nvidia.com> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@ -647,7 +647,7 @@ make_tma_atom_im2col(CopyOp,
|
||||
gtensor_cwhdn,
|
||||
range_c,
|
||||
range_whdn,
|
||||
detail::get_swizzle_portion(slayout),
|
||||
get_swizzle_portion(slayout),
|
||||
tma_layout_vt,
|
||||
lower_corner_whd,
|
||||
upper_corner_whd,
|
||||
|
||||
@ -37,10 +37,13 @@
|
||||
#include <cute/arch/mma_sm100.hpp>
|
||||
#include <cute/arch/mma_sm100_desc.hpp>
|
||||
#include <cute/arch/mma_sm100_umma.hpp>
|
||||
#include <cute/atom/copy_traits_sm100.hpp> // cute::TMEM::
|
||||
#include <cute/arch/tmem_allocator_sm100.hpp> // cute::TMEM::
|
||||
|
||||
#include <cute/atom/mma_traits.hpp>
|
||||
#include <cute/atom/mma_traits_sm90_gmma.hpp> // cute::GMMA::
|
||||
#include <cute/atom/mma_traits_sm90_gmma_sparse.hpp> // cute::GMMA::
|
||||
#include <cute/atom/copy_traits_sm100.hpp> // UTCCP smem desc
|
||||
|
||||
#include <cute/numeric/numeric_types.hpp>
|
||||
|
||||
// Check that aggregate initialization in .with() initializes all fields
|
||||
@ -417,6 +420,9 @@ constexpr auto get_utccp_smem_desc_tensor(Tensor<TEngine, TLayout> const& smem_u
|
||||
|
||||
namespace UMMA {
|
||||
|
||||
// Import TMEM constants
|
||||
namespace TMEM = cute::TMEM;
|
||||
|
||||
enum class TmemAllocMode {
|
||||
// Default allocation mode.
|
||||
// If a TMEM Atom uses a half-subpartition (16DPs), then multiple atoms can be
|
||||
@ -3053,7 +3059,7 @@ struct MMA_Traits<SM100_MMA_F8F6F4_2x1SM_SS, a_type, b_type, c_type,
|
||||
static_assert(cute::sizeof_bits_v<a_type> <= 8 && cute::sizeof_bits_v<b_type> <= 8, "SM100_MMA_F8F6F4_2x1SM_SS supports types with leq 8bit types");
|
||||
static_assert(M == 128 || M == 256, "SM100_MMA_F8F6F4_2x1SM_SS M-mode size should be 64 or 128 for 1 CTA cluster MMA.");
|
||||
static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F8F6F4_2x1SM_SS N-mode size should be a multiple of 32 between 32 and 256.");
|
||||
|
||||
|
||||
using FrgTypeA = UMMA::smem_desc<a_major>;
|
||||
using FrgTypeB = UMMA::smem_desc<b_major>;
|
||||
using FrgTypeC = UMMA::tmem_frg_2sm<c_type>;
|
||||
|
||||
Reference in New Issue
Block a user