[Bugfix][Rocm] fix qr error when different inp shape (#25892)
Signed-off-by: Haoyang Li <lihaoyang0109@gmail.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: ilmarkov <markovilya197@gmail.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
@ -22,13 +22,14 @@ template <typename AllReduceKernel, typename T>
|
||||
__global__ __quickreduce_launch_bounds_two_shot__ static void
|
||||
allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, uint32_t num_blocks,
|
||||
int rank, uint8_t** dbuffer_list,
|
||||
uint32_t data_offset, uint32_t flag_color) {
|
||||
uint32_t data_offset, uint32_t flag_color,
|
||||
int64_t data_size_per_phase) {
|
||||
int block = blockIdx.x;
|
||||
int grid = gridDim.x;
|
||||
|
||||
while (block < num_blocks) {
|
||||
AllReduceKernel::run(A, B, N, block, rank, dbuffer_list, data_offset,
|
||||
flag_color);
|
||||
flag_color, data_size_per_phase);
|
||||
block += grid;
|
||||
flag_color++;
|
||||
}
|
||||
@ -41,21 +42,21 @@ allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, uint32_t num_blocks,
|
||||
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
|
||||
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
|
||||
num_blocks, rank, dbuffer_list, data_offset, \
|
||||
flag_color); \
|
||||
flag_color, this->kMaxProblemSize); \
|
||||
} else if (world_size == 4) { \
|
||||
using LineCodec = __codec<T, 4>; \
|
||||
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
|
||||
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
|
||||
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
|
||||
num_blocks, rank, dbuffer_list, data_offset, \
|
||||
flag_color); \
|
||||
flag_color, this->kMaxProblemSize); \
|
||||
} else if (world_size == 8) { \
|
||||
using LineCodec = __codec<T, 8>; \
|
||||
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
|
||||
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
|
||||
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
|
||||
num_blocks, rank, dbuffer_list, data_offset, \
|
||||
flag_color); \
|
||||
flag_color, this->kMaxProblemSize); \
|
||||
}
|
||||
|
||||
enum QuickReduceQuantLevel {
|
||||
|
||||
@ -553,13 +553,12 @@ struct AllReduceTwoshot {
|
||||
int const rank, // rank index
|
||||
uint8_t** __restrict__ buffer_list, // communication buffers
|
||||
uint32_t const data_offset, // offset to start of the data buffer
|
||||
uint32_t flag_color) {
|
||||
uint32_t flag_color, int64_t data_size_per_phase) {
|
||||
// Topology
|
||||
int thread = threadIdx.x + threadIdx.y * kWavefront;
|
||||
uint8_t* rank_buffer = buffer_list[rank];
|
||||
Codec codec(thread, rank);
|
||||
int block_id = blockIdx.x;
|
||||
int grid_size = gridDim.x;
|
||||
// --------------------------------------------------------
|
||||
// Read input into registers
|
||||
int32x4_t tA[kAtoms];
|
||||
@ -588,12 +587,10 @@ struct AllReduceTwoshot {
|
||||
// rank responsible for this segment.
|
||||
uint32_t comm_data0_offset =
|
||||
data_offset + block_id * Codec::kTransmittedTileSize;
|
||||
uint32_t comm_data1_offset =
|
||||
grid_size * Codec::kTransmittedTileSize + comm_data0_offset;
|
||||
uint32_t comm_data1_offset = data_size_per_phase + comm_data0_offset;
|
||||
|
||||
uint32_t comm_flags0_offset = block_id * (kWorldSize * sizeof(uint32_t));
|
||||
uint32_t comm_flags1_offset =
|
||||
grid_size * (kWorldSize * sizeof(uint32_t)) + comm_flags0_offset;
|
||||
uint32_t comm_flags1_offset = (data_offset / 2) + comm_flags0_offset;
|
||||
|
||||
for (int r = 0; r < kWorldSize; r++) {
|
||||
int32x4_t* send_buffer =
|
||||
|
||||
Reference in New Issue
Block a user