added support of b2b bmm (#849)
* added support of b2b bmm * fixed arguments and params structures * added batch_count argument * removed SplitKSerial and added new test case with b2b bmm * fixed support of Kbatched and added new test case with batch stride * added batch support for bias and scale * make test * small changes --------- Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
committed by
GitHub
parent
d572cc1aab
commit
4a68cf748e
@ -69,7 +69,7 @@ __global__ void TensorScaleBiasGemm(
|
||||
TensorRefScalar tensor_scale, ///< scale tensor
|
||||
TensorRefScalar tensor_bias ///< bias tensor
|
||||
) {
|
||||
|
||||
|
||||
ConvertOp convert_op;
|
||||
|
||||
MatrixCoord output_coord(
|
||||
@ -89,7 +89,7 @@ __global__ void TensorScaleBiasGemm(
|
||||
|
||||
ScalarType bias = ScalarType(0);
|
||||
|
||||
if(tensor_bias.good())
|
||||
if(tensor_bias.good())
|
||||
bias = tensor_bias.at({0, coord.column()});
|
||||
|
||||
tensor_out.at(coord) = convert_op(
|
||||
@ -99,6 +99,70 @@ __global__ void TensorScaleBiasGemm(
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename TensorRefIn, ///< Input TensorRef Type
|
||||
typename TensorRefOut, ///< Output TensorRef Type
|
||||
typename ScalarType, ///< alpha Type
|
||||
typename TensorRefScalar, ///< Scale/Bias TensorRef Type
|
||||
typename ConvertOp = NumericConverter<typename TensorRefOut::Element, ScalarType>,
|
||||
int kMblock = 4,
|
||||
int kNblock = 4
|
||||
>
|
||||
__global__ void TensorScaleBiasGemmBatched(
|
||||
gemm::GemmCoord problem_size,
|
||||
TensorRefIn tensor_in, ///< input tensor
|
||||
TensorRefOut tensor_out, ///< output tensor
|
||||
ScalarType alpha, ///< alpha
|
||||
TensorRefScalar tensor_scale, ///< scale tensor
|
||||
TensorRefScalar tensor_bias, ///< bias tensor
|
||||
int batch_count = 1,
|
||||
int64_t batch_stride_tensor_in = 0,
|
||||
int64_t batch_stride_tensor_out = 0,
|
||||
int64_t batch_stride_tensor_scale = 0,
|
||||
int64_t batch_stride_tensor_bias = 0
|
||||
) {
|
||||
|
||||
ConvertOp convert_op;
|
||||
int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock;
|
||||
int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock;
|
||||
int batch_idx = blockIdx.z;
|
||||
|
||||
tensor_in.add_pointer_offset(batch_idx * batch_stride_tensor_in);
|
||||
tensor_out.add_pointer_offset(batch_idx * batch_stride_tensor_out);
|
||||
tensor_scale.add_pointer_offset(batch_idx * batch_stride_tensor_scale);
|
||||
tensor_bias.add_pointer_offset(batch_idx * batch_stride_tensor_bias);
|
||||
|
||||
for (; batch_idx < batch_count; batch_idx += gridDim.z) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < kNblock; j++) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kMblock; i++) {
|
||||
int row = row_block + i;
|
||||
int col = col_block + j;
|
||||
MatrixCoord coord = MatrixCoord(row, col);
|
||||
if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) {
|
||||
|
||||
ScalarType scale = alpha;
|
||||
if(tensor_scale.good())
|
||||
scale = tensor_scale.at({0, coord.column()});
|
||||
|
||||
ScalarType bias = ScalarType(0);
|
||||
|
||||
if(tensor_bias.good())
|
||||
bias = tensor_bias.at({0, coord.column()});
|
||||
|
||||
tensor_out.at(coord) = convert_op(
|
||||
scale * ScalarType(tensor_in.at(coord)) + bias);
|
||||
}
|
||||
}
|
||||
}
|
||||
tensor_in.add_pointer_offset(batch_stride_tensor_in * gridDim.z);
|
||||
tensor_out.add_pointer_offset(batch_stride_tensor_out * gridDim.z);
|
||||
tensor_scale.add_pointer_offset(batch_stride_tensor_scale * gridDim.z);
|
||||
tensor_bias.add_pointer_offset(batch_stride_tensor_bias * gridDim.z);
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename TensorRefIn, ///< Input TensorRef Type
|
||||
typename TensorRefOut, ///< Output TensorRef Type
|
||||
@ -118,7 +182,7 @@ __global__ void TensorScaleBiasConv2d(
|
||||
TensorRefScalar tensor_scale, ///< scale tensor
|
||||
TensorRefScalar tensor_bias ///< bias tensor
|
||||
) {
|
||||
|
||||
|
||||
ConvertOp convert_op;
|
||||
|
||||
int64_t npq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
|
||||
@ -137,7 +201,7 @@ __global__ void TensorScaleBiasConv2d(
|
||||
int64_t npq = npq_start + m;
|
||||
|
||||
thread_n[m] = int(npq / PQ);
|
||||
|
||||
|
||||
int64_t residual = npq % PQ;
|
||||
thread_p[m] = int(residual / problem_size.Q);
|
||||
thread_q[m] = int(residual % problem_size.Q);
|
||||
@ -155,17 +219,17 @@ __global__ void TensorScaleBiasConv2d(
|
||||
ScalarType scale = alpha;
|
||||
if(tensor_scale.good())
|
||||
scale = tensor_scale.at({0, thread_k});
|
||||
|
||||
|
||||
ScalarType bias = ScalarType(0);
|
||||
if(tensor_bias.good())
|
||||
if(tensor_bias.good())
|
||||
bias = tensor_bias.at({0, thread_k});
|
||||
|
||||
|
||||
tensor_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op(
|
||||
scale * ScalarType(
|
||||
tensor_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k})
|
||||
) + bias);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -217,6 +281,62 @@ void TensorScaleBiasGemm(
|
||||
);
|
||||
}
|
||||
|
||||
/// Apply scale and bias on a tensor
|
||||
template <
|
||||
typename ElementIn, ///< Input Type
|
||||
typename ElementOut, ///< Output Type
|
||||
typename Layout, ///< Layout of input/output tensor
|
||||
typename ScalarType, ///< alpha Type
|
||||
typename LayoutScaleBias, ///< Layout of scale and bias
|
||||
typename ConvertOp = NumericConverter<ElementOut, ScalarType>
|
||||
>
|
||||
void TensorScaleBiasGemmBatched(
|
||||
gemm::GemmCoord problem_size,
|
||||
TensorRef<ElementIn, Layout> tensor_in, ///< input tensor
|
||||
TensorRef<ElementOut, Layout> tensor_out, ///< output tensor
|
||||
ScalarType alpha, ///< alpha
|
||||
TensorRef<ScalarType, LayoutScaleBias> tensor_scale, ///< scale tensor
|
||||
TensorRef<ScalarType, LayoutScaleBias> tensor_bias, ///< bias tensor
|
||||
int batch_count = 1,
|
||||
int64_t batch_stride_tensor_in = 0,
|
||||
int64_t batch_stride_tensor_out = 0,
|
||||
int64_t batch_stride_tensor_scale = 0,
|
||||
int64_t batch_stride_tensor_bias = 0
|
||||
) {
|
||||
|
||||
int const kMblock = 4;
|
||||
int const kNblock = 4;
|
||||
|
||||
dim3 block(16, 8);
|
||||
dim3 grid(
|
||||
(problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock),
|
||||
(problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock),
|
||||
batch_count % std::numeric_limits<uint16_t>::max()
|
||||
);
|
||||
|
||||
kernel::TensorScaleBiasGemmBatched<
|
||||
TensorRef<ElementIn, Layout>,
|
||||
TensorRef<ElementOut, Layout>,
|
||||
ScalarType,
|
||||
TensorRef<ScalarType, LayoutScaleBias>,
|
||||
ConvertOp,
|
||||
kMblock,
|
||||
kNblock
|
||||
><<< grid, block >>> (
|
||||
problem_size,
|
||||
tensor_in,
|
||||
tensor_out,
|
||||
alpha,
|
||||
tensor_scale,
|
||||
tensor_bias,
|
||||
batch_count,
|
||||
batch_stride_tensor_in,
|
||||
batch_stride_tensor_out,
|
||||
batch_stride_tensor_scale,
|
||||
batch_stride_tensor_bias
|
||||
);
|
||||
}
|
||||
|
||||
/// Apply scale and bias on a tensor
|
||||
template <
|
||||
typename ElementIn, ///< Input Type
|
||||
|
||||
Reference in New Issue
Block a user