Improvements for: Groupwise scaling along M for FP8 gemm (#2095)
* fix blockwise fp8 kernels Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> * wip, < 128 not working Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> * fix < 128 Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> * reduce diff Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> * review comments Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> * support partial n blocks Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> * fix build errors Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> --------- Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
@ -557,13 +557,13 @@ bool verify(const Options<RasterOrderOptions> &options) {
|
||||
auto blockscale_A = cute::make_tensor(blockscale_tensor_A.host_data(),
|
||||
cute::make_layout(
|
||||
cute::make_shape(blockscale_m, blockscale_k, options.l),
|
||||
cute::make_stride(blockscale_k, 1, blockscale_m * blockscale_k)
|
||||
cute::make_stride(1, blockscale_m, blockscale_m * blockscale_k)
|
||||
)
|
||||
);
|
||||
auto blockscale_B = cute::make_tensor(blockscale_tensor_B.host_data(),
|
||||
cute::make_layout(
|
||||
cute::make_shape(blockscale_n, blockscale_k, options.l),
|
||||
cute::make_stride(blockscale_k, 1, blockscale_n * blockscale_k)
|
||||
cute::make_stride(1, blockscale_n, blockscale_n * blockscale_k)
|
||||
)
|
||||
);
|
||||
|
||||
|
||||
@ -396,14 +396,17 @@ template <typename GroupScaleConfig>
|
||||
void initialize(const Options<RasterOrderOptions> &options) {
|
||||
|
||||
using TileShape = typename GroupScaleConfig::TileShape;
|
||||
const int ScaleMsPerTile = GroupScaleConfig::ScaleMsPerTile;
|
||||
const int ScaleNsPerTile = GroupScaleConfig::ScaleNsPerTile;
|
||||
const int ScaleGranularityM = GroupScaleConfig::ScaleGranularityM;
|
||||
const int ScaleGranularityN = GroupScaleConfig::ScaleGranularityN;
|
||||
|
||||
assert(options.m % ScaleGranularityM == 0);
|
||||
assert(options.n % ScaleGranularityN == 0);
|
||||
|
||||
// Find Group Scaling tensor shapes based on `ScaleGranularityM`, problem shape, and TileShape
|
||||
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
|
||||
auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape{})));
|
||||
auto groupscale_m = cute::get<0>(blockscale_shape) * ScaleMsPerTile; // We need to pad along M in scale tensor of A to prevent illegal memory access.
|
||||
auto groupscale_n = cute::get<1>(blockscale_shape) * ScaleNsPerTile; // We need to pad along N in scale tensor of A to prevent illegal memory access.
|
||||
auto groupscale_m = cute::get<0>(gemm_problem_shape) / ScaleGranularityM;
|
||||
auto groupscale_n = cute::get<1>(gemm_problem_shape) / ScaleGranularityN;
|
||||
auto blockscale_k = cute::get<2>(blockscale_shape);
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
@ -575,6 +578,8 @@ bool verify(const Options<RasterOrderOptions> &options, const int ScaleMsPerTile
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
const int ScaleGranularityM = get<0>(TileShape_{}) / ScaleMsPerTile;
|
||||
const int ScaleGranularityN = get<1>(TileShape_{}) / ScaleNsPerTile;
|
||||
|
||||
// Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape
|
||||
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
|
||||
@ -582,6 +587,8 @@ bool verify(const Options<RasterOrderOptions> &options, const int ScaleMsPerTile
|
||||
auto blockscale_m = cute::get<0>(blockscale_shape);
|
||||
auto blockscale_n = cute::get<1>(blockscale_shape);
|
||||
auto blockscale_k = cute::get<2>(blockscale_shape);
|
||||
auto groupscale_m = get<0>(gemm_problem_shape) / ScaleGranularityM;
|
||||
auto groupscale_n = get<1>(gemm_problem_shape) / ScaleGranularityN;
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
auto A = cute::make_tensor(tensor_A.host_data(),
|
||||
@ -617,14 +624,14 @@ bool verify(const Options<RasterOrderOptions> &options, const int ScaleMsPerTile
|
||||
|
||||
auto blockscale_A = cute::make_tensor(blockscale_tensor_A.host_data(),
|
||||
cute::make_layout(
|
||||
cute::make_shape(blockscale_m, ScaleMsPerTile, blockscale_k, options.l),
|
||||
cute::make_stride(blockscale_k * ScaleMsPerTile, 1, ScaleMsPerTile, blockscale_m * blockscale_k * ScaleMsPerTile)
|
||||
cute::make_shape(groupscale_m, blockscale_k, options.l),
|
||||
cute::make_stride(1, groupscale_m, groupscale_m * blockscale_k)
|
||||
)
|
||||
);
|
||||
auto blockscale_B = cute::make_tensor(blockscale_tensor_B.host_data(),
|
||||
cute::make_layout(
|
||||
cute::make_shape(blockscale_n, ScaleNsPerTile, blockscale_k, options.l),
|
||||
cute::make_stride(blockscale_k * ScaleNsPerTile, 1, ScaleNsPerTile, blockscale_n * blockscale_k * ScaleNsPerTile)
|
||||
cute::make_shape(groupscale_n, blockscale_k, options.l),
|
||||
cute::make_stride(1, groupscale_n, groupscale_n * blockscale_k)
|
||||
)
|
||||
);
|
||||
|
||||
@ -708,6 +715,31 @@ int run(Options<RasterOrderOptions> &options)
|
||||
const int ScaleMsPerTile = GroupScaleConfig::ScaleMsPerTile;
|
||||
const int ScaleNsPerTile = GroupScaleConfig::ScaleNsPerTile;
|
||||
|
||||
bool skip = false;
|
||||
|
||||
if (options.m % ScaleGranularityM != 0) {
|
||||
std::cout << "Skippig (m size: " << options.m << " less then ScaleGranularityM: " << ScaleGranularityM << "):" << std::endl;
|
||||
skip = true;
|
||||
}
|
||||
|
||||
if (options.n % ScaleGranularityN != 0) {
|
||||
std::cout << "Skippig (n size: " << options.m << " less then ScaleGranularityN: " << ScaleGranularityM << "):" << std::endl;
|
||||
skip = true;
|
||||
}
|
||||
|
||||
if (options.k % size<2>(TileShape{}) != 0) {
|
||||
std::cout << "Skippig (k size: " << options.k << " less then TileShape[2]: " << size<2>(TileShape{}) << "):" << std::endl;
|
||||
skip = true;
|
||||
}
|
||||
|
||||
if (!skip) std::cout << "Running: " << std::endl;
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
|
||||
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
|
||||
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
|
||||
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
|
||||
|
||||
if (skip) return -1;
|
||||
|
||||
initialize<GroupScaleConfig>(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
@ -768,10 +800,6 @@ int run(Options<RasterOrderOptions> &options)
|
||||
raster = "Along M";
|
||||
}
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
|
||||
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
|
||||
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
|
||||
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
|
||||
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
|
||||
@ -217,15 +217,19 @@ void gett_mainloop(
|
||||
}
|
||||
}
|
||||
|
||||
int64_t block_m = m / kBlockM;
|
||||
int64_t block_n = n / kBlockN;
|
||||
cute::Tensor blockscale_A = mainloop_params.ScaleA(block_m, _, _, l);
|
||||
cute::Tensor blockscale_B = mainloop_params.ScaleB(block_n, _, _, l);
|
||||
const int M = cute::size<0>(mainloop_params.A.layout());
|
||||
const int N = cute::size<0>(mainloop_params.B.layout());
|
||||
const int ScaleGranularityM = M / cute::size<0>(mainloop_params.ScaleA);
|
||||
const int ScaleGranularityN = N / cute::size<0>(mainloop_params.ScaleB);
|
||||
assert(ScaleGranularityM && M % ScaleGranularityM == 0
|
||||
&& "ScaleGranularityM must divide M");
|
||||
assert(ScaleGranularityN && N % ScaleGranularityN == 0
|
||||
&& "ScaleGranularityN must divide N");
|
||||
|
||||
const int ScaleGranularityM = cute::size<0>(typename MainloopParams::TileShape{}) / cute::size<1>(mainloop_params.ScaleA.shape());
|
||||
const int ScaleGranularityN = cute::size<1>(typename MainloopParams::TileShape{}) / cute::size<1>(mainloop_params.ScaleB.shape());
|
||||
assert(cute::size<0>(typename MainloopParams::TileShape{}) == ScaleGranularityM * cute::size<1>(mainloop_params.ScaleA.shape()));
|
||||
assert(cute::size<1>(typename MainloopParams::TileShape{}) == ScaleGranularityN * cute::size<1>(mainloop_params.ScaleB.shape()));
|
||||
cute::Tensor blockscale_A = domain_offset(
|
||||
make_coord(m / ScaleGranularityM, _0{}), mainloop_params.ScaleA(_, _, l));
|
||||
cute::Tensor blockscale_B = domain_offset(
|
||||
make_coord(n / ScaleGranularityN, _0{}), mainloop_params.ScaleB(_, _, l));
|
||||
|
||||
// Compute on this k-block
|
||||
for (int64_t k = 0; k < cute::size<1>(mainloop_params.A.layout()); ++k) {
|
||||
@ -257,9 +261,12 @@ void gett_mainloop(
|
||||
}
|
||||
}
|
||||
|
||||
int m_size = std::min(static_cast<int64_t>(kBlockM), cute::size<0>(mainloop_params.A.layout()) - m);
|
||||
int n_size = std::min(static_cast<int64_t>(kBlockN), cute::size<0>(mainloop_params.B.layout()) - n);
|
||||
|
||||
// do compute
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
for (int m_b = 0; m_b < m_size; ++m_b) {
|
||||
for (int n_b = 0; n_b < n_size; ++n_b) {
|
||||
acc_temp[m_b][n_b] = fma_op(a_frag[m_b], b_frag[n_b], acc_temp[m_b][n_b]);
|
||||
}
|
||||
}
|
||||
@ -269,9 +276,9 @@ void gett_mainloop(
|
||||
// (b) Zero-out partial temporary (acc_temp),
|
||||
// (c) Update permanent (accu)
|
||||
if ((k+1) % kBlockK == 0) {
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
for (int m_b = 0; m_b < m_size; ++m_b) {
|
||||
auto scale_a_m_b = scale_a[m_b / ScaleGranularityM];
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
for (int n_b = 0; n_b < n_size; ++n_b) {
|
||||
auto scale_b_n_b = scale_b[n_b / ScaleGranularityN];
|
||||
ElementAccumulator blockwise_scaled_accum = acc_temp[m_b][n_b] * scale_a_m_b * scale_b_n_b;
|
||||
acc[m_b][n_b] = blockwise_scaled_accum + acc[m_b][n_b];
|
||||
|
||||
Reference in New Issue
Block a user