From 8206e7a0f57a9a057cdd2c3bb4899bd5154a82e1 Mon Sep 17 00:00:00 2001 From: Gabriel Wu <13583761+lucifer1004@users.noreply.github.com> Date: Wed, 28 May 2025 22:24:39 +0800 Subject: [PATCH] Pre-compile in CuteDsl/ampere/elementwise_apply.py (#2340) --- examples/python/CuTeDSL/ampere/elementwise_apply.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/examples/python/CuTeDSL/ampere/elementwise_apply.py b/examples/python/CuTeDSL/ampere/elementwise_apply.py index e1e18729..b395e9f5 100644 --- a/examples/python/CuTeDSL/ampere/elementwise_apply.py +++ b/examples/python/CuTeDSL/ampere/elementwise_apply.py @@ -311,6 +311,7 @@ def run_elementwise_apply_and_verify( print("Compiling kernel with cute.compile ...") start_time = time.time() + compiled_func = cute.compile(elementwise_apply, op, from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic()) compilation_time = time.time() - start_time print(f"Compilation time: {compilation_time:.4f} seconds") @@ -321,9 +322,7 @@ def run_elementwise_apply_and_verify( current_stream = cuda.CUstream(torch_stream.cuda_stream) if not skip_ref_check: - elementwise_apply( - op, from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic() - ) + compiled_func(from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic()) print("Verifying results...") torch.testing.assert_close(op(a, b), c) print("Results verified successfully!") @@ -337,18 +336,14 @@ def run_elementwise_apply_and_verify( # Warmup for _ in range(warmup_iterations): - elementwise_apply( - op, from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic() - ) + compiled_func(from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic()) # Record start event cuda.cuEventRecord(start_event, current_stream) # Execute the kernel for _ in range(iterations): - elementwise_apply( - op, from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic() - ) + compiled_func(from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic()) # Record end event cuda.cuEventRecord(end_event, current_stream)