Use cudaMemcpyAsync in gemm grouped with kRequiresPrecomputation schedule. (#2256)

Co-authored-by: Yuhang Qi <qiyuhang@bytedance.com>
This commit is contained in:
Qi Yuhang
2025-05-01 03:28:05 +08:00
committed by GitHub
parent 2b78c2fe31
commit e5b810bed1

View File

@ -127,8 +127,8 @@ private:
}
/// Copy from `data` to `workspace`
Status copy_to_workspace(void* workspace, void* data, size_t bytes) {
cudaError_t cuda_error = cudaMemcpy(workspace, data, bytes, cudaMemcpyHostToDevice);
Status copy_to_workspace(void* workspace, void* data, size_t bytes, cudaStream_t stream = nullptr) {
cudaError_t cuda_error = cudaMemcpyAsync(workspace, data, bytes, cudaMemcpyHostToDevice, stream);
if (cuda_error != cudaSuccess) {
// Call cudaGetLastError() to clear the error bit
cuda_error = cudaGetLastError();
@ -142,14 +142,14 @@ private:
}
/// Precomputes scheduling information for the grouped GEMM
Status precompute(Arguments const &args, int32_t tile_count, void* workspace) {
Status precompute(Arguments const &args, int32_t tile_count, void* workspace, cudaStream_t stream = nullptr) {
size_t workspace_bytes = get_workspace_size(args);
std::vector<uint8_t> host_workspace(workspace_bytes);
BaseKernel::ProblemVisitor::host_precompute(args.host_problem_sizes,
args.problem_count,
args.threadblock_count,
(void*)host_workspace.data());
return copy_to_workspace(workspace, host_workspace.data(), workspace_bytes);
return copy_to_workspace(workspace, host_workspace.data(), workspace_bytes, stream);
}
/// Reorder `data` according to `indices`
@ -361,7 +361,7 @@ public:
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) {
int32_t tile_count = group_tile_count(args);
Status status = precompute(args, tile_count, workspace);
Status status = precompute(args, tile_count, workspace, stream);
if (status != Status::kSuccess) {
return status;
}
@ -388,7 +388,7 @@ public:
}
/// Lightweight update given a subset of arguments
Status update(Arguments const &args, void *workspace = nullptr) {
Status update(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
size_t workspace_bytes = get_workspace_size(args);
@ -398,7 +398,7 @@ public:
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) {
int32_t tile_count = group_tile_count(args);
Status status = precompute(args, tile_count, workspace);
Status status = precompute(args, tile_count, workspace, stream);
if (status != Status::kSuccess) {
return status;
}