[Bugfix][Rocm] fix qr error when different inp shape (#25892)

Signed-off-by: Haoyang Li <lihaoyang0109@gmail.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: ilmarkov <markovilya197@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
haoyangli-amd
2025-10-14 01:04:21 +08:00
committed by GitHub
parent a1b2d658ee
commit 134f70b3ed
3 changed files with 96 additions and 11 deletions

View File

@ -22,13 +22,14 @@ template <typename AllReduceKernel, typename T>
__global__ __quickreduce_launch_bounds_two_shot__ static void
allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, uint32_t num_blocks,
int rank, uint8_t** dbuffer_list,
uint32_t data_offset, uint32_t flag_color) {
uint32_t data_offset, uint32_t flag_color,
int64_t data_size_per_phase) {
int block = blockIdx.x;
int grid = gridDim.x;
while (block < num_blocks) {
AllReduceKernel::run(A, B, N, block, rank, dbuffer_list, data_offset,
flag_color);
flag_color, data_size_per_phase);
block += grid;
flag_color++;
}
@ -41,21 +42,21 @@ allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, uint32_t num_blocks,
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
num_blocks, rank, dbuffer_list, data_offset, \
flag_color); \
flag_color, this->kMaxProblemSize); \
} else if (world_size == 4) { \
using LineCodec = __codec<T, 4>; \
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
num_blocks, rank, dbuffer_list, data_offset, \
flag_color); \
flag_color, this->kMaxProblemSize); \
} else if (world_size == 8) { \
using LineCodec = __codec<T, 8>; \
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
num_blocks, rank, dbuffer_list, data_offset, \
flag_color); \
flag_color, this->kMaxProblemSize); \
}
enum QuickReduceQuantLevel {

View File

@ -553,13 +553,12 @@ struct AllReduceTwoshot {
int const rank, // rank index
uint8_t** __restrict__ buffer_list, // communication buffers
uint32_t const data_offset, // offset to start of the data buffer
uint32_t flag_color) {
uint32_t flag_color, int64_t data_size_per_phase) {
// Topology
int thread = threadIdx.x + threadIdx.y * kWavefront;
uint8_t* rank_buffer = buffer_list[rank];
Codec codec(thread, rank);
int block_id = blockIdx.x;
int grid_size = gridDim.x;
// --------------------------------------------------------
// Read input into registers
int32x4_t tA[kAtoms];
@ -588,12 +587,10 @@ struct AllReduceTwoshot {
// rank responsible for this segment.
uint32_t comm_data0_offset =
data_offset + block_id * Codec::kTransmittedTileSize;
uint32_t comm_data1_offset =
grid_size * Codec::kTransmittedTileSize + comm_data0_offset;
uint32_t comm_data1_offset = data_size_per_phase + comm_data0_offset;
uint32_t comm_flags0_offset = block_id * (kWorldSize * sizeof(uint32_t));
uint32_t comm_flags1_offset =
grid_size * (kWorldSize * sizeof(uint32_t)) + comm_flags0_offset;
uint32_t comm_flags1_offset = (data_offset / 2) + comm_flags0_offset;
for (int r = 0; r < kWorldSize; r++) {
int32x4_t* send_buffer =