diff --git a/examples/python/CuTeDSL/ampere/elementwise_apply.py b/examples/python/CuTeDSL/ampere/elementwise_apply.py index e1e18729..b395e9f5 100644 --- a/examples/python/CuTeDSL/ampere/elementwise_apply.py +++ b/examples/python/CuTeDSL/ampere/elementwise_apply.py @@ -311,6 +311,7 @@ def run_elementwise_apply_and_verify( print("Compiling kernel with cute.compile ...") start_time = time.time() + compiled_func = cute.compile(elementwise_apply, op, from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic()) compilation_time = time.time() - start_time print(f"Compilation time: {compilation_time:.4f} seconds") @@ -321,9 +322,7 @@ def run_elementwise_apply_and_verify( current_stream = cuda.CUstream(torch_stream.cuda_stream) if not skip_ref_check: - elementwise_apply( - op, from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic() - ) + compiled_func(from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic()) print("Verifying results...") torch.testing.assert_close(op(a, b), c) print("Results verified successfully!") @@ -337,18 +336,14 @@ def run_elementwise_apply_and_verify( # Warmup for _ in range(warmup_iterations): - elementwise_apply( - op, from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic() - ) + compiled_func(from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic()) # Record start event cuda.cuEventRecord(start_event, current_stream) # Execute the kernel for _ in range(iterations): - elementwise_apply( - op, from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic() - ) + compiled_func(from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic()) # Record end event cuda.cuEventRecord(end_event, current_stream)