Pre-compile in CuteDsl/ampere/elementwise_apply.py (#2340)
This commit is contained in:
@ -311,6 +311,7 @@ def run_elementwise_apply_and_verify(
|
||||
|
||||
print("Compiling kernel with cute.compile ...")
|
||||
start_time = time.time()
|
||||
compiled_func = cute.compile(elementwise_apply, op, from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic())
|
||||
compilation_time = time.time() - start_time
|
||||
print(f"Compilation time: {compilation_time:.4f} seconds")
|
||||
|
||||
@ -321,9 +322,7 @@ def run_elementwise_apply_and_verify(
|
||||
current_stream = cuda.CUstream(torch_stream.cuda_stream)
|
||||
|
||||
if not skip_ref_check:
|
||||
elementwise_apply(
|
||||
op, from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic()
|
||||
)
|
||||
compiled_func(from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic())
|
||||
print("Verifying results...")
|
||||
torch.testing.assert_close(op(a, b), c)
|
||||
print("Results verified successfully!")
|
||||
@ -337,18 +336,14 @@ def run_elementwise_apply_and_verify(
|
||||
|
||||
# Warmup
|
||||
for _ in range(warmup_iterations):
|
||||
elementwise_apply(
|
||||
op, from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic()
|
||||
)
|
||||
compiled_func(from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic())
|
||||
|
||||
# Record start event
|
||||
cuda.cuEventRecord(start_event, current_stream)
|
||||
|
||||
# Execute the kernel
|
||||
for _ in range(iterations):
|
||||
elementwise_apply(
|
||||
op, from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic()
|
||||
)
|
||||
compiled_func(from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic())
|
||||
|
||||
# Record end event
|
||||
cuda.cuEventRecord(end_event, current_stream)
|
||||
|
||||
Reference in New Issue
Block a user