Files
cutlass/examples/88_hopper_fmha/kernel/fmha_tile_scheduler.hpp
2025-06-06 02:39:20 -04:00

205 lines
6.5 KiB
C++

/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/kernel_hardware_info.h"
namespace cutlass::fmha::kernel {
////////////////////////////////////////////////////////////////////////////////
struct IndividualTileScheduler {
struct Params {
dim3 grid;
};
bool valid_ = true;
CUTLASS_DEVICE
IndividualTileScheduler(Params const&) {}
template<class ProblemSize, class ClusterShape, class TileShape>
static Params to_underlying_arguments(
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
ClusterShape const& cluster_shape, TileShape const& tile_shape)
{
using namespace cute;
dim3 grid(round_up(ceil_div(size<2>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)), size<0>(problem_size), size<1>(problem_size));
return Params{ grid };
}
static dim3 get_grid_shape(Params const& params) {
return params.grid;
}
CUTLASS_DEVICE
bool is_valid() {
return valid_;
}
CUTLASS_DEVICE
auto get_block_coord() {
using namespace cute;
return make_coord(blockIdx.x, _0{}, make_coord(blockIdx.y, blockIdx.z));
}
CUTLASS_DEVICE
IndividualTileScheduler& operator++() {
valid_ = false;
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////
struct PersistentTileScheduler {
struct Params {
int num_blocks;
FastDivmod divmod_m_block;
FastDivmod divmod_b;
FastDivmod divmod_h;
KernelHardwareInfo hw_info;
};
int block_idx = 0;
Params params;
CUTLASS_DEVICE
PersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {}
template<class ProblemSize, class ClusterShape, class TileShape>
static Params to_underlying_arguments(
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
ClusterShape const& cluster_shape, TileShape const& tile_shape)
{
using namespace cute;
// Get SM count if needed, otherwise use user supplied SM count
int sm_count = hw_info.sm_count;
if (sm_count <= 0) {
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
}
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
hw_info.sm_count = sm_count;
int num_m_blocks = cutlass::round_up(ceil_div(size<2>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape));
int num_blocks = num_m_blocks * size<0>(problem_size) * size<1>(problem_size);
return Params {
num_blocks,
{ num_m_blocks}, { size<0>(problem_size) }, { size<1>(problem_size) },
hw_info
};
}
static dim3 get_grid_shape(Params const& params) {
dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1);
return grid;
}
CUTLASS_DEVICE
bool is_valid() {
return block_idx < params.num_blocks;
}
CUTLASS_DEVICE
auto get_block_coord() {
using namespace cute;
int block_decode = block_idx;
int m_block, bidb, bidh;
params.divmod_m_block(block_decode, m_block, block_decode);
params.divmod_b(block_decode, bidb, block_decode);
params.divmod_h(block_decode, bidh, block_decode);
return make_coord(m_block, _0{}, make_coord(bidb, bidh));
}
CUTLASS_DEVICE
PersistentTileScheduler& operator++() {
block_idx += gridDim.x;
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////
template<typename Base>
struct TileSchedulerBwdAdapter {
using Params = typename Base::Params;
Base base_;
CUTLASS_DEVICE
TileSchedulerBwdAdapter(Params const& params) : base_(params) {}
template<class ProblemSize, class ClusterShape, class TileShape>
static Params to_underlying_arguments(
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
ClusterShape const& cluster_shape, TileShape const& tile_shape)
{
using namespace cute;
return Base::to_underlying_arguments(select<0,1,3,2,4>(problem_size), hw_info, select<1,0,2>(cluster_shape), select<1,0,2>(tile_shape));
}
static dim3 get_grid_shape(Params const& params) {
return Base::get_grid_shape(params);
}
CUTLASS_DEVICE
bool is_valid() {
return base_.is_valid();
}
CUTLASS_DEVICE
auto get_block_coord() {
using namespace cute;
return select<1,0,2>(base_.get_block_coord());
}
CUTLASS_DEVICE
TileSchedulerBwdAdapter& operator++() {
++base_;
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::fmha::kernel