Use cudaMemcpyAsync in gemm grouped with kRequiresPrecomputation schedule. (#2256)
Co-authored-by: Yuhang Qi <qiyuhang@bytedance.com>
This commit is contained in:
@ -127,8 +127,8 @@ private:
|
||||
}
|
||||
|
||||
/// Copy from `data` to `workspace`
|
||||
Status copy_to_workspace(void* workspace, void* data, size_t bytes) {
|
||||
cudaError_t cuda_error = cudaMemcpy(workspace, data, bytes, cudaMemcpyHostToDevice);
|
||||
Status copy_to_workspace(void* workspace, void* data, size_t bytes, cudaStream_t stream = nullptr) {
|
||||
cudaError_t cuda_error = cudaMemcpyAsync(workspace, data, bytes, cudaMemcpyHostToDevice, stream);
|
||||
if (cuda_error != cudaSuccess) {
|
||||
// Call cudaGetLastError() to clear the error bit
|
||||
cuda_error = cudaGetLastError();
|
||||
@ -142,14 +142,14 @@ private:
|
||||
}
|
||||
|
||||
/// Precomputes scheduling information for the grouped GEMM
|
||||
Status precompute(Arguments const &args, int32_t tile_count, void* workspace) {
|
||||
Status precompute(Arguments const &args, int32_t tile_count, void* workspace, cudaStream_t stream = nullptr) {
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
std::vector<uint8_t> host_workspace(workspace_bytes);
|
||||
BaseKernel::ProblemVisitor::host_precompute(args.host_problem_sizes,
|
||||
args.problem_count,
|
||||
args.threadblock_count,
|
||||
(void*)host_workspace.data());
|
||||
return copy_to_workspace(workspace, host_workspace.data(), workspace_bytes);
|
||||
return copy_to_workspace(workspace, host_workspace.data(), workspace_bytes, stream);
|
||||
}
|
||||
|
||||
/// Reorder `data` according to `indices`
|
||||
@ -361,7 +361,7 @@ public:
|
||||
|
||||
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) {
|
||||
int32_t tile_count = group_tile_count(args);
|
||||
Status status = precompute(args, tile_count, workspace);
|
||||
Status status = precompute(args, tile_count, workspace, stream);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
@ -388,7 +388,7 @@ public:
|
||||
}
|
||||
|
||||
/// Lightweight update given a subset of arguments
|
||||
Status update(Arguments const &args, void *workspace = nullptr) {
|
||||
Status update(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
|
||||
@ -398,7 +398,7 @@ public:
|
||||
|
||||
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) {
|
||||
int32_t tile_count = group_tile_count(args);
|
||||
Status status = precompute(args, tile_count, workspace);
|
||||
Status status = precompute(args, tile_count, workspace, stream);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user