Pre-compile in CuteDsl/ampere/elementwise_apply.py (#2340)

This commit is contained in:
Gabriel Wu
2025-05-28 22:24:39 +08:00
committed by GitHub
parent 6316b6f867
commit 8206e7a0f5

View File

@ -311,6 +311,7 @@ def run_elementwise_apply_and_verify(
print("Compiling kernel with cute.compile ...")
start_time = time.time()
compiled_func = cute.compile(elementwise_apply, op, from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic())
compilation_time = time.time() - start_time
print(f"Compilation time: {compilation_time:.4f} seconds")
@ -321,9 +322,7 @@ def run_elementwise_apply_and_verify(
current_stream = cuda.CUstream(torch_stream.cuda_stream)
if not skip_ref_check:
elementwise_apply(
op, from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic()
)
compiled_func(from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic())
print("Verifying results...")
torch.testing.assert_close(op(a, b), c)
print("Results verified successfully!")
@ -337,18 +336,14 @@ def run_elementwise_apply_and_verify(
# Warmup
for _ in range(warmup_iterations):
elementwise_apply(
op, from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic()
)
compiled_func(from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic())
# Record start event
cuda.cuEventRecord(start_event, current_stream)
# Execute the kernel
for _ in range(iterations):
elementwise_apply(
op, from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic()
)
compiled_func(from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic())
# Record end event
cuda.cuEventRecord(end_event, current_stream)