CUTLASS 3.8 Release (#2059)

* CUTLASS 3.8 Release

* update

* Update README.md

* Revert "Update README.md"

This reverts commit b353e36fe8.

* update

* update

---------

Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
mihir-awatramani
2025-01-24 23:44:06 -08:00
committed by GitHub
parent 9eb01fa0b0
commit 389e493055
290 changed files with 91223 additions and 292 deletions

View File

@ -922,7 +922,7 @@ __global__ void Conv3dWgrad(
filter_s = problem_size.S - 1 - filter_s;
}
int d = Z * problem_size.stride_d - problem_size.pad_w + filter_t * problem_size.dilation_d;
int d = Z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d;
int h = P * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h;
int w = Q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w;

View File

@ -352,7 +352,7 @@ struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
"Tensors must be of rank 2");
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
ScalarType, ComputeType, xor_add<ComputeType>>(
ScalarType, ComputeType, xor_popc_add<ComputeType>>(
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
}
@ -367,7 +367,7 @@ struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
"Tensors must be of rank 2");
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
ScalarType, ComputeType, xor_add<ComputeType>>(
ScalarType, ComputeType, xor_popc_add<ComputeType>>(
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
}
};
@ -389,7 +389,7 @@ struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
"Tensors must be of rank 2");
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
ScalarType, ComputeType, and_add<ComputeType>>(
ScalarType, ComputeType, and_popc_add<ComputeType>>(
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
}
@ -404,7 +404,7 @@ struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
"Tensors must be of rank 2");
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
ScalarType, ComputeType, and_add<ComputeType>>(
ScalarType, ComputeType, and_popc_add<ComputeType>>(
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
}
};

View File

@ -42,6 +42,7 @@
#include "cutlass/relatively_equal.h"
#include "cute/tensor.hpp"
#include "cute/pointer.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -59,10 +60,20 @@ struct ElementTraits<T, std::enable_if_t<!std::is_same_v<decltype(std::declval<T
/////////////////////////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////
//
// Gett Mainloop Parameters
//
///////////////////////////////////////////////////////////
template<
class ElementAccumulator_,
class TensorA_, // (M, K, L)
class TensorB_ // (N, K, L)
, class TensorSfA_ = TensorA_,
class TensorSfB_ = TensorB_
>
struct GettMainloopParams {
using ElementAccumulator = ElementAccumulator_;
@ -79,23 +90,105 @@ struct GettMainloopParams {
ComplexTransform transform_A = ComplexTransform::kNone;
ComplexTransform transform_B = ComplexTransform::kNone;
using TensorSfA = TensorSfA_;
using TensorSfB = TensorSfB_;
using EngineSfA = typename TensorSfA::engine_type;
using LayoutSfA = typename TensorSfA::layout_type;
using EngineSfB = typename TensorSfB::engine_type;
using LayoutSfB = typename TensorSfB::layout_type;
TensorSfA_ SfA{};
TensorSfB_ SfB{};
GettMainloopParams() {}
GettMainloopParams(TensorA tensor_A, TensorB tensor_B)
: A(tensor_A), B(tensor_B) {}
GettMainloopParams(TensorA tensor_A, TensorSfA tensor_SfA, TensorB tensor_B, TensorSfB tensor_SfB)
: A(tensor_A), SfA(tensor_SfA),
B(tensor_B), SfB(tensor_SfB) {}
};
////////////////////////////////////////////////////////////////////////
//
// Gett Mainloop Parameter Specialization for Block Scaled GEMM kernels
//
////////////////////////////////////////////////////////////////////////
template<
class ElementAccumulator_,
class TensorA_, // (M, K, L)
class TensorSfA_, // (M, K, L)
class TensorB_, // (N, K, L)
class TensorSfB_ // (N, K, L)
>
struct GettBlockScalingMainloopParams : public GettMainloopParams<ElementAccumulator_, TensorA_, TensorB_, TensorSfA_, TensorSfB_> {
using Base = GettMainloopParams<ElementAccumulator_, TensorA_, TensorB_, TensorSfA_, TensorSfB_>;
using ElementAccumulator = typename Base::ElementAccumulator;
using TensorA = typename Base::TensorA;
using TensorB = typename Base::TensorB;
using EngineA = typename Base::EngineA;
using LayoutA = typename Base::LayoutA;
using EngineB = typename Base::EngineB;
using LayoutB = typename Base::LayoutB;
ComplexTransform transform_A = Base::transform_A;
ComplexTransform transform_B = Base::transform_B;
using TensorSfA = typename Base::TensorSfA;
using TensorSfB = typename Base::TensorSfB;
using EngineSfA = typename Base::EngineSfA;
using LayoutSfA = typename Base::LayoutSfA;
using EngineSfB = typename Base::EngineSfB;
using LayoutSfB = typename Base::LayoutSfB;
GettBlockScalingMainloopParams() {}
GettBlockScalingMainloopParams(TensorA tensor_A, TensorSfA tensor_SfA, TensorB tensor_B, TensorSfB tensor_SfB)
: Base(tensor_A, tensor_SfA, tensor_B, tensor_SfB) {}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
enum class SfStrategy {
None = 0,
SfDGen = 1
};
///////////////////////////////////////////////////////////
//
// Gett Epilogue Parameters
//
///////////////////////////////////////////////////////////
template<
class ElementScalar_,
class ElementScalingFactor_,
class ElementAccumulator_,
class ElementCompute_,
class TensorC_, // (M, N, L)
class TensorD_, // (M, N, L)
class VectorBias_ = TensorD_, // (M, 1)
class TensorAux_ = TensorD_, // (M, N, L)
class VectorAlpha_ = TensorD_, // (M, 1)
class VectorBeta_ = VectorAlpha_, // (M, 1)
class TensorC_, // (M, N, L)
class TensorD_, // (M, N, L)
class VectorBias_ = decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // (M, 1)
class TensorAux_ = decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // (M, N, L)
class VectorAlpha_ = decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // (M, 1)
class VectorBeta_ = VectorAlpha_, // (M, 1)
class ActivationFunctor_ = cutlass::epilogue::thread::Identity<ElementCompute_>,
class TensorSFD_ = TensorD_,
class SFD_VectorSize_ = cute::Int<0>,
class BiasBinaryOp_ = cutlass::plus<ElementCompute_>,
bool PerColumnBias_ = false
,
SfStrategy SfGenStrategy_ = SfStrategy::None
>
struct GettEpilogueParams {
using ElementScalar = ElementScalar_;
@ -108,6 +201,8 @@ struct GettEpilogueParams {
using VectorBias = VectorBias_;
using VectorAlpha = VectorAlpha_;
using VectorBeta = VectorBeta_;
using TensorSFD = TensorSFD_;
using SFD_VectorSize = SFD_VectorSize_;
using ActivationFunctor = ActivationFunctor_;
using BiasBinaryOp = BiasBinaryOp_;
@ -115,7 +210,11 @@ struct GettEpilogueParams {
using LayoutC = typename TensorC::layout_type;
using EngineD = typename TensorD::engine_type;
using LayoutD = typename TensorD::layout_type;
using EngineSfD = typename TensorSFD::engine_type;
using LayoutSfD = typename TensorSFD::layout_type;
static constexpr bool PerColumnBias = PerColumnBias_;
static constexpr SfStrategy SfGenStrategy = SfGenStrategy_;
ElementScalar alpha = ElementScalar(1);
ElementScalar beta = ElementScalar(0);
@ -125,7 +224,8 @@ struct GettEpilogueParams {
TensorAux Aux{};
VectorAlpha Valpha{};
VectorBeta Vbeta{};
ElementCompute st = ElementCompute(1);
TensorSFD SfD{};
ElementCompute st = ElementCompute(1);
ElementAccumulator* abs_max_D = nullptr;
ElementAccumulator* abs_max_Aux = nullptr;
@ -137,8 +237,250 @@ struct GettEpilogueParams {
ElementScalingFactor scale_aux = ElementScalingFactor(1);
bool beta_per_channel_scaling = false;
GettEpilogueParams() {}
GettEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D)
: alpha(alpha), beta(beta), C(tensor_C), D(tensor_D) {}
GettEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D, TensorSFD tensor_SfD, ElementCompute epilogue_st)
: alpha(alpha), beta(beta), C(tensor_C), D(tensor_D), SfD(tensor_SfD), st(epilogue_st) {}
GettEpilogueParams(
ElementScalar alpha, ElementScalar beta,
TensorC tensor_C, TensorD tensor_D,
VectorBias bias, TensorAux tensor_aux,
VectorAlpha vector_alpha, VectorBeta vector_beta)
: alpha(alpha), beta(beta),
C(tensor_C), D(tensor_D),
Bias(bias), Aux(tensor_aux),
Valpha(vector_alpha), Vbeta(vector_beta) {}
};
////////////////////////////////////////////////////////////////////////
//
// Gett Epilogue Parameters Specialization for Block Scaled GEMM kernels
//
////////////////////////////////////////////////////////////////////////
template<
class ElementScalar_,
class ElementAccumulator_,
class ElementCompute_,
class TensorC_,
class TensorD_,
class TensorSfD_ = TensorD_,
class SFD_VectorSize_ = cute::Int<0>,
SfStrategy SfGenStrategy_ = SfStrategy::None
>
struct GettBlockScalingEpilogueParams : public GettEpilogueParams<
ElementScalar_, // ElementScalar
ElementScalar_, // ElementScalingFactor
ElementAccumulator_, // ElementAccumulator
ElementCompute_, // ElementCompute
TensorC_, // TensorC (M, N, L)
TensorD_, // TensorD (M, N, L)
decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // VectorBias (M, 1)
decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // TensorAux (M, N, L)
decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // VectorAlpha (M, 1)
decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // VectorBeta (M, 1)
cutlass::epilogue::thread::Identity<ElementCompute_>, //
TensorSfD_, // TensorSfD
SFD_VectorSize_, // SFD_VectorSize
cutlass::plus<ElementCompute_>, // class BiasBinaryOp_ =
false, //PerColumnBias_
SfGenStrategy_ // SfGenStrategy
> {
using Base = GettEpilogueParams<
ElementScalar_, // ElementScalar
ElementScalar_, // ElementScalingFactor
ElementAccumulator_, // ElementAccumulator
ElementCompute_, // ElementCompute
TensorC_, // TensorC (M, N, L)
TensorD_, // TensorD (M, N, L)
decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // VectorBias (M, 1)
decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // TensorAux (M, N, L)
decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // VectorAlpha (M, 1)
decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // VectorBeta (M, 1)
cutlass::epilogue::thread::Identity<ElementCompute_>, //
TensorSfD_, // TensorSfD
SFD_VectorSize_, // SFD_VectorSize
cutlass::plus<ElementCompute_>, // BiasBinaryOp
false, // PerColumnBias
SfGenStrategy_ // SfGenStrategy
>;
using ElementScalar = typename Base::ElementScalar;
using ElementScalingFactor = typename Base::ElementScalingFactor;
using ElementAccumulator = typename Base::ElementAccumulator;
using ElementCompute = typename Base::ElementCompute;
using TensorC = typename Base::TensorC;
using TensorD = typename Base::TensorD;
using TensorAux = typename Base::TensorAux;
using VectorBias = typename Base::VectorBias;
using VectorAlpha = typename Base::VectorAlpha;
using VectorBeta = typename Base::VectorBeta;
using TensorSFD = typename Base::TensorSFD;
using SFD_VectorSize = typename Base::SFD_VectorSize;
using ActivationFunctor = typename Base::ActivationFunctor;
using BiasBinaryOp = typename Base::BiasBinaryOp;
using EngineC = typename Base::EngineC;
using LayoutC = typename Base::LayoutC;
using EngineD = typename Base::EngineD;
using LayoutD = typename Base::LayoutD;
using EngineSfD = typename Base::EngineSfD;
using LayoutSfD = typename Base::LayoutSfD;
static constexpr bool PerColumnBias = Base::PerColumnBias;
static constexpr SfStrategy SfGenStrategy = Base::SfGenStrategy;
GettBlockScalingEpilogueParams() {}
GettBlockScalingEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D)
: Base(alpha, beta, tensor_C, tensor_D) {}
GettBlockScalingEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D, TensorSFD tensor_SfD)
: Base(alpha, beta, tensor_C, tensor_D, tensor_SfD, ElementCompute{0}) {}
GettBlockScalingEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D, TensorSFD tensor_SfD, ElementCompute epilogue_st)
: Base(alpha, beta, tensor_C, tensor_D, tensor_SfD, epilogue_st) {}
};
///////////////////////////////////////////////////////////
//
// Generic Gett 3x Implementation
//
///////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
template <int kVectorSize, class EpilogueParams, class TensorD, class TensorSFD, class ElementCompute, int kBlockM, int kBlockN>
void compute_1d_scaling_factor_and_quantized_output(
EpilogueParams const& epilogue_params,
TensorD &tensor_D,
TensorSFD &tensor_SfD,
int64_t m,
int64_t n,
int64_t l,
ElementCompute (&acc)[kBlockM][kBlockN])
{
using ElementD = typename ElementTraits<typename EpilogueParams::EngineD::value_type>::type;
using ElementSfD = typename ElementTraits<typename EpilogueParams::EngineSfD::value_type>::type;
int const M = cute::size<0>(tensor_D.layout());
int const N = cute::size<1>(tensor_D.layout());
int const L = cute::size<2>(tensor_D.layout());
auto mul = cutlass::multiplies<ElementCompute>{};
auto div = divides<ElementCompute>{};
// Get FP max
ElementCompute fp_max = ElementCompute(std::numeric_limits<ElementD>::max());
float scale_down_factor = div(1.0f, fp_max);
// Get st' = st / FP max
ElementCompute st_scaled_down = mul(epilogue_params.st, scale_down_factor);
absolute_value_op<ElementCompute> abs_op;
maximum_with_nan_propogation<ElementCompute> max_op;
if constexpr (cute::is_constant<1, decltype(cute::stride<0,1>(tensor_SfD))>::value) {
// MN major output
int const NumVecPerBlock = ceil_div(kBlockM, kVectorSize);
// Col major output
for (int n_b = 0; n_b < kBlockN; ++n_b) {
for (int v_b = 0; v_b < NumVecPerBlock; ++v_b) {
int64_t col = n + n_b;
/// Step1: get max across a vector
ElementCompute accum_max = ElementCompute(0);
for (int v = 0; v < kVectorSize; v++) {
int accum_row = v_b * kVectorSize + v;
int64_t output_row = accum_row + m;
if (output_row < M && col < N) {
accum_max = max_op(accum_max, abs_op(acc[accum_row][n_b]));
}
}
/// Step2: Compute Scale
ElementCompute pvscale = mul(accum_max, st_scaled_down);
ElementSfD qpvscale = static_cast<ElementSfD>(pvscale);
// Store the Scaling Factors
int64_t sf_row = m + kVectorSize * v_b;
if (sf_row < M && col < N) {
tensor_SfD(sf_row, col, l) = qpvscale;
}
/// Step3: Compute quantized output values
ElementCompute qpvscale_up = NumericConverter<ElementCompute, ElementSfD>{}(qpvscale);
// Get float reciprocal
ElementCompute qpvscale_rcp = div(1.0f, qpvscale_up);
ElementCompute acc_scale = mul(epilogue_params.st, qpvscale_rcp);
// Map INF to fp32::max
acc_scale = cutlass::minimum_with_nan_propagation<ElementCompute>{}(acc_scale, cutlass::platform::numeric_limits<ElementCompute>::max());
// Store the intermediate_accum
for (int v = 0; v < kVectorSize; v++) {
int accum_row = v_b * kVectorSize + v;
int64_t output_row = accum_row + m;
if (output_row < M && col < N) {
acc[accum_row][n_b] = mul(acc[accum_row][n_b], acc_scale);
}
}
}
}
}
else {
int const NumVecPerBlock = ceil_div(kBlockN, kVectorSize);
// row major output
for (int m_b = 0; m_b < kBlockM; ++m_b) {
for (int v_b = 0; v_b < NumVecPerBlock; ++v_b) {
int64_t row = m + m_b;
/// Step1: get max across a vector
ElementCompute accum_max = ElementCompute(0);
for (int v = 0; v < kVectorSize; v++) {
int accum_col = v_b * kVectorSize + v;
int64_t output_col = accum_col + n;
if (row < M && output_col < N) {
accum_max = max_op(accum_max, abs_op(acc[m_b][accum_col]));
}
}
/// Step2: Compute Scale
ElementCompute pvscale = mul(accum_max, st_scaled_down);
ElementSfD qpvscale = static_cast<ElementSfD>(pvscale);
// Store the Scaling Factors
int64_t sf_col = n + kVectorSize * v_b;
if (row < M && sf_col < N) {
tensor_SfD(row, sf_col, l) = qpvscale;
}
/// Step3: Compute quantized output values
ElementCompute qpvscale_up = NumericConverter<ElementCompute, ElementSfD>{}(qpvscale);
// Get float reciprocal
ElementCompute qpvscale_rcp = div(1.0f, qpvscale_up);
ElementCompute acc_scale = mul(epilogue_params.st, qpvscale_rcp);
// Map INF to fp32::max
acc_scale = cutlass::minimum_with_nan_propagation<ElementCompute>{}(acc_scale, cutlass::platform::numeric_limits<ElementCompute>::max());
// Store the intermediate_accum
for (int v = 0; v < kVectorSize; v++) {
int accum_col = v_b * kVectorSize + v;
int64_t output_col = accum_col + n;
if (row < M && output_col < N) {
acc[m_b][accum_col] = mul(acc[m_b][accum_col], acc_scale);
}
}
}
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GETT - General Tensor-Tensor contraction reference kernel
@ -188,6 +530,11 @@ void gett_mainloop(
using ElementA = typename ElementTraits<typename MainloopParams::EngineA::value_type>::type;
using ElementB = typename ElementTraits<typename MainloopParams::EngineB::value_type>::type;
using ElementSFA = typename ElementTraits<typename MainloopParams::EngineSfA::value_type>::type;
using ElementSFB = typename ElementTraits<typename MainloopParams::EngineSfB::value_type>::type;
using RingOp = multiply_add<ElementAccumulator, ElementAccumulator, ElementAccumulator>;
RingOp fma_op;
@ -207,6 +554,14 @@ void gett_mainloop(
// Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type.
a_frag[m_b] = static_cast<ElementAccumulator>(ElementA(mainloop_params.A(m + m_b, k, l)));
if constexpr (not cute::is_same_v<ElementSFA, ElementA>){
// Load SFA
auto sfa = static_cast<ElementAccumulator>(mainloop_params.SfA(m + m_b, k, l));
a_frag[m_b] *= sfa;
}
if (mainloop_params.transform_A == ComplexTransform::kConjugate) {
a_frag[m_b] = conj(a_frag[m_b]);
}
@ -222,6 +577,14 @@ void gett_mainloop(
// Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type.
b_frag[n_b] = static_cast<ElementAccumulator>(ElementB(mainloop_params.B(n + n_b, k, l)));
if constexpr (not cute::is_same_v<ElementSFB, ElementB>){
// Load SFB
auto sfb = static_cast<ElementAccumulator>(mainloop_params.SfB(n + n_b, k, l));
b_frag[n_b] *= sfb;
}
if (mainloop_params.transform_B == ComplexTransform::kConjugate) {
b_frag[n_b] = conj(b_frag[n_b]);
}
@ -259,6 +622,7 @@ void gett_epilogue(
using ElementCompute = typename EpilogueParams::ElementCompute;
using ElementC = typename EpilogueParams::TensorC::value_type;
using ElementD = typename EpilogueParams::TensorD::value_type;
using ElementSfD = typename EpilogueParams::TensorSFD::value_type;
using ElementAux = typename EpilogueParams::TensorAux::value_type;
using ElementBias = typename EpilogueParams::VectorBias::value_type;
using ElementScalar = typename EpilogueParams::ElementScalar;
@ -267,6 +631,8 @@ void gett_epilogue(
using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp;
constexpr bool PerColBias = EpilogueParams::PerColumnBias;
constexpr SfStrategy SfGenStrategy = EpilogueParams::SfGenStrategy;
constexpr bool IsScalingAndAmaxOutputNeeded =
cute::is_same_v<ElementD, cutlass::float_e4m3_t> or
cute::is_same_v<ElementD, cutlass::float_e5m2_t>;
@ -412,6 +778,17 @@ void gett_epilogue(
}
}
} // m_b
if constexpr (
SfGenStrategy == SfStrategy::SfDGen
) {
// 1d scale factor generation
constexpr int kVectorSize = typename EpilogueParams::SFD_VectorSize{};
if (epilogue_params.SfD.data() != nullptr) {
compute_1d_scaling_factor_and_quantized_output<kVectorSize>(epilogue_params, epilogue_params.D, epilogue_params.SfD, m, n, l, inter_accum);
}
}
for (int m_b = 0; m_b < kBlockM; ++m_b) {
for (int n_b = 0; n_b < kBlockN; ++n_b) {
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) {