added support of b2b bmm (#849)

* added support of b2b bmm

* fixed arguments and params structures

* added batch_count argument

* removed SplitKSerial and added new test case with b2b bmm

* fixed support of Kbatched and added new test case with batch stride

* added batch support for bias and scale

* make test

* small changes

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Aleksandr Pivovar
2023-04-15 05:20:02 +02:00
committed by GitHub
parent d572cc1aab
commit 4a68cf748e
10 changed files with 765 additions and 409 deletions

View File

@ -69,7 +69,7 @@ __global__ void TensorScaleBiasGemm(
TensorRefScalar tensor_scale, ///< scale tensor
TensorRefScalar tensor_bias ///< bias tensor
) {
ConvertOp convert_op;
MatrixCoord output_coord(
@ -89,7 +89,7 @@ __global__ void TensorScaleBiasGemm(
ScalarType bias = ScalarType(0);
if(tensor_bias.good())
if(tensor_bias.good())
bias = tensor_bias.at({0, coord.column()});
tensor_out.at(coord) = convert_op(
@ -99,6 +99,70 @@ __global__ void TensorScaleBiasGemm(
}
}
template <
typename TensorRefIn, ///< Input TensorRef Type
typename TensorRefOut, ///< Output TensorRef Type
typename ScalarType, ///< alpha Type
typename TensorRefScalar, ///< Scale/Bias TensorRef Type
typename ConvertOp = NumericConverter<typename TensorRefOut::Element, ScalarType>,
int kMblock = 4,
int kNblock = 4
>
__global__ void TensorScaleBiasGemmBatched(
gemm::GemmCoord problem_size,
TensorRefIn tensor_in, ///< input tensor
TensorRefOut tensor_out, ///< output tensor
ScalarType alpha, ///< alpha
TensorRefScalar tensor_scale, ///< scale tensor
TensorRefScalar tensor_bias, ///< bias tensor
int batch_count = 1,
int64_t batch_stride_tensor_in = 0,
int64_t batch_stride_tensor_out = 0,
int64_t batch_stride_tensor_scale = 0,
int64_t batch_stride_tensor_bias = 0
) {
ConvertOp convert_op;
int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock;
int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock;
int batch_idx = blockIdx.z;
tensor_in.add_pointer_offset(batch_idx * batch_stride_tensor_in);
tensor_out.add_pointer_offset(batch_idx * batch_stride_tensor_out);
tensor_scale.add_pointer_offset(batch_idx * batch_stride_tensor_scale);
tensor_bias.add_pointer_offset(batch_idx * batch_stride_tensor_bias);
for (; batch_idx < batch_count; batch_idx += gridDim.z) {
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < kNblock; j++) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kMblock; i++) {
int row = row_block + i;
int col = col_block + j;
MatrixCoord coord = MatrixCoord(row, col);
if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) {
ScalarType scale = alpha;
if(tensor_scale.good())
scale = tensor_scale.at({0, coord.column()});
ScalarType bias = ScalarType(0);
if(tensor_bias.good())
bias = tensor_bias.at({0, coord.column()});
tensor_out.at(coord) = convert_op(
scale * ScalarType(tensor_in.at(coord)) + bias);
}
}
}
tensor_in.add_pointer_offset(batch_stride_tensor_in * gridDim.z);
tensor_out.add_pointer_offset(batch_stride_tensor_out * gridDim.z);
tensor_scale.add_pointer_offset(batch_stride_tensor_scale * gridDim.z);
tensor_bias.add_pointer_offset(batch_stride_tensor_bias * gridDim.z);
}
}
template <
typename TensorRefIn, ///< Input TensorRef Type
typename TensorRefOut, ///< Output TensorRef Type
@ -118,7 +182,7 @@ __global__ void TensorScaleBiasConv2d(
TensorRefScalar tensor_scale, ///< scale tensor
TensorRefScalar tensor_bias ///< bias tensor
) {
ConvertOp convert_op;
int64_t npq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
@ -137,7 +201,7 @@ __global__ void TensorScaleBiasConv2d(
int64_t npq = npq_start + m;
thread_n[m] = int(npq / PQ);
int64_t residual = npq % PQ;
thread_p[m] = int(residual / problem_size.Q);
thread_q[m] = int(residual % problem_size.Q);
@ -155,17 +219,17 @@ __global__ void TensorScaleBiasConv2d(
ScalarType scale = alpha;
if(tensor_scale.good())
scale = tensor_scale.at({0, thread_k});
ScalarType bias = ScalarType(0);
if(tensor_bias.good())
if(tensor_bias.good())
bias = tensor_bias.at({0, thread_k});
tensor_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op(
scale * ScalarType(
tensor_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k})
) + bias);
}
}
}
}
}
@ -217,6 +281,62 @@ void TensorScaleBiasGemm(
);
}
/// Apply scale and bias on a tensor
template <
typename ElementIn, ///< Input Type
typename ElementOut, ///< Output Type
typename Layout, ///< Layout of input/output tensor
typename ScalarType, ///< alpha Type
typename LayoutScaleBias, ///< Layout of scale and bias
typename ConvertOp = NumericConverter<ElementOut, ScalarType>
>
void TensorScaleBiasGemmBatched(
gemm::GemmCoord problem_size,
TensorRef<ElementIn, Layout> tensor_in, ///< input tensor
TensorRef<ElementOut, Layout> tensor_out, ///< output tensor
ScalarType alpha, ///< alpha
TensorRef<ScalarType, LayoutScaleBias> tensor_scale, ///< scale tensor
TensorRef<ScalarType, LayoutScaleBias> tensor_bias, ///< bias tensor
int batch_count = 1,
int64_t batch_stride_tensor_in = 0,
int64_t batch_stride_tensor_out = 0,
int64_t batch_stride_tensor_scale = 0,
int64_t batch_stride_tensor_bias = 0
) {
int const kMblock = 4;
int const kNblock = 4;
dim3 block(16, 8);
dim3 grid(
(problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock),
(problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock),
batch_count % std::numeric_limits<uint16_t>::max()
);
kernel::TensorScaleBiasGemmBatched<
TensorRef<ElementIn, Layout>,
TensorRef<ElementOut, Layout>,
ScalarType,
TensorRef<ScalarType, LayoutScaleBias>,
ConvertOp,
kMblock,
kNblock
><<< grid, block >>> (
problem_size,
tensor_in,
tensor_out,
alpha,
tensor_scale,
tensor_bias,
batch_count,
batch_stride_tensor_in,
batch_stride_tensor_out,
batch_stride_tensor_scale,
batch_stride_tensor_bias
);
}
/// Apply scale and bias on a tensor
template <
typename ElementIn, ///< Input Type