Files
cutlass/examples/python/CuTeDSL/ampere/inline_ptx.py
Junkai-Wu b1d6e2c9b3 v4.3 update. (#2709)
* v4.3 update.

* Update the cute_dsl_api changelog's doc link

* Update version to 4.3.0

* Update the example link

* Update doc to encourage user to install DSL from requirements.txt

---------

Co-authored-by: Larry Wu <larwu@nvidia.com>
2025-10-21 14:26:30 -04:00

246 lines
8.1 KiB
Python

# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from functools import partial
from typing import Union
import torch
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
from cutlass._mlir.dialects import llvm
from cutlass.cute.typing import Boolean, Int32, Int, Constexpr
from cutlass.cutlass_dsl import T, dsl_user_op
from cutlass.cute.arch.nvvm_wrappers import FULL_MASK, WARP_SIZE
"""
A simple example to show how to wrap PTX instructions by using inline_asm op in llvm dialect.
Situations like:
1. Instructions that are not already exposed by CuTe DSL via `nvvm` module
2. Sequences of instructions that the compiler otherwise does not generate optimally
motivate developers to inline PTX themselves.
In this example, we inline the vote.sync.ballot.b32, vote.sync.any.pred, vote.sync.all.pred,
vote.sync.uni.pred, and use the corresponding ops in nvvm_wrappers.py for the test.
You can refer to the documentation of `inline_asm op in llvm dialect <https://mlir.llvm.org/docs/Dialects/LLVM/#llvminline_asm-llvminlineasmop>`_
and `vote.sync <https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-vote-sync>`_
for more details.
To run this example:
.. code-block:: bash
python examples/ampere/inline_ptx.py
The example will run the vote kernel with inline PTX and nvvm dialect separately.
The results from inline PTX and nvvm dialect will be verified correspondingly.
"""
@dsl_user_op
def ptx_vote_sync_op(
pred: Boolean, kind: str, mask: Int = FULL_MASK, *, loc=None, ip=None
) -> Union[Int32, Boolean]:
return_type = Boolean
return_type_str = "pred"
return return_type(
llvm.inline_asm(
T.bool(),
[
Boolean(pred).ir_value(loc=loc, ip=ip),
Int32(mask).ir_value(loc=loc, ip=ip),
],
f"""{{\n\t
.reg .pred ps;\n\t
.reg .pred pd;\n\t
setp.ne.b32 ps, $1, 0;\n\t
vote.sync.{kind}.{return_type_str} pd, ps, $2;\n\t
selp.b32 $0, 1, 0, pd;\n\t
}}""",
"=r,r,i",
has_side_effects=True,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
)
ptx_vote_any_sync = partial(ptx_vote_sync_op, kind="any")
ptx_vote_all_sync = partial(ptx_vote_sync_op, kind="all")
ptx_vote_uni_sync = partial(ptx_vote_sync_op, kind="uni")
@dsl_user_op
def ptx_vote_ballot_sync(
pred: Boolean, mask: Int = FULL_MASK, *, loc=None, ip=None
) -> Union[Int32, Boolean]:
return_type = Int32
return_type_str = "b32"
return return_type(
llvm.inline_asm(
T.i32(),
[
Boolean(pred).ir_value(loc=loc, ip=ip),
Int32(mask).ir_value(loc=loc, ip=ip),
],
f"""{{\n\t
.reg .pred p;\n\t
setp.ne.b32 p, $1, 0;\n\t
vote.sync.ballot.{return_type_str} $0, p, $2;\n\t
}}""",
"=r,r,i",
has_side_effects=True,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
)
@cute.kernel
def vote_kernel(
mBallot: cute.Tensor,
mAny: cute.Tensor,
mAll: cute.Tensor,
mUni: cute.Tensor,
use_inline_ptx: Constexpr[bool],
):
tidx, _, _ = cute.arch.thread_idx()
vote_ballot = (
ptx_vote_ballot_sync(tidx < 10)
if use_inline_ptx
else cute.arch.vote_ballot_sync(tidx < 10)
)
vote_any = (
ptx_vote_any_sync(tidx < 10)
if use_inline_ptx
else cute.arch.vote_any_sync(tidx < 10)
)
vote_all = (
ptx_vote_all_sync(tidx < 10)
if use_inline_ptx
else cute.arch.vote_all_sync(tidx < 10)
)
vote_uni = (
ptx_vote_uni_sync(tidx < 10)
if use_inline_ptx
else cute.arch.vote_uni_sync(tidx < 10)
)
mBallot[tidx] = vote_ballot
mAny[tidx] = vote_any
mAll[tidx] = vote_all
mUni[tidx] = vote_uni
@cute.jit
def vote(
mBallot: cute.Tensor,
mAny: cute.Tensor,
mAll: cute.Tensor,
mUni: cute.Tensor,
use_inline_ptx: Constexpr[bool],
):
vote_kernel(
mBallot,
mAny,
mAll,
mUni,
use_inline_ptx,
).launch(
grid=[1, 1, 1],
block=[cute.size(WARP_SIZE, mode=[0]), 1, 1],
)
def run():
ballot_ptx = torch.randint(
0, 100, (WARP_SIZE,), device=torch.device("cuda"), dtype=torch.int32
)
any_ptx = torch.randint(
0, 2, (WARP_SIZE,), device=torch.device("cuda"), dtype=torch.bool
)
all_ptx = torch.randint(
0, 2, (WARP_SIZE,), device=torch.device("cuda"), dtype=torch.bool
)
uni_ptx = torch.randint(
0, 2, (WARP_SIZE,), device=torch.device("cuda"), dtype=torch.bool
)
mBallotPTX = from_dlpack(ballot_ptx).mark_layout_dynamic()
mAnyPTX = from_dlpack(any_ptx).mark_layout_dynamic()
mAllPTX = from_dlpack(all_ptx).mark_layout_dynamic()
mUniPTX = from_dlpack(uni_ptx).mark_layout_dynamic()
# get the results from ptx
vote(mBallotPTX, mAnyPTX, mAllPTX, mUniPTX, use_inline_ptx=True)
ballot_nvvm = torch.randint(
0, 100, (WARP_SIZE,), device=torch.device("cuda"), dtype=torch.int32
)
any_nvvm = torch.randint(
0, 2, (WARP_SIZE,), device=torch.device("cuda"), dtype=torch.bool
)
all_nvvm = torch.randint(
0, 2, (WARP_SIZE,), device=torch.device("cuda"), dtype=torch.bool
)
uni_nvvm = torch.randint(
0, 2, (WARP_SIZE,), device=torch.device("cuda"), dtype=torch.bool
)
mBallotNVVM = from_dlpack(ballot_nvvm).mark_layout_dynamic()
mAnyNVVM = from_dlpack(any_nvvm).mark_layout_dynamic()
mAllNVVM = from_dlpack(all_nvvm).mark_layout_dynamic()
mUniNVVM = from_dlpack(uni_nvvm).mark_layout_dynamic()
# get the results from nvvm
vote(mBallotNVVM, mAnyNVVM, mAllNVVM, mUniNVVM, use_inline_ptx=False)
print("Verifying ballot results...")
torch.testing.assert_close(ballot_ptx, ballot_nvvm)
print("Verifying any results...")
torch.testing.assert_close(any_ptx, any_nvvm)
print(torch.all(any_ptx == any(i < 10 for i in range(WARP_SIZE))))
assert torch.all(any_ptx == any(i < 10 for i in range(WARP_SIZE)))
print("Verifying all results...")
torch.testing.assert_close(all_ptx, all_nvvm)
assert torch.all(all_ptx == all(i < 10 for i in range(WARP_SIZE)))
print("Verifying uni results...")
torch.testing.assert_close(uni_ptx, uni_nvvm)
assert torch.all(uni_ptx == (len(set(i < 10 for i in range(WARP_SIZE))) == 1))
print("Results verified successfully!")
if __name__ == "__main__":
run()