Files
cutlass/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp
2023-04-29 09:34:27 -04:00

183 lines
8.0 KiB
C++

/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/fast_math.h"
#include "cutlass/kernel_hardware_info.hpp"
#include "cute/layout.hpp"
namespace cutlass::gemm::kernel::detail {
///////////////////////////////////////////////////////////////////////////////
// Persistent Thread Block (TB) scheduler
class PersistentTileSchedulerSm90 {
//
// Data members
//
private:
uint64_t current_work_linear_idx_{static_cast<uint64_t>((int(blockIdx.x) * int(gridDim.y)) + int(blockIdx.y))};
uint64_t grid_blocks_total_{static_cast<uint64_t>(int(gridDim.x) * int(gridDim.y))};
struct WorkTileInfo {
int32_t M_idx = 0;
int32_t N_idx = 0;
int32_t L_idx = 0;
uint32_t is_valid_tile = false;
};
//
// Methods
//
public:
struct Params {
FastDivmodU64 divmod_batch_{};
FastDivmodU64 divmod_grid_y_{};
FastDivmodU64 divmod_blk_m_{};
uint64_t blocks_per_problem_ = 0;
};
template <class ProblemShapeMNKL, class TileShape, class ClusterShape>
static Params
to_underlying_arguments(ProblemShapeMNKL problem_shape_mnkl, TileShape tile_shape, ClusterShape cluster_shape) {
// We only need the tile and cluster shape during scheduler setup, so let FTAD do the magic
static_assert(is_static<TileShape>::value);
static_assert(is_static<ClusterShape>::value);
// Round up to nearest multiple of cluster dim along each mode
auto [problem_blocks_m, problem_blocks_n, problem_blocks_l] = get_tiled_blk_shape_mnl(
problem_shape_mnkl, tile_shape, cluster_shape);
return {
FastDivmodU64(problem_blocks_m * problem_blocks_n),
FastDivmodU64(size<1>(cluster_shape)),
FastDivmodU64(problem_blocks_m),
problem_blocks_m * problem_blocks_n * problem_blocks_l
};
}
PersistentTileSchedulerSm90() = default;
CUTLASS_DEVICE
WorkTileInfo
get_current_work(Params const& scheduler_params) const {
// Map worker's linear index into the CTA tiled problem shape to the corresponding MNL indices
uint64_t work_idx_l, remainder;
scheduler_params.divmod_batch_(work_idx_l, remainder, current_work_linear_idx_);
uint64_t blk_per_grid_dim, dontcare;
scheduler_params.divmod_grid_y_(blk_per_grid_dim, dontcare, remainder);
uint64_t block_idx_m, block_idx_n;
scheduler_params.divmod_blk_m_(block_idx_n, block_idx_m, blk_per_grid_dim);
int32_t work_idx_m = static_cast<int32_t>(block_idx_m);
int32_t work_idx_n = static_cast<int32_t>((block_idx_n * gridDim.y) + blockIdx.y);
return {work_idx_m, work_idx_n, static_cast<int32_t>(work_idx_l), current_work_linear_idx_ < scheduler_params.blocks_per_problem_};
}
CUTLASS_DEVICE
void
advance_to_next_work(uint32_t advance_count = 1) {
current_work_linear_idx_ += grid_blocks_total_ * advance_count;
}
// Given the inputs, computes the total number of output blocks this problem will compute over
// Note that this is only the logical size of our grid, not the physical grid we will actually launch.
template<class ProblemShapeMNKL, class BlockShape, class ClusterShape>
CUTLASS_HOST_DEVICE constexpr static
dim3
get_tiled_blk_shape_mnl(ProblemShapeMNKL problem_shape_mnkl, BlockShape blk_shape, ClusterShape cluster_shape) {
// Across M and N is our Cluster tile, so we must round up the blocks to the nearest whole number of Cluster tiles
auto blk_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shape_mnkl), cute::shape<0>(blk_shape)));
auto blk_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shape_mnkl), cute::shape<1>(blk_shape)));
// Round up to nearest multiple of cluster dim along each mode
int problem_blocks_m = round_up(blk_m, cute::size<0>(cluster_shape));
int problem_blocks_n = round_up(blk_n, cute::size<1>(cluster_shape));
// Cluster tile does not span the batch mode, so no extra rounding up required for it
int problem_blocks_l = int(cute::size<3>(problem_shape_mnkl));
return {uint32_t(problem_blocks_m), uint32_t(problem_blocks_n), uint32_t(problem_blocks_l)};
}
// Given the inputs, computes the physical grid we should launch.
template<class ProblemShapeMNKL, class BlockShape, class ClusterShape>
CUTLASS_HOST_DEVICE constexpr static
dim3
get_grid_shape(ProblemShapeMNKL problem_shape_mnk, BlockShape blk_shape, ClusterShape cluster_shape, KernelHardwareInfo hw_info) {
int const sm_count = hw_info.sm_count;
CUTLASS_TRACE_HOST("get_grid_shape(): Persistent schedule grid plan using SM count = " << sm_count);
// Compute the total number of output tiles our problem has
auto problem_shape_MNKL = append<4>(problem_shape_mnk, Int<1>{});
auto [problem_blocks_m, problem_blocks_n, problem_blocks_l] =
get_tiled_blk_shape_mnl(problem_shape_MNKL, blk_shape, cluster_shape);
int problem_blocks_total = problem_blocks_m * problem_blocks_n * problem_blocks_l;
dim3 launch_grid(1, cute::size<1>(cluster_shape), 1);
// The else path is generic, however, we can avoid some divs if we know Cluster size is 1
if constexpr (size(cluster_shape) == 1) {
launch_grid.x = std::min(sm_count, problem_blocks_total);
}
else {
/*
* Optimal grid size calculation is based on
* GH100: 8 GPCs, 72 TPCs (9 TPCs/GPC), 2 SMs/TPC, 144 SMs per full GPU
* Hence, maximum SMs per GPC = 18
*/
constexpr int max_sm_per_gpc = 18;
// Provided SM count could possibly be less than the assumed maximum SMs per GPC
int const min_num_gpc = sm_count < max_sm_per_gpc ? 1 : sm_count / max_sm_per_gpc;
int const max_blk_occupancy_per_gpc = max_sm_per_gpc - (max_sm_per_gpc % size(cluster_shape));
int blk_per_device = min_num_gpc * max_blk_occupancy_per_gpc;
// The calculation below allows for larger grid size launch for different GPUs.
int const num_gpc_residual = sm_count < max_sm_per_gpc ? 0 : sm_count % max_sm_per_gpc;
int const max_blk_occupancy_per_residual_gpc = num_gpc_residual - (num_gpc_residual % size(cluster_shape));
blk_per_device += max_blk_occupancy_per_residual_gpc;
blk_per_device = sm_count < blk_per_device ? sm_count : blk_per_device;
launch_grid.x = std::min(
blk_per_device / size<1>(cluster_shape),
problem_blocks_total / size<1>(cluster_shape));
}
return launch_grid;
}
};
} // namespace cutlass::gemm::kernel::detail