v4.1 release
This commit is contained in:
@ -15,6 +15,8 @@ The preprocessor read through python's ast and changes the input code.
|
||||
"""
|
||||
|
||||
from typing import Callable, Iterator, Optional, overload
|
||||
from typing_extensions import deprecated
|
||||
import warnings
|
||||
|
||||
from .utils.logger import log
|
||||
from .common import *
|
||||
@ -30,13 +32,9 @@ class Executor:
|
||||
set_functions: Assigns the functions for checking loop bounds and
|
||||
conditional evaluation.
|
||||
|
||||
for_dynamic: Generates MLIR for OP
|
||||
for_constexpr: Executes a for loop at JIT compile-time
|
||||
for_execute: Decides whether to execute the loop at compile-time or generate MLIR for OP based on the provided bounds.
|
||||
|
||||
if_dynamic: Generates MLIR if OP
|
||||
if_constexpr: Executes a if at JIT compile-time by python interpreter
|
||||
if_execute: Decides whether to execute the if statement at compile-time or generate MLIR if OP based on the predicate.
|
||||
for_execute: Generates MLIR for OP
|
||||
while_execute: Generates MLIR while OP
|
||||
if_execute: generate MLIR if OP
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@ -44,6 +42,9 @@ class Executor:
|
||||
self._loop_execute_range_dynamic = None
|
||||
self._if_dynamic = None
|
||||
self._while_dynamic = None
|
||||
self._compare_executor = None
|
||||
self._any_executor = None
|
||||
self._all_executor = None
|
||||
|
||||
def set_functions(
|
||||
self,
|
||||
@ -51,11 +52,17 @@ class Executor:
|
||||
loop_execute_range_dynamic: Callable,
|
||||
if_dynamic: Callable,
|
||||
while_dynamic: Callable,
|
||||
compare_executor: Callable,
|
||||
any_executor: Callable = None,
|
||||
all_executor: Callable = None,
|
||||
):
|
||||
self._is_dynamic_expression = is_dynamic_expression
|
||||
self._loop_execute_range_dynamic = loop_execute_range_dynamic
|
||||
self._if_dynamic = if_dynamic
|
||||
self._while_dynamic = while_dynamic
|
||||
self._compare_executor = compare_executor
|
||||
self._any_executor = any_executor
|
||||
self._all_executor = all_executor
|
||||
|
||||
@staticmethod
|
||||
def convert_to_list(x):
|
||||
@ -83,31 +90,6 @@ class Executor:
|
||||
return res[0]
|
||||
return res
|
||||
|
||||
def for_dynamic(
|
||||
self,
|
||||
func: Callable,
|
||||
start,
|
||||
stop,
|
||||
step,
|
||||
used_args: list,
|
||||
iter_args: list,
|
||||
iter_arg_names: list,
|
||||
unroll=bool,
|
||||
unroll_full=int,
|
||||
):
|
||||
log().debug("start [%s] stop [%s] step [%s]", start, stop, step)
|
||||
return self._loop_execute_range_dynamic(
|
||||
func,
|
||||
start,
|
||||
stop,
|
||||
step,
|
||||
used_args,
|
||||
iter_args,
|
||||
iter_arg_names,
|
||||
unroll,
|
||||
unroll_full,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def for_constexpr(
|
||||
func: Callable,
|
||||
@ -143,44 +125,14 @@ class Executor:
|
||||
iter_arg_names=[],
|
||||
unroll=-1,
|
||||
unroll_full=False,
|
||||
is_range_constexpr=None,
|
||||
pipelining=None,
|
||||
):
|
||||
assert (
|
||||
self._loop_execute_range_dynamic and self._is_dynamic_expression
|
||||
self._loop_execute_range_dynamic
|
||||
), "Functions must be set before execution."
|
||||
log().debug("start [%s] stop [%s] step [%s]", start, stop, step)
|
||||
any_dynamic_expression = (
|
||||
self._is_dynamic_expression(start)
|
||||
or self._is_dynamic_expression(stop)
|
||||
or self._is_dynamic_expression(step)
|
||||
)
|
||||
|
||||
if is_range_constexpr is None:
|
||||
if not any_dynamic_expression:
|
||||
return self.for_constexpr(func, start, stop, step, used_args, iter_args)
|
||||
else:
|
||||
return self.for_dynamic(
|
||||
func,
|
||||
start,
|
||||
stop,
|
||||
step,
|
||||
used_args,
|
||||
iter_args,
|
||||
iter_arg_names,
|
||||
unroll,
|
||||
unroll_full,
|
||||
)
|
||||
|
||||
# Ensure bounds are compile-time constants for constexpr execution
|
||||
if is_range_constexpr:
|
||||
if any_dynamic_expression:
|
||||
raise DSLRuntimeError(
|
||||
"Loop bounds must be constexpr (compile-time constants)"
|
||||
)
|
||||
return self.for_constexpr(func, start, stop, step, used_args, iter_args)
|
||||
|
||||
# MLIR generation
|
||||
return self.for_dynamic(
|
||||
return self._loop_execute_range_dynamic(
|
||||
func,
|
||||
start,
|
||||
stop,
|
||||
@ -190,40 +142,9 @@ class Executor:
|
||||
iter_arg_names,
|
||||
unroll,
|
||||
unroll_full,
|
||||
pipelining,
|
||||
)
|
||||
|
||||
def if_dynamic(
|
||||
self,
|
||||
pred,
|
||||
then_block: Callable,
|
||||
else_block: Optional[Callable] = None,
|
||||
used_args=[],
|
||||
yield_args=[],
|
||||
yield_arg_names=[],
|
||||
):
|
||||
return self._if_dynamic(
|
||||
pred, then_block, else_block, used_args, yield_args, yield_arg_names
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def if_constexpr(
|
||||
pred,
|
||||
then_block: Callable,
|
||||
else_block: Optional[Callable] = None,
|
||||
used_args=[],
|
||||
yield_args=[],
|
||||
):
|
||||
if pred:
|
||||
log().debug(" running then block [%s]", yield_args)
|
||||
res = then_block(*used_args, *yield_args)
|
||||
log().debug("result [%s]", res)
|
||||
return Executor.converge_ret_val(res)
|
||||
elif else_block is not None:
|
||||
log().debug("running else [%s]", yield_args)
|
||||
res = else_block(*used_args, *yield_args)
|
||||
log().debug("result [%s]", res)
|
||||
return Executor.converge_ret_val(res)
|
||||
|
||||
def if_execute(
|
||||
self,
|
||||
pred,
|
||||
@ -232,94 +153,14 @@ class Executor:
|
||||
used_args=[],
|
||||
yield_args=[],
|
||||
yield_arg_names=[],
|
||||
if_constexpr=None,
|
||||
):
|
||||
assert (
|
||||
self._if_dynamic and self._is_dynamic_expression
|
||||
), "Functions must be set before execution."
|
||||
|
||||
is_if_constexpr = not self._is_dynamic_expression(pred)
|
||||
if if_constexpr is None:
|
||||
if is_if_constexpr:
|
||||
return self.if_constexpr(
|
||||
pred, then_block, else_block, used_args, yield_args
|
||||
)
|
||||
else:
|
||||
return self.if_dynamic(
|
||||
pred, then_block, else_block, used_args, yield_args, yield_arg_names
|
||||
)
|
||||
|
||||
# Ensure bounds are compile-time constants for constexpr execution
|
||||
if if_constexpr:
|
||||
if not is_if_constexpr:
|
||||
raise DSLRuntimeError(
|
||||
"If predicate must be constexpr (compile-time constants)"
|
||||
)
|
||||
return self.if_constexpr(
|
||||
pred, then_block, else_block, used_args, yield_args
|
||||
)
|
||||
assert self._if_dynamic, "Functions must be set before execution."
|
||||
|
||||
# MLIR generation
|
||||
return self.if_dynamic(
|
||||
return self._if_dynamic(
|
||||
pred, then_block, else_block, used_args, yield_args, yield_arg_names
|
||||
)
|
||||
|
||||
def while_dynamic(
|
||||
self,
|
||||
while_before_block: Callable,
|
||||
while_after_block: Callable,
|
||||
used_args=[],
|
||||
yield_args=[],
|
||||
yield_arg_names=[],
|
||||
):
|
||||
return self._while_dynamic(
|
||||
while_before_block,
|
||||
while_after_block,
|
||||
used_args,
|
||||
yield_args,
|
||||
yield_arg_names,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def while_constexpr(
|
||||
while_before_block,
|
||||
while_after_block,
|
||||
used_args=[],
|
||||
yield_args=[],
|
||||
):
|
||||
log().debug(
|
||||
"while_constexpr begin %s", while_before_block.__qualname__
|
||||
)
|
||||
cond, loop_results = while_before_block(*used_args, *yield_args)
|
||||
while cond:
|
||||
loop_results = Executor.convert_to_list(loop_results)
|
||||
log().debug(
|
||||
"calling while_after [%s], [%s]",
|
||||
used_args,
|
||||
loop_results,
|
||||
)
|
||||
loop_results = while_after_block(*used_args, *loop_results)
|
||||
log().debug(
|
||||
"while after [%s]", loop_results
|
||||
)
|
||||
loop_results = Executor.convert_to_list(loop_results)
|
||||
log().debug(
|
||||
"calling while_before [%s], [%s]",
|
||||
used_args,
|
||||
loop_results,
|
||||
)
|
||||
cond, loop_results = while_before_block(*used_args, *loop_results)
|
||||
log().debug(
|
||||
"while_before cond, results [%s], [%s]",
|
||||
cond,
|
||||
loop_results,
|
||||
)
|
||||
|
||||
log().debug(
|
||||
"while_constexpr results %s", loop_results
|
||||
)
|
||||
return Executor.converge_ret_val(loop_results)
|
||||
|
||||
def while_execute(
|
||||
self,
|
||||
pred,
|
||||
@ -328,26 +169,11 @@ class Executor:
|
||||
used_args=[],
|
||||
yield_args=[],
|
||||
yield_arg_names=[],
|
||||
while_constexpr=None,
|
||||
):
|
||||
assert (
|
||||
self._while_dynamic and self._is_dynamic_expression
|
||||
), "Functions must be set before execution."
|
||||
|
||||
is_while_constexpr = not self._is_dynamic_expression(pred)
|
||||
|
||||
# Ensure bounds are compile-time constants for constexpr execution
|
||||
if while_constexpr:
|
||||
if not is_while_constexpr:
|
||||
raise DSLRuntimeError(
|
||||
"While predicate must be constexpr (compile-time constants)"
|
||||
)
|
||||
return self.while_constexpr(
|
||||
while_before_block, while_after_block, used_args, yield_args
|
||||
)
|
||||
assert self._while_dynamic, "Functions must be set before execution."
|
||||
|
||||
# MLIR generation
|
||||
return self.while_dynamic(
|
||||
return self._while_dynamic(
|
||||
while_before_block,
|
||||
while_after_block,
|
||||
used_args,
|
||||
@ -367,15 +193,16 @@ def loop_selector(
|
||||
start,
|
||||
stop,
|
||||
step,
|
||||
*,
|
||||
used_args=[],
|
||||
iter_args=[],
|
||||
iter_arg_names=[],
|
||||
unroll=-1,
|
||||
unroll_full=False,
|
||||
constexpr=None,
|
||||
pipelining=None,
|
||||
):
|
||||
log().debug(
|
||||
"start [%s] stop [%s] step [%s] used_args [%s] iter_args [%s] unroll [%s] unroll_full [%s] constexpr [%s]",
|
||||
"start [%s] stop [%s] step [%s] used_args [%s] iter_args [%s] unroll [%s] unroll_full [%s] pipelining [%s]",
|
||||
start,
|
||||
stop,
|
||||
step,
|
||||
@ -383,7 +210,7 @@ def loop_selector(
|
||||
iter_args,
|
||||
unroll,
|
||||
unroll_full,
|
||||
constexpr,
|
||||
pipelining,
|
||||
)
|
||||
from .typing import Integer, Numeric
|
||||
|
||||
@ -408,7 +235,7 @@ def loop_selector(
|
||||
iter_arg_names,
|
||||
unroll,
|
||||
unroll_full,
|
||||
constexpr,
|
||||
pipelining,
|
||||
)
|
||||
|
||||
return ir_loop
|
||||
@ -443,7 +270,6 @@ def while_executor(
|
||||
used_args=[],
|
||||
yield_args=[],
|
||||
yield_arg_names=[],
|
||||
constexpr=None,
|
||||
):
|
||||
return executor.while_execute(
|
||||
pred,
|
||||
@ -452,7 +278,6 @@ def while_executor(
|
||||
used_args,
|
||||
yield_args,
|
||||
yield_arg_names,
|
||||
constexpr,
|
||||
)
|
||||
|
||||
|
||||
@ -463,10 +288,9 @@ def if_executor(
|
||||
used_args=[],
|
||||
yield_args=[],
|
||||
yield_arg_names=[],
|
||||
constexpr=None,
|
||||
):
|
||||
return executor.if_execute(
|
||||
pred, then_block, else_block, used_args, yield_args, yield_arg_names, constexpr
|
||||
pred, then_block, else_block, used_args, yield_args, yield_arg_names
|
||||
)
|
||||
|
||||
|
||||
@ -475,75 +299,70 @@ def if_executor(
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class range_dynamic:
|
||||
class range:
|
||||
@overload
|
||||
def __new__(cls, stop, unroll=0, unroll_full=False):
|
||||
def __new__(cls, stop, unroll=0, unroll_full=False, pipelining=None):
|
||||
pass
|
||||
|
||||
@overload
|
||||
def __new__(cls, start, stop, step, unroll=0, unroll_full=False):
|
||||
def __new__(cls, start, stop, step, unroll=0, unroll_full=False, pipelining=None):
|
||||
pass
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
raise DSLRuntimeError("range_dynamic should be always preprocessed to IR")
|
||||
|
||||
|
||||
class range_constexpr:
|
||||
def __init__(self, *args):
|
||||
if len(args) == 1:
|
||||
self.start = 0
|
||||
self.stop = args[0]
|
||||
self.step = 1
|
||||
elif len(args) == 2:
|
||||
self.start, self.stop = args
|
||||
self.step = 1
|
||||
elif len(args) == 3:
|
||||
self.start, self.stop, self.step = args
|
||||
else:
|
||||
raise DSLRuntimeError(
|
||||
"range_constexpr supports up to 3 arguments (start, stop, step)"
|
||||
)
|
||||
# Ensure the arguments are compile-time constants (if required)
|
||||
for arg_name, arg_value in [
|
||||
("step", self.step),
|
||||
("start", self.start),
|
||||
("stop", self.stop),
|
||||
]:
|
||||
if executor._is_dynamic_expression(arg_value):
|
||||
raise DSLRuntimeError(
|
||||
f"`range_constexpr` requires `constexpr` (non-IR Values) for all arguments, "
|
||||
f"but `{arg_name}` is not. If the arguments are dynamic, use `range`; the DSL "
|
||||
f"will handle them during runtime. ",
|
||||
suggestion="Use `range` instead of `range_constexpr`.",
|
||||
)
|
||||
raise DSLRuntimeError("dynamic range should be always preprocessed to IR")
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
current = self.start
|
||||
while current < self.stop:
|
||||
yield current
|
||||
current += self.step
|
||||
raise DSLRuntimeError("dynamic range should be always preprocessed to IR")
|
||||
|
||||
|
||||
@deprecated(
|
||||
"range_dynamic is deprecated and will be removed in the future, please remove it."
|
||||
)
|
||||
def range_dynamic(*args, **kwargs):
|
||||
raise DSLRuntimeError("range_dynamic should be always preprocessed to IR")
|
||||
|
||||
|
||||
def range_constexpr(*args):
|
||||
raise DSLRuntimeError("range_constexpr should be preprocessed by preprocessor.")
|
||||
|
||||
# =============================================================================
|
||||
# If expressions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def const_expr(expression):
|
||||
if executor._is_dynamic_expression(expression):
|
||||
"""
|
||||
This function is used to check if the expression is a python value.
|
||||
If the expression is a python value, return the boolean value of the expression.
|
||||
If the expression is a dynamic expression, raise an error.
|
||||
"""
|
||||
from .typing import Numeric
|
||||
|
||||
failed = False
|
||||
|
||||
if isinstance(expression, Numeric):
|
||||
if isinstance(expression.value, (int, float, bool)):
|
||||
return expression.value
|
||||
else:
|
||||
failed = True
|
||||
elif executor._is_dynamic_expression(expression):
|
||||
failed = True
|
||||
|
||||
if failed:
|
||||
raise DSLRuntimeError(
|
||||
f"The function `const_expr({expression})` received a dynamic expression (non compile-time constant).",
|
||||
context={
|
||||
"const_expr": "Accepts only constexpr (compile-time constant)",
|
||||
"If your expression depends on dynamic values": "Avoid marking it as `const_expr()`",
|
||||
"If the expression could be either dynamic or constexpr": "Omit explicit `const_expr()` marker; the DSL will infer the correct handling automatically",
|
||||
"If your expression depends on dynamic values": "Remove `const_expr()`",
|
||||
},
|
||||
)
|
||||
return expression
|
||||
|
||||
|
||||
@deprecated(
|
||||
"dynamic_expr is deprecated and will be removed in the future, please remove it."
|
||||
)
|
||||
def dynamic_expr(expression):
|
||||
raise DSLRuntimeError("dynamic_expr should be always preprocessed to IR")
|
||||
return expression
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@ -582,3 +401,86 @@ def bool_cast(value):
|
||||
suggestion = "Please explicitly convert to boolean with expressions like comparision."
|
||||
)
|
||||
return bool(value)
|
||||
|
||||
def compare_executor(left, comparators, ops):
|
||||
"""
|
||||
Executes comparison operations with a left operand and a list of comparators.
|
||||
|
||||
Args:
|
||||
left: The leftmost value in the comparison chain
|
||||
comparators: A list of values to compare against
|
||||
ops: A list of comparison operators to apply
|
||||
|
||||
Returns:
|
||||
The result of the comparison chain
|
||||
|
||||
Raises:
|
||||
AssertionError: If the executor function is not set before execution
|
||||
"""
|
||||
assert (
|
||||
executor._compare_executor is not None
|
||||
), "Function must be set before execution."
|
||||
return executor._compare_executor(left, comparators, ops)
|
||||
|
||||
|
||||
def any_executor(iterable):
|
||||
"""Executes the 'any' operation on an iterable, handling both dynamic and static expressions.
|
||||
|
||||
:param iterable: An iterable to check if any elements evaluate to True
|
||||
:type iterable: Iterable
|
||||
:return: boolean of Python value or IR value
|
||||
:rtype: bool or cutlass.Boolean
|
||||
|
||||
"""
|
||||
if executor._any_executor and executor._is_dynamic_expression(iterable):
|
||||
return executor._any_executor(iterable)
|
||||
else:
|
||||
return any(iterable)
|
||||
|
||||
|
||||
def all_executor(iterable):
|
||||
"""Executes the 'all' operation on an iterable, handling both dynamic and static expressions.
|
||||
|
||||
:param iterable: An iterable to check if all elements evaluate to True
|
||||
:type iterable: Iterable
|
||||
:return: boolean of Python value or IR value
|
||||
:rtype: bool or cutlass.Boolean
|
||||
"""
|
||||
if executor._all_executor and executor._is_dynamic_expression(iterable):
|
||||
return executor._all_executor(iterable)
|
||||
else:
|
||||
return all(iterable)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Control flow checks
|
||||
# =============================================================================
|
||||
def range_value_check(*args):
|
||||
"""
|
||||
Ensure all `range_constexpr` bounds are compile-time constants (Python ints).
|
||||
"""
|
||||
try:
|
||||
return tuple(arg.__index__() for arg in args)
|
||||
except:
|
||||
raise DSLRuntimeError(
|
||||
"`range_constexpr` requires constexpr (compile-time constant) for all arguments.",
|
||||
suggestion="Use `range` instead of `range_constexpr`.",
|
||||
)
|
||||
|
||||
|
||||
def range_perf_warning(filename, lineno, *args):
|
||||
has_dynamic_expr = False
|
||||
for arg in args:
|
||||
if executor._is_dynamic_expression(arg):
|
||||
has_dynamic_expr = True
|
||||
break
|
||||
if not has_dynamic_expr:
|
||||
warnings.warn_explicit(
|
||||
(
|
||||
"The loop was previously unrolled in Python, but now it may not unroll in IR. This may cause performance regression."
|
||||
"If you want to unroll the loop in Python, please use `range_constexpr` instead of `range`."
|
||||
),
|
||||
category=UserWarning,
|
||||
filename=filename,
|
||||
lineno=lineno,
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -164,16 +164,17 @@ def _mlir_type_to_numpy_type(type):
|
||||
|
||||
def is_dynamic_expression(value):
|
||||
"""
|
||||
Check if the value is an MLIR's SSA value.
|
||||
Given the `value`, check if itself is an IR value or recursively go through it to check if it contains IR value
|
||||
"""
|
||||
# Case 1: If the value has MLIR's SSA value, return True
|
||||
# Case 2: If the value supports __extract_mlir_values__ then it's possible to get SSA value
|
||||
return (
|
||||
isinstance(value, ir.Value)
|
||||
or hasattr(value, "__extract_mlir_values__")
|
||||
or len(extract_mlir_values(value)) > 0
|
||||
)
|
||||
|
||||
if isinstance(value, (tuple, list)):
|
||||
for x in value:
|
||||
if is_dynamic_expression(x):
|
||||
return True
|
||||
elif isinstance(value, (ir.Value, ir.BlockArgumentList)) or hasattr(
|
||||
value, "__extract_mlir_values__"
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def extract_mlir_values(obj):
|
||||
"""
|
||||
@ -726,6 +727,7 @@ class BaseDSL:
|
||||
)
|
||||
|
||||
jit_arg_types, jit_arg_attrs, jit_exec_args = [], [], []
|
||||
jit_adapted_args = []
|
||||
default_attr = ir.DictAttr.get({})
|
||||
|
||||
input_args = [*args, *kwargs.values()]
|
||||
@ -759,7 +761,9 @@ class BaseDSL:
|
||||
# If not any known type, try JIT argument adapter
|
||||
# to convert the argument
|
||||
adapter = JitArgAdapterRegistry.get_registered_adapter(type(arg))
|
||||
arg = adapter(arg) if adapter else arg
|
||||
if adapter:
|
||||
arg = adapter(arg)
|
||||
jit_adapted_args.append(arg)
|
||||
|
||||
if is_host:
|
||||
jit_exec_arg.extend(get_c_pointers(arg))
|
||||
@ -798,14 +802,14 @@ class BaseDSL:
|
||||
jit_arg_types.extend(jit_arg_type)
|
||||
jit_arg_attrs.extend(jit_arg_attr)
|
||||
|
||||
return jit_exec_args, jit_arg_types, jit_arg_attrs
|
||||
return jit_exec_args, jit_arg_types, jit_arg_attrs, jit_adapted_args
|
||||
|
||||
def generate_mlir_function_types(
|
||||
self, func, function_name, input_args, kwargs, args_spec: inspect.FullArgSpec
|
||||
):
|
||||
"""Convert input arguments to MLIR function signature also convert numpy arrays to memref."""
|
||||
|
||||
exe_args, types, _ = self._generate_jit_func_args(
|
||||
exe_args, types, attrs, adapted_args = self._generate_jit_func_args(
|
||||
func, function_name, input_args, kwargs, args_spec, is_host=True
|
||||
)
|
||||
|
||||
@ -816,7 +820,7 @@ class BaseDSL:
|
||||
types
|
||||
), "expects the same number of arguments and function parameters"
|
||||
|
||||
return exe_args, types
|
||||
return exe_args, types, adapted_args
|
||||
|
||||
@dataclass
|
||||
class LaunchConfig:
|
||||
@ -1158,7 +1162,7 @@ class BaseDSL:
|
||||
"""Generate MLIR module and compile iself.T_provider."""
|
||||
with ir.Context(), ir.Location.unknown():
|
||||
# Convert input arguments to MLIR arguments
|
||||
exe_args, func_types = self.generate_mlir_function_types(
|
||||
exe_args, func_types, adapted_args = self.generate_mlir_function_types(
|
||||
funcBody, function_name, args, kwargs, args_spec
|
||||
)
|
||||
|
||||
@ -1476,7 +1480,7 @@ class BaseDSL:
|
||||
if self.device_compilation_only:
|
||||
return kernel_operands, kernel_arg_types, kernel_arg_attrs
|
||||
|
||||
kernel_operands, kernel_arg_types, kernel_arg_attrs = (
|
||||
kernel_operands, kernel_arg_types, kernel_arg_attrs, _ = (
|
||||
self._generate_jit_func_args(
|
||||
kernel_func, kernel_name, args, kwargs, args_spec, is_host=False
|
||||
)
|
||||
@ -1586,12 +1590,14 @@ class BaseDSL:
|
||||
if self.device_compilation_only:
|
||||
log().debug("Generating cuda-python arguments")
|
||||
# Convert input arguments to MLIR arguments
|
||||
self.exe_args, kernel_types = self.generate_mlir_function_types(
|
||||
funcBody,
|
||||
kernel_name,
|
||||
canonicalized_args,
|
||||
canonicalized_kwargs,
|
||||
args_spec,
|
||||
self.exe_args, kernel_types, _ = (
|
||||
self.generate_mlir_function_types(
|
||||
funcBody,
|
||||
kernel_name,
|
||||
canonicalized_args,
|
||||
canonicalized_kwargs,
|
||||
args_spec,
|
||||
)
|
||||
)
|
||||
|
||||
helper = kernelGenHelper()
|
||||
|
||||
@ -78,11 +78,9 @@ def detect_gpu_arch(prefix):
|
||||
|
||||
major, minor = arch
|
||||
suffix = ""
|
||||
if major >= 9 and minor >= 0:
|
||||
if major >= 9:
|
||||
suffix = "a"
|
||||
elif minor != 0:
|
||||
# e.g sm_86, belong with sm_80 family
|
||||
minor = 0
|
||||
|
||||
return f"sm_{major}{minor}{suffix}"
|
||||
|
||||
|
||||
|
||||
@ -12,23 +12,24 @@
|
||||
"""
|
||||
This module provides jit executor related classes
|
||||
"""
|
||||
import io
|
||||
import inspect
|
||||
import ctypes
|
||||
import numpy as np
|
||||
import inspect
|
||||
import io
|
||||
from typing import get_origin
|
||||
|
||||
import numpy as np
|
||||
|
||||
# MLIR modules imports
|
||||
from .._mlir import ir
|
||||
|
||||
# Local modules imports
|
||||
from .utils.timer import timer
|
||||
from .utils.logger import log
|
||||
from . import typing as t
|
||||
from .common import DSLRuntimeError
|
||||
from .runtime import cuda as cuda_helpers
|
||||
from .runtime.jit_arg_adapters import JitArgAdapterRegistry, is_arg_spec_constexpr
|
||||
from .typing import get_c_pointers
|
||||
from . import typing as t
|
||||
|
||||
# MLIR modules imports
|
||||
from .._mlir import ir
|
||||
from .utils.logger import log
|
||||
from .utils.timer import timer
|
||||
|
||||
|
||||
class CudaSingleModule:
|
||||
@ -64,6 +65,7 @@ class JitExecutor:
|
||||
self.args_spec = args_spec
|
||||
self.function_name = function_name
|
||||
if args_spec is not None:
|
||||
self.original_args_spec = args_spec
|
||||
self.args_spec = self.filter_runtime_arg_spec(args_spec)
|
||||
# cuda kernels
|
||||
self.cuda_modules = cuda_modules
|
||||
@ -135,6 +137,29 @@ class JitExecutor:
|
||||
for module in set(cuda_modules):
|
||||
cuda_helpers.unload_cubin_module(module)
|
||||
|
||||
def get_constexpr_args(self) -> list[dict[str, int | str]]:
|
||||
"""
|
||||
This function returns the constexpr args that have been pruned from the original function signature.
|
||||
The return type is a list of dicts, each dict contains the argument index (argument_index) and argument name (argument_name).
|
||||
|
||||
:return: list of dicts, each dict contains the argument index (argument_index) and argument name (argument_name).
|
||||
:rtype: list[dict[str, int | str]]
|
||||
"""
|
||||
if self.original_args_spec is None:
|
||||
return list()
|
||||
constexpr_args = list()
|
||||
for i, arg_name in enumerate(self.original_args_spec.args):
|
||||
if arg_name not in self.args_spec.args:
|
||||
constexpr_args.append({"argument_index": i, "argument_name": arg_name})
|
||||
|
||||
if self.original_args_spec.kwonlyargs:
|
||||
for kwarg in self.original_args_spec.kwonlyargs:
|
||||
if kwarg not in self.args_spec.kwonlyargs:
|
||||
constexpr_args.append(
|
||||
{"argument_index": None, "argument_name": kwarg}
|
||||
)
|
||||
return constexpr_args
|
||||
|
||||
def generate_execution_args(self, args, kwargs, args_spec: inspect.FullArgSpec):
|
||||
"""
|
||||
This function is the prune version of `generate_mlir_function_types` which only generates execution args
|
||||
@ -175,6 +200,7 @@ class JitExecutor:
|
||||
)
|
||||
|
||||
exe_args = []
|
||||
adapted_args = []
|
||||
input_args = rectified_args + list(rectified_kwargs.values())
|
||||
input_arg_names = args_spec.args + args_spec.kwonlyargs
|
||||
for arg, arg_name in zip(input_args, input_arg_names):
|
||||
@ -193,13 +219,16 @@ class JitExecutor:
|
||||
adapter = JitArgAdapterRegistry.get_registered_adapter(type(arg))
|
||||
if adapter:
|
||||
arg = adapter(arg)
|
||||
adapted_args.append(arg)
|
||||
|
||||
exe_args.extend(get_c_pointers(arg))
|
||||
|
||||
return exe_args
|
||||
return exe_args, adapted_args
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
exe_args = self.generate_execution_args(args, kwargs, self.args_spec)
|
||||
exe_args, adapted_args = self.generate_execution_args(
|
||||
args, kwargs, self.args_spec
|
||||
)
|
||||
|
||||
self.run_compiled_program(exe_args)
|
||||
|
||||
|
||||
@ -46,29 +46,75 @@ from .._mlir.dialects import arith, math
|
||||
|
||||
@runtime_checkable
|
||||
class DynamicExpression(Protocol):
|
||||
"""
|
||||
This is a protocol class that provides a common interface
|
||||
to generate user-defined dynamic expressions.
|
||||
"""Protocol defining the interface for object holding dynamic values in the DSL.
|
||||
|
||||
The DSL checks this protocol to determine if a class is a dynamic expression (SSA value) or not.
|
||||
This protocol enables classes to represent dynamic values in the DSL. Classes implementing
|
||||
this protocol can be used in JIT-compiled functions and dynamic value generation.
|
||||
|
||||
It is required for custom data types to work correctly with following JIT features:
|
||||
* as function argument to call another JIT function from JIT function
|
||||
* as return value from JIT function
|
||||
* for constructions like if-else, while-loop, etc.
|
||||
|
||||
:param value: The MLIR operation result value to initialize the object with
|
||||
:type value: ir.Value
|
||||
|
||||
**Required Methods**
|
||||
|
||||
* ``__extract_mlir_values__``: Extract MLIR values from the object
|
||||
* ``__new_from_mlir_values__``: Create new instance from MLIR values
|
||||
|
||||
**Implementation Example**
|
||||
|
||||
To implement a custom data type that works with the DSL:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class CustomData(metaclass=DslType):
|
||||
def __init__(self, int_value):
|
||||
self.int_value = int_value
|
||||
|
||||
def __extract_mlir_values__(self):
|
||||
return [self.int_value]
|
||||
|
||||
def __new_from_mlir_values__(self, values):
|
||||
return CustomData(values[0])
|
||||
|
||||
**Usage in JIT Functions**
|
||||
|
||||
When used in JIT-compiled functions, the DSL automatically extracts MLIR values:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@jit
|
||||
def caller():
|
||||
x = CustomData(1)
|
||||
return foo(x)
|
||||
|
||||
This generates MLIR like:
|
||||
|
||||
.. code-block:: mlir
|
||||
|
||||
func @caller() -> i32 {
|
||||
%0 = func.call @foo(%arg0) : (i32) -> i32
|
||||
return %0 : i32
|
||||
}
|
||||
"""
|
||||
|
||||
def __extract_mlir_values__(self):
|
||||
"""
|
||||
Generate a dynamic expression for the current object.
|
||||
"""Extract MLIR values from this object.
|
||||
|
||||
:return: List of MLIR values
|
||||
:return: List of MLIR values representing this object's data
|
||||
:rtype: List[ir.Value]
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __new_from_mlir_values__(self, values):
|
||||
"""
|
||||
Create a new object from MLIR values.
|
||||
"""Create a new instance from MLIR values.
|
||||
|
||||
:param values: List of MLIR values
|
||||
:param values: List of MLIR values to construct the object from
|
||||
:type values: List[ir.Value]
|
||||
:return: A new instance of the class that implements this protocol
|
||||
:return: New instance of the implementing class
|
||||
:rtype: Any
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@ -77,50 +123,73 @@ class DynamicExpression(Protocol):
|
||||
@runtime_checkable
|
||||
class JitArgument(Protocol):
|
||||
"""
|
||||
This is a protocol class that provides a common interface
|
||||
for JIT function arguments generation for Python to call JIT functions.
|
||||
Protocol class defining the interface for JIT function argument generation.
|
||||
|
||||
The DSL checks this protocol to determine if a class is capable of providing information
|
||||
needed for generating JIT function arguments.
|
||||
This protocol enables classes to provide the necessary information for generating
|
||||
JIT function arguments and allow the DSL JIT executor to call JIT compiled functions.
|
||||
|
||||
See breakdowns below for JitArgument protocol based JIT function calls.
|
||||
**Required Methods**
|
||||
|
||||
* ``__c_pointers__``: Returns ctypes pointers for runtime execution
|
||||
* ``__get_mlir_types__``: Returns MLIR types for function definition
|
||||
* ``__new_from_mlir_values__``: Creates new instances from MLIR values
|
||||
|
||||
**Example**
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class CustomData:
|
||||
def __init__(self, int_value, ...):
|
||||
self.int_value = int_value
|
||||
...
|
||||
|
||||
def __c_pointers__(self):
|
||||
return [ctypes.pointer(ctypes.c_int32(self.int_value)), ...]
|
||||
|
||||
def __get_mlir_types__(self):
|
||||
return [ir.IntegerType.get(32), ...]
|
||||
|
||||
def __new_from_mlir_values__(self, values):
|
||||
return CustomData(values[0], ...)
|
||||
|
||||
@jit
|
||||
def foo(x: CustomData):
|
||||
return x.int_value + 1
|
||||
a = x.int_value + 1
|
||||
...
|
||||
|
||||
# Emit: `%c0 = arith.constant(1, i32)`
|
||||
c1 = const(1, Int32)
|
||||
# `c1` tracks `%c0` defined outside of function body of `foo`
|
||||
# `%c0` can't be used directly in function body of `foo`
|
||||
x = CustomData(c1, ...)
|
||||
# `CustomData` is an argument of `foo`
|
||||
foo(CustomData(1, ...))
|
||||
|
||||
When called like ``y = foo(x)``, the following steps occur:
|
||||
|
||||
1. JIT compiler generates MLIR function definition using ``__get_mlir_types__``:
|
||||
1. JIT compiler generates MLIR function definition using ``__get_mlir_types__``
|
||||
|
||||
.. code-block:: mlir
|
||||
|
||||
func @foo(%arg0: i32, ...) -> i32 {
|
||||
func.func @foo(%arg0: i32, ...) {
|
||||
...
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
2. Function is traced in Python, wrapping MLIR values with ``__new_from_mlir_values__``:
|
||||
2. JIT function can't use values from Python, so it needs to reconstruct the object from
|
||||
MLIR values, a.k.a `%arg0`, with ``__new_from_mlir_values__`` and pass it to `foo`.
|
||||
|
||||
Following code demonstrates how JIT compiler reconstructs the object and pass to Python.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Implementation of IR tracing
|
||||
new_x = CustomData(ir.Value(%arg0), ...)
|
||||
y = foo(new_x)
|
||||
# `x.int_value` is %arg0 rather than `c1` defined outside
|
||||
# `x.int_value` is %arg0 rather than `c1` defined by Python.
|
||||
|
||||
3. For Python runtime execution, JIT engine invokes compiled function using ``__c_pointers__``:
|
||||
3. For Python runtime execution, JIT engine invokes compiled function using ``__c_pointers__``
|
||||
pointing to the underlying data object passing to JIT compiled function.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
jit_engine.invoke(foo, concat([x.__c_pointers__(), ...]))
|
||||
jit_engine.invoke(compiled_foo, concat([x.__c_pointers__(), ...]))
|
||||
"""
|
||||
|
||||
def __c_pointers__(self):
|
||||
@ -224,47 +293,6 @@ class DslType(type):
|
||||
:property mlir_type: Returns the corresponding MLIR type for this DSL type
|
||||
:type mlir_type: Any
|
||||
|
||||
**Examples**
|
||||
|
||||
Define a custom data type:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class CustomData(metaclass=DslType, ...):
|
||||
def __init__(self, int_value, ...):
|
||||
self.int_value = int_value
|
||||
...
|
||||
|
||||
def __str__(cls):
|
||||
return "CustomData[int, ...]"
|
||||
|
||||
def __c_pointers__(self):
|
||||
return [ctypes.pointer(ctypes.c_int32(self.int_value)), ...]
|
||||
|
||||
def __get_mlir_types__(self):
|
||||
return [_T.i32(), ...]
|
||||
|
||||
def __extract_mlir_values__(self):
|
||||
return [self.int_value, ...]
|
||||
|
||||
def __new_from_mlir_values__(self, values):
|
||||
return CustomData(values[0], ...)
|
||||
|
||||
For JIT function calls, MLIR values are extracted with ``__extract_mlir_values__``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@jit
|
||||
def caller():
|
||||
x = CustomData(1, ...)
|
||||
return foo(x)
|
||||
|
||||
.. code-block:: mlir
|
||||
|
||||
func @caller() -> i32 {
|
||||
%0 = func.call @foo(%arg0, ...) : (i32, ...) -> i32
|
||||
return %0 : i32
|
||||
}
|
||||
"""
|
||||
|
||||
_is_abstract: bool
|
||||
@ -946,9 +974,12 @@ class Numeric(metaclass=NumericMeta, is_abstract=True):
|
||||
:return: The result of the logical not operation
|
||||
:rtype: Boolean
|
||||
"""
|
||||
ty = type(self)
|
||||
zero_val = arith.constant(ty.mlir_type, ty.zero)
|
||||
return self.__eq__(ty(zero_val), loc=loc, ip=ip)
|
||||
if isinstance(self.value, (int, float, bool)):
|
||||
return not self.value
|
||||
else:
|
||||
ty = type(self)
|
||||
zero_val = arith.constant(ty.mlir_type, ty.zero)
|
||||
return self.__eq__(ty(zero_val), loc=loc, ip=ip)
|
||||
|
||||
def __dsl_and__(self, other, *, loc=None, ip=None):
|
||||
"""DSL implementation of Python's `and` operator.
|
||||
@ -1057,6 +1088,15 @@ class Numeric(metaclass=NumericMeta, is_abstract=True):
|
||||
],
|
||||
)
|
||||
|
||||
def __index__(self):
|
||||
if isinstance(self.value, (int, float, bool)):
|
||||
return self.value
|
||||
else:
|
||||
raise DSLRuntimeError(
|
||||
f"'{type(self.value)}' object cannot be interpreted as an integer",
|
||||
suggestion="Mark the loop as dynamic with `dynamic_expr` or `range_dynamic` and decorate the parent function with `jit` decorator",
|
||||
)
|
||||
|
||||
def __neg__(self, *, loc=None, ip=None):
|
||||
if isinstance(self, (bool, int, float)):
|
||||
return type(self)(-self.value) # type: ignore
|
||||
@ -1813,7 +1853,7 @@ class IRVariadic:
|
||||
|
||||
def __init__(self, operands):
|
||||
"""
|
||||
Create a list of variadic operands. `operands` must be SSA values.
|
||||
Create a list of variadic operands. `operands` must be dynamic values.
|
||||
"""
|
||||
self.operands = operands
|
||||
|
||||
|
||||
@ -68,7 +68,9 @@ from .core import (
|
||||
select,
|
||||
front,
|
||||
is_major,
|
||||
leading_dim,
|
||||
find,
|
||||
find_if,
|
||||
coalesce,
|
||||
group_modes,
|
||||
cosize,
|
||||
@ -221,7 +223,9 @@ __all__ = [
|
||||
"select",
|
||||
"front",
|
||||
"is_major",
|
||||
"leading_dim",
|
||||
"find",
|
||||
"find_if",
|
||||
"coalesce",
|
||||
"group_modes",
|
||||
"cosize",
|
||||
|
||||
@ -25,12 +25,13 @@ __all__ = [
|
||||
#
|
||||
# mbar.py
|
||||
#
|
||||
"mbarrier_init_arrive_cnt",
|
||||
"mbarrier_init",
|
||||
"mbarrier_init_fence",
|
||||
"mbarrier_init_tx_bytes",
|
||||
"mbarrier_arrive_and_expect_tx",
|
||||
"mbarrier_expect_tx",
|
||||
"mbarrier_wait",
|
||||
"mbarrier_try_wait",
|
||||
"conditional_mbarrier_try_wait",
|
||||
"mbarrier_conditional_try_wait",
|
||||
"mbarrier_arrive",
|
||||
#
|
||||
# nvvm_wrappers.py
|
||||
@ -51,6 +52,7 @@ __all__ = [
|
||||
"shuffle_sync_down",
|
||||
"shuffle_sync_bfly",
|
||||
"barrier",
|
||||
"barrier_arrive",
|
||||
"sync_threads",
|
||||
"sync_warp",
|
||||
"fence_acq_rel_cta",
|
||||
|
||||
@ -69,7 +69,16 @@ def elect_one(*, loc=None, ip=None) -> IfOpRegion:
|
||||
pass
|
||||
"""
|
||||
arch = CuTeDSL._get_dsl().envar.arch
|
||||
check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch")
|
||||
check_value_in(
|
||||
arch,
|
||||
[
|
||||
"sm_90",
|
||||
"sm_90a",
|
||||
"sm_100a",
|
||||
"sm_100f",
|
||||
],
|
||||
"arch",
|
||||
)
|
||||
is_thread_leader = nvvm.elect_sync(T.bool())
|
||||
if_op = scf.IfOp(is_thread_leader, loc=loc, ip=ip)
|
||||
return IfOpRegion(if_op.then_block, loc=loc, ip=ip)
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
from typing import Optional
|
||||
|
||||
from cutlass.cutlass_dsl import CuTeDSL, T, if_generate, dsl_user_op
|
||||
|
||||
@ -26,7 +27,7 @@ from ...impl_utils import check_value_in
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def mbarrier_init_arrive_cnt(mbar_ptr: Pointer, cnt: Int, *, loc=None, ip=None) -> None:
|
||||
def mbarrier_init(mbar_ptr: Pointer, cnt: Int, *, loc=None, ip=None) -> None:
|
||||
"""
|
||||
Initializes a mbarrier with the specified thread arrival count.
|
||||
|
||||
@ -46,16 +47,25 @@ def mbarrier_init_fence(*, loc=None, ip=None) -> None:
|
||||
A fence operation that applies to the mbarrier initializations.
|
||||
"""
|
||||
arch = CuTeDSL._get_dsl().envar.arch
|
||||
check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch")
|
||||
check_value_in(
|
||||
arch,
|
||||
[
|
||||
"sm_90",
|
||||
"sm_90a",
|
||||
"sm_100a",
|
||||
"sm_100f",
|
||||
],
|
||||
"arch",
|
||||
)
|
||||
nvvm.fence_mbarrier_init(loc=loc, ip=ip)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def mbarrier_init_tx_bytes(
|
||||
def mbarrier_arrive_and_expect_tx(
|
||||
mbar_ptr: Pointer, bytes: Int, peer_cta_rank_in_cluster=None, *, loc=None, ip=None
|
||||
) -> None:
|
||||
"""
|
||||
Initializes a mbarrier with the specified number of transaction bytes.
|
||||
Arrives on a mbarrier and expects a specified number of transaction bytes.
|
||||
|
||||
:param mbar_ptr: A pointer to the mbarrier in SMEM
|
||||
:type mbar_ptr: Pointer
|
||||
@ -66,7 +76,16 @@ def mbarrier_init_tx_bytes(
|
||||
SMEM.
|
||||
"""
|
||||
arch = CuTeDSL._get_dsl().envar.arch
|
||||
check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch")
|
||||
check_value_in(
|
||||
arch,
|
||||
[
|
||||
"sm_90",
|
||||
"sm_90a",
|
||||
"sm_100a",
|
||||
"sm_100f",
|
||||
],
|
||||
"arch",
|
||||
)
|
||||
|
||||
mbar_llvm_ptr = mbar_ptr.llvm_ptr
|
||||
if peer_cta_rank_in_cluster is not None:
|
||||
@ -91,6 +110,56 @@ def mbarrier_init_tx_bytes(
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def mbarrier_expect_tx(
|
||||
mbar_ptr: Pointer, bytes: Int, peer_cta_rank_in_cluster=None, *, loc=None, ip=None
|
||||
) -> None:
|
||||
"""
|
||||
Expects a specified number of transaction bytes without an arrive.
|
||||
|
||||
:param mbar_ptr: A pointer to the mbarrier in SMEM
|
||||
:type mbar_ptr: Pointer
|
||||
:param bytes: The number of transaction bytes
|
||||
:type bytes: Int
|
||||
:param peer_cta_rank_in_cluster: An optional CTA rank in cluster. If provided, the pointer to
|
||||
the mbarrier is converted to a remote address in the peer CTA's
|
||||
SMEM.
|
||||
"""
|
||||
arch = CuTeDSL._get_dsl().envar.arch
|
||||
check_value_in(
|
||||
arch,
|
||||
[
|
||||
"sm_90",
|
||||
"sm_90a",
|
||||
"sm_100a",
|
||||
"sm_100f",
|
||||
],
|
||||
"arch",
|
||||
)
|
||||
|
||||
mbar_llvm_ptr = mbar_ptr.llvm_ptr
|
||||
if peer_cta_rank_in_cluster is not None:
|
||||
mbar_llvm_ptr = nvvm.mapa(
|
||||
mbar_llvm_ptr.type,
|
||||
mbar_llvm_ptr,
|
||||
Int32(peer_cta_rank_in_cluster).ir_value(loc=loc, ip=ip),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
space = nvvm.MBarrierSpaceKind.CLUSTER
|
||||
else:
|
||||
space = nvvm.MBarrierSpaceKind.CTA
|
||||
|
||||
nvvm.mbarrier_txn(
|
||||
mbar_llvm_ptr,
|
||||
Int32(bytes).ir_value(loc=loc, ip=ip),
|
||||
kind=nvvm.MBarrierTxnKind.EXPECT_TX,
|
||||
space=space,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def mbarrier_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> None:
|
||||
"""
|
||||
@ -102,7 +171,16 @@ def mbarrier_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> None:
|
||||
:type phase: Int
|
||||
"""
|
||||
arch = CuTeDSL._get_dsl().envar.arch
|
||||
check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch")
|
||||
check_value_in(
|
||||
arch,
|
||||
[
|
||||
"sm_90",
|
||||
"sm_90a",
|
||||
"sm_100a",
|
||||
"sm_100f",
|
||||
],
|
||||
"arch",
|
||||
)
|
||||
|
||||
timeout_ns = 10000000
|
||||
# This NVVM Op is a spin-loop wrapping the mbarrier.try_wait.parity.shared.b64 PTX
|
||||
@ -129,7 +207,16 @@ def mbarrier_try_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> Bo
|
||||
:rtype: Boolean
|
||||
"""
|
||||
arch = CuTeDSL._get_dsl().envar.arch
|
||||
check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch")
|
||||
check_value_in(
|
||||
arch,
|
||||
[
|
||||
"sm_90",
|
||||
"sm_90a",
|
||||
"sm_100a",
|
||||
"sm_100f",
|
||||
],
|
||||
"arch",
|
||||
)
|
||||
|
||||
return Boolean(
|
||||
nvvm.mbarrier_wait_parity(
|
||||
@ -144,7 +231,7 @@ def mbarrier_try_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> Bo
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def conditional_mbarrier_try_wait(
|
||||
def mbarrier_conditional_try_wait(
|
||||
cond, mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None
|
||||
) -> Boolean:
|
||||
"""
|
||||
@ -159,7 +246,16 @@ def conditional_mbarrier_try_wait(
|
||||
:rtype: Boolean
|
||||
"""
|
||||
arch = CuTeDSL._get_dsl().envar.arch
|
||||
check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch")
|
||||
check_value_in(
|
||||
arch,
|
||||
[
|
||||
"sm_90",
|
||||
"sm_90a",
|
||||
"sm_100a",
|
||||
"sm_100f",
|
||||
],
|
||||
"arch",
|
||||
)
|
||||
return if_generate(
|
||||
cond,
|
||||
lambda: mbarrier_try_wait(mbar_ptr, phase, loc=loc, ip=ip),
|
||||
@ -171,7 +267,11 @@ def conditional_mbarrier_try_wait(
|
||||
|
||||
@dsl_user_op
|
||||
def mbarrier_arrive(
|
||||
mbar_ptr: Pointer, peer_cta_rank_in_cluster: Int = None, *, loc=None, ip=None
|
||||
mbar_ptr: Pointer,
|
||||
peer_cta_rank_in_cluster: Optional[Int] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> None:
|
||||
"""
|
||||
Arrives on an mbarrier.
|
||||
@ -185,7 +285,16 @@ def mbarrier_arrive(
|
||||
mbar_llvm_ptr = mbar_ptr.llvm_ptr
|
||||
if peer_cta_rank_in_cluster is not None:
|
||||
arch = CuTeDSL._get_dsl().envar.arch
|
||||
check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch")
|
||||
check_value_in(
|
||||
arch,
|
||||
[
|
||||
"sm_90",
|
||||
"sm_90a",
|
||||
"sm_100a",
|
||||
"sm_100f",
|
||||
],
|
||||
"arch",
|
||||
)
|
||||
|
||||
mbar_llvm_ptr = nvvm.mapa_shared_cluster(
|
||||
mbar_llvm_ptr.type,
|
||||
|
||||
@ -225,6 +225,25 @@ def barrier(*, barrier_id=None, number_of_threads=None, loc=None, ip=None) -> No
|
||||
barrier_id=barrier_id, number_of_threads=number_of_threads, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def barrier_arrive(
|
||||
*, barrier_id=None, number_of_threads=None, loc=None, ip=None
|
||||
) -> None:
|
||||
if barrier_id is not None:
|
||||
barrier_id = Int32(barrier_id).ir_value(loc=loc, ip=ip)
|
||||
|
||||
if number_of_threads is None:
|
||||
raise ValueError(
|
||||
"barrier_arrive needs pass number_of_threads to arrive the barrier",
|
||||
)
|
||||
number_of_threads = Int32(number_of_threads).ir_value(loc=loc, ip=ip)
|
||||
|
||||
nvvm.barrier_arrive(
|
||||
barrier_id=barrier_id, number_of_threads=number_of_threads, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def sync_threads(*, loc=None, ip=None) -> None:
|
||||
"""
|
||||
@ -545,3 +564,20 @@ def exp2(a: Union[float, Float32], *, loc=None, ip=None) -> Float32:
|
||||
asm_dialect=llvm.AsmDialect.AD_ATT,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# TODO: add `fastmath` flag for this op
|
||||
@dsl_user_op
|
||||
def exp(a: Union[float, Float32], *, loc=None, ip=None) -> Float32:
|
||||
LOG2_E = 1.4426950408889634
|
||||
return exp2(a * LOG2_E, loc=loc, ip=ip)
|
||||
|
||||
|
||||
# TODO: add `fastmath` flag for this op
|
||||
@dsl_user_op
|
||||
def exp_packed_f32x2(
|
||||
a: Tuple[Float32, Float32], *, loc=None, ip=None
|
||||
) -> Tuple[Float32, Float32]:
|
||||
LOG2_E = Float32(1.4426950408889634)
|
||||
b = mul_packed_f32x2(a, (LOG2_E, LOG2_E), loc=loc, ip=ip)
|
||||
return exp2(b[0], loc=loc, ip=ip), exp2(b[1], loc=loc, ip=ip)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -26,7 +26,7 @@ __all__ = [
|
||||
#
|
||||
# helpers.py
|
||||
#
|
||||
"make_tma_tile_atom",
|
||||
"make_tiled_tma_atom",
|
||||
"tma_partition",
|
||||
"create_tma_multicast_mask",
|
||||
"prefetch_descriptor",
|
||||
|
||||
@ -127,7 +127,12 @@ class CopyBulkTensorTileG2SOp(CopyOp):
|
||||
|
||||
cta_group: CtaGroup = CtaGroup.ONE
|
||||
|
||||
admissible_archs = ["sm_90", "sm_90a", "sm_100a"]
|
||||
admissible_archs = [
|
||||
"sm_90",
|
||||
"sm_90a",
|
||||
"sm_100a",
|
||||
"sm_100f",
|
||||
]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not isinstance(self.cta_group, CtaGroup):
|
||||
@ -159,7 +164,7 @@ class CopyBulkTensorTileG2SOp(CopyOp):
|
||||
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
||||
) -> "CopyBulkTensorTileG2SNonExecTrait":
|
||||
raise NotImplementedError(
|
||||
"Use cpasync.make_tma_tile_atom to obtain a copy Atom for TMA"
|
||||
"Use cpasync.make_tiled_tma_atom to obtain a copy Atom for TMA"
|
||||
)
|
||||
|
||||
def _to_ir(self) -> _cute_nvgpu_ir.TiledTmaLoadEnum:
|
||||
@ -224,7 +229,12 @@ class CopyBulkTensorTileG2SMulticastOp(CopyOp):
|
||||
|
||||
cta_group: CtaGroup = CtaGroup.ONE
|
||||
|
||||
admissible_archs = ["sm_90", "sm_90a", "sm_100a"]
|
||||
admissible_archs = [
|
||||
"sm_90",
|
||||
"sm_90a",
|
||||
"sm_100a",
|
||||
"sm_100f",
|
||||
]
|
||||
|
||||
def __post_init__(self):
|
||||
if not isinstance(self.cta_group, CtaGroup):
|
||||
@ -256,7 +266,7 @@ class CopyBulkTensorTileG2SMulticastOp(CopyOp):
|
||||
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
||||
) -> "CopyBulkTensorTileG2SMulticastNonExecTrait":
|
||||
raise NotImplementedError(
|
||||
"Use cpasync.make_tma_tile_atom to obtain a copy Atom for TMA"
|
||||
"Use cpasync.make_tiled_tma_atom to obtain a copy Atom for TMA"
|
||||
)
|
||||
|
||||
def _to_ir(self) -> _cute_nvgpu_ir.TiledTmaLoadEnum:
|
||||
@ -326,7 +336,12 @@ class CopyBulkTensorTileS2GOp(CopyOp):
|
||||
This Operation uses TMA in the ``.tile`` mode.
|
||||
"""
|
||||
|
||||
admissible_archs = ["sm_90", "sm_90a", "sm_100a"]
|
||||
admissible_archs = [
|
||||
"sm_90",
|
||||
"sm_90a",
|
||||
"sm_100a",
|
||||
"sm_100f",
|
||||
]
|
||||
|
||||
def __post_init__(self):
|
||||
# Arch verification
|
||||
@ -345,7 +360,7 @@ class CopyBulkTensorTileS2GOp(CopyOp):
|
||||
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
||||
) -> "CopyBulkTensorTileS2GTrait":
|
||||
raise NotImplementedError(
|
||||
"Use cpasync.make_tma_tile_atom to obtain a copy Atom for TMA"
|
||||
"Use cpasync.make_tiled_tma_atom to obtain a copy Atom for TMA"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -29,14 +29,14 @@ from .copy import (
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def make_tma_tile_atom(
|
||||
def make_tiled_tma_atom(
|
||||
op: Union[
|
||||
CopyBulkTensorTileG2SOp,
|
||||
CopyBulkTensorTileG2SMulticastOp,
|
||||
CopyBulkTensorTileS2GOp,
|
||||
],
|
||||
gmem_tensor: Tensor,
|
||||
smem_layout: Layout,
|
||||
smem_layout: Union[Layout, core.ComposedLayout],
|
||||
cta_tiler: Tiler,
|
||||
num_multicast: int = 1,
|
||||
*,
|
||||
@ -45,7 +45,7 @@ def make_tma_tile_atom(
|
||||
ip=None,
|
||||
) -> Tuple[core.CopyAtom, Tensor]:
|
||||
"""
|
||||
Makes a TMA Copy Atom in the ``.tile`` mode to copy tiles of a GMEM tensor to/from and SMEM
|
||||
Makes a TMA Copy Atom in the ``.tile`` mode to copy tiles of a GMEM tensor to/from SMEM
|
||||
buffer with the given Layout.
|
||||
|
||||
Given
|
||||
@ -71,7 +71,7 @@ def make_tma_tile_atom(
|
||||
:param gmem_tensor: The GMEM tensor involved in the Copy
|
||||
:type gmem_tensor: Tensor
|
||||
:param smem_layout: The SMEM layout to construct the Copy Atom for
|
||||
:type smem_layout: Layout
|
||||
:type smem_layout: Union[Layout, core.ComposedLayout]
|
||||
:param cta_tiler: The CTA Tiler to use
|
||||
:type cta_tiler: Tiler
|
||||
:param num_multicast: The multicast factor
|
||||
@ -94,6 +94,12 @@ def make_tma_tile_atom(
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
# Wrap smem_layout in a composed layout to make it a TMA-friendly layout
|
||||
if isinstance(smem_layout, Layout):
|
||||
smem_layout = core.make_composed_layout(
|
||||
core.make_swizzle(0, 4, 3), 0, smem_layout
|
||||
)
|
||||
|
||||
if isinstance(op, CopyBulkTensorTileG2SOp):
|
||||
if num_multicast != 1:
|
||||
raise ValueError(
|
||||
|
||||
@ -34,7 +34,7 @@ from .cpasync.copy import (
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def make_tma_tile_atom_A(
|
||||
def make_tiled_tma_atom_A(
|
||||
op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp],
|
||||
gmem_tensor: Tensor,
|
||||
smem_layout: Layout,
|
||||
@ -46,6 +46,51 @@ def make_tma_tile_atom_A(
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Tuple[core.CopyAtom, Tensor]:
|
||||
"""
|
||||
Makes a TMA Copy atom mapping to ``.tile`` mode for ``cp.async.bulk.tensor`` PTX operation
|
||||
accounting for the MK projections of the TiledMMA for A tensor loads.
|
||||
|
||||
Given
|
||||
|
||||
- a GMEM tensor
|
||||
- a SMEM layout
|
||||
- a MMA Tiler
|
||||
- a TiledMma
|
||||
- a Cluster-level shape
|
||||
|
||||
this function figures out the bulk tensor asynchronous copy instruction to use with the maximum
|
||||
"TMA vector length" to copy tiles of the GMEM tensor to an SMEM buffer with the provided
|
||||
layout and consistent with the provided Tiler & tiled_mma (considering the M-mode & K-mode).
|
||||
The Cluster-level shape is used to determine the multicast factor across the N-mode for A tensor loads.
|
||||
|
||||
This function returns two results:
|
||||
|
||||
1. the Copy Atom
|
||||
2. the so-called TMA tensor used to map logical coordinates of the GMEM tensor to coordinates
|
||||
that the TMA unit can consume. TMA tensors have so-called basis stride elements so that the
|
||||
associated layout can output coordinates. Otherwise, TMA tensors can be partitioned
|
||||
similarly to any other CuTe tensors using the algebra.
|
||||
|
||||
:param op: The Copy Operation to construct an Atom for
|
||||
:type op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp]
|
||||
:param gmem_tensor: The GMEM tensor to be loaded by this copy atom
|
||||
:type gmem_tensor: Tensor
|
||||
:param smem_layout: Shared memory layout to load the tensor into (PDSL)
|
||||
:type smem_layout: Layout
|
||||
:param mma_tiler_mnk: The MMA Tiler shape (TILE_M, TILE_N, TILE_K) in MNK dimensions
|
||||
:type mma_tiler_mnk: Shape
|
||||
:param tiled_mma: The TiledMMA that will consume the load as operands
|
||||
:type tiled_mma: core.TiledMma
|
||||
:param cluster_shape_vmnk: The Cluster-level shape in VMNK dimensions
|
||||
:type cluster_shape_vmnk: Shape
|
||||
:param internal_type: An optional parameter for the internal data type to when element
|
||||
type does not match the copy type
|
||||
:type internal_type: Type[Numeric]
|
||||
:return: A copy atom for this operation and the associated TMA coord tensor
|
||||
:rtype: Tuple[core.CopyAtom, Tensor]
|
||||
|
||||
"""
|
||||
|
||||
if internal_type is not None:
|
||||
if not isinstance(internal_type, NumericMeta):
|
||||
raise TypeError(f"internal_type must be a Numeric, but got {internal_type}")
|
||||
@ -54,7 +99,7 @@ def make_tma_tile_atom_A(
|
||||
op,
|
||||
[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp],
|
||||
"op",
|
||||
"make_tma_tile_atom_A",
|
||||
"make_tiled_tma_atom_A",
|
||||
)
|
||||
|
||||
ident = core.make_identity_layout(gmem_tensor.shape, loc=loc, ip=ip)
|
||||
@ -94,7 +139,7 @@ def make_tma_tile_atom_A(
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def make_tma_tile_atom_B(
|
||||
def make_tiled_tma_atom_B(
|
||||
op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp],
|
||||
gmem_tensor: Tensor,
|
||||
smem_layout: Layout,
|
||||
@ -106,6 +151,51 @@ def make_tma_tile_atom_B(
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Tuple[core.CopyAtom, Tensor]:
|
||||
"""
|
||||
Makes a TMA Copy atom mapping to ``.tile`` mode for ``cp.async.bulk.tensor`` PTX operation
|
||||
accounting for the NK projections of the TiledMMA for B tensor loads.
|
||||
|
||||
Given
|
||||
|
||||
- a GMEM tensor
|
||||
- a SMEM layout
|
||||
- a MMA Tiler
|
||||
- a TiledMma
|
||||
- a Cluster-level shape
|
||||
|
||||
this function figures out the bulk tensor asynchronous copy instruction to use with the maximum
|
||||
"TMA vector length" to copy tiles of the GMEM tensor to an SMEM buffer with the provided
|
||||
layout and consistent with the provided Tiler & tiled_mma (considering the N-mode & K-mode).
|
||||
The Cluster-level shape is used to determine the multicast factor across the M-mode for B tensor loads.
|
||||
|
||||
This function returns two results:
|
||||
|
||||
1. the Copy Atom
|
||||
2. the so-called TMA tensor used to map logical coordinates of the GMEM tensor to coordinates
|
||||
that the TMA unit can consume. TMA tensors have so-called basis stride elements so that the
|
||||
associated layout can output coordinates. Otherwise, TMA tensors can be partitioned
|
||||
similarly to any other CuTe tensors using the algebra.
|
||||
|
||||
:param op: The Copy Operation to construct an Atom for
|
||||
:type op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp]
|
||||
:param gmem_tensor: The GMEM tensor to be loaded by this copy atom
|
||||
:type gmem_tensor: Tensor
|
||||
:param smem_layout: Shared memory layout to load the tensor into (PDSL)
|
||||
:type smem_layout: Layout
|
||||
:param mma_tiler_mnk: The MMA Tiler shape (TILE_M, TILE_N, TILE_K) in MNK dimensions
|
||||
:type mma_tiler_mnk: Shape
|
||||
:param tiled_mma: The TiledMMA that will consume the load as operands
|
||||
:type tiled_mma: core.TiledMma
|
||||
:param cluster_shape_vmnk: The Cluster-level shape in VMNK dimensions
|
||||
:type cluster_shape_vmnk: Shape
|
||||
:param internal_type: An optional parameter for the internal data type to when element
|
||||
type does not match the copy type
|
||||
:type internal_type: Type[Numeric]
|
||||
:return: A Copy Atom for this Operation and the associated TMA tensor
|
||||
:rtype: Tuple[core.CopyAtom, Tensor]
|
||||
|
||||
"""
|
||||
|
||||
if internal_type is not None:
|
||||
if not isinstance(internal_type, NumericMeta):
|
||||
raise TypeError(f"internal_type must be a Numeric, but got {internal_type}")
|
||||
@ -114,7 +204,7 @@ def make_tma_tile_atom_B(
|
||||
op,
|
||||
[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp],
|
||||
"op",
|
||||
"make_tma_tile_atom_B",
|
||||
"make_tiled_tma_atom_B",
|
||||
)
|
||||
|
||||
ident = core.make_identity_layout(gmem_tensor.shape, loc=loc, ip=ip)
|
||||
@ -154,6 +244,6 @@ def make_tma_tile_atom_B(
|
||||
|
||||
|
||||
__all__ = [
|
||||
"make_tma_tile_atom_A",
|
||||
"make_tma_tile_atom_B",
|
||||
"make_tiled_tma_atom_A",
|
||||
"make_tiled_tma_atom_B",
|
||||
]
|
||||
|
||||
@ -98,7 +98,10 @@ class _LdBase(CopyOp):
|
||||
repeat: Repetition = Repetition.x1
|
||||
pack: Pack = Pack.NONE
|
||||
|
||||
admissible_archs = ["sm_100a"]
|
||||
admissible_archs = [
|
||||
"sm_100a",
|
||||
"sm_100f",
|
||||
]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Arch verification
|
||||
@ -284,7 +287,10 @@ class _StBase(CopyOp):
|
||||
repeat: Repetition
|
||||
unpack: Unpack = Unpack.NONE
|
||||
|
||||
admissible_archs = ["sm_100a"]
|
||||
admissible_archs = [
|
||||
"sm_100a",
|
||||
"sm_100f",
|
||||
]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Arch verification
|
||||
|
||||
@ -136,7 +136,10 @@ class MmaOp(MmaOp):
|
||||
a_major_mode: OperandMajorMode
|
||||
b_major_mode: OperandMajorMode
|
||||
|
||||
admissible_archs = ["sm_100a"]
|
||||
admissible_archs = [
|
||||
"sm_100a",
|
||||
"sm_100f",
|
||||
]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Verify arch
|
||||
|
||||
@ -339,10 +339,10 @@ class MmaF8Op(MmaOp):
|
||||
"expects the 'b_dtype' Op parameter to be one of Float8E5M2 or Float8E4M3FN",
|
||||
)
|
||||
# Accumulator data type verification
|
||||
if self.acc_dtype != Float32:
|
||||
if self.acc_dtype not in [Float16, Float32]:
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'acc_dtype' Op parameter to be Float32",
|
||||
"expects the 'acc_dtype' Op parameter to be one of Float16 or Float32",
|
||||
)
|
||||
# Verify the instruction shape
|
||||
instruction_k = 32
|
||||
|
||||
@ -20,6 +20,7 @@ from typing import Union
|
||||
from cutlass._mlir import ir
|
||||
import cutlass._mlir.dialects.cute as _cute_ir
|
||||
|
||||
from cutlass.base_dsl.dsl import is_dynamic_expression
|
||||
from cutlass.cutlass_dsl import TensorFormat, JitArgAdapterRegistry
|
||||
|
||||
# Local modules imports
|
||||
@ -45,7 +46,8 @@ from .typing import (
|
||||
BFloat16,
|
||||
Float8E5M2,
|
||||
)
|
||||
from .core import find, _Tensor as CoreTensor
|
||||
from . import core
|
||||
from .core import _Tensor as CoreTensor
|
||||
|
||||
|
||||
class _Pointer(Pointer):
|
||||
@ -131,6 +133,9 @@ class _Pointer(Pointer):
|
||||
def memspace(self):
|
||||
return self._addr_space
|
||||
|
||||
def align(self, min_align: int, *, loc=None, ip=None) -> Pointer:
|
||||
raise NotImplementedError("align is not supported in runtime")
|
||||
|
||||
def verify(self, expected_py_type):
|
||||
if expected_py_type is Pointer:
|
||||
return True
|
||||
@ -361,7 +366,7 @@ class _Tensor(Tensor):
|
||||
* If nested leading dimensions are found, returns a tuple of indices
|
||||
* If no leading dimension is found, returns None
|
||||
"""
|
||||
return find(1, self.stride, exclude_when=(1, self.shape))
|
||||
return core.leading_dim(self.shape, self.stride)
|
||||
|
||||
def fill(self, value: Numeric):
|
||||
raise TypeError(f"fill function is not supported in runtime")
|
||||
@ -479,12 +484,8 @@ class TensorAdapter:
|
||||
Convert a DLPack protocol supported tensor/array to a cute tensor.
|
||||
"""
|
||||
|
||||
# Need reference these capsules to avoid being garbage collected
|
||||
tensor_capsules = []
|
||||
|
||||
def __init__(self, arg):
|
||||
self._arg = from_dlpack(arg).mark_layout_dynamic()
|
||||
self.tensor_capsules.append(self._arg)
|
||||
|
||||
def __new_from_mlir_values__(self, values):
|
||||
return self._arg.__new_from_mlir_values__(values)
|
||||
|
||||
@ -9,29 +9,26 @@
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
import random
|
||||
import numpy as np
|
||||
import functools
|
||||
import hashlib
|
||||
|
||||
from cutlass.cutlass_dsl import (
|
||||
const,
|
||||
T,
|
||||
CuTeDSL,
|
||||
BaseDSL,
|
||||
t,
|
||||
Constexpr,
|
||||
detect_gpu_arch,
|
||||
)
|
||||
|
||||
import cutlass._mlir.dialects.cute as _cute_ir
|
||||
import cutlass._mlir.ir as ir
|
||||
from cutlass._mlir.dialects import nvvm, cf, vector, builtin
|
||||
|
||||
from cutlass.cute import core
|
||||
from cutlass.cute import nvgpu
|
||||
from typing import Type
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
from enum import Enum
|
||||
from inspect import isclass
|
||||
from itertools import product
|
||||
from time import time
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
||||
|
||||
import cuda.bindings.driver as cuda_driver
|
||||
import cuda.bindings.runtime as cuda_runtime
|
||||
import numpy as np
|
||||
|
||||
import cutlass._mlir.ir as ir
|
||||
import cutlass.base_dsl.jit_executor
|
||||
import cutlass.cute as cute
|
||||
from cutlass._mlir.dialects import builtin, cf, nvvm, vector
|
||||
from cutlass.cute import core, nvgpu
|
||||
from cutlass.cutlass_dsl import Constexpr, CuTeDSL, T, t
|
||||
|
||||
|
||||
def assert_(cond, msg=None):
|
||||
@ -248,9 +245,10 @@ def sample_pytest(rand_cfg=None):
|
||||
import functools
|
||||
import os
|
||||
import random
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
seed, sample_ratio = rand_cfg
|
||||
random.seed(seed)
|
||||
|
||||
@ -270,3 +268,311 @@ def sample_pytest(rand_cfg=None):
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
#########################################
|
||||
# Benchmarking utilities
|
||||
#########################################
|
||||
|
||||
|
||||
class JitArguments:
|
||||
"""
|
||||
A type to hold both args and kwargs for passing to a kernel while benchmarking.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
|
||||
def _cuda_success(
|
||||
err: Union[tuple, cuda_runtime.cudaError_t, cuda_driver.CUresult], message: str
|
||||
):
|
||||
"""
|
||||
Helper function to check CUDA API errors.
|
||||
"""
|
||||
if isinstance(err, tuple):
|
||||
_cuda_success(err[0], message)
|
||||
elif isinstance(err, cuda_runtime.cudaError_t):
|
||||
error_message = cuda_runtime.cudaGetErrorString(err)[1].decode("utf-8")
|
||||
if err != cuda_runtime.cudaError_t.cudaSuccess:
|
||||
raise RuntimeError(f"{message} : {error_message}")
|
||||
elif isinstance(err, cuda_driver.CUresult):
|
||||
if err != cuda_driver.CUresult.CUDA_SUCCESS:
|
||||
error_message = cuda_driver.cuGetErrorString(err)[1].decode("utf-8")
|
||||
raise RuntimeError(f"{message} : {error_message}")
|
||||
else:
|
||||
raise TypeError(
|
||||
f"{err} is an unexpected type : it should be a cudaError_t or CUresult"
|
||||
)
|
||||
|
||||
|
||||
def _does_kernel_use_stream(
|
||||
kernel: Callable, stream: cuda_driver.CUstream, *args, **kwargs
|
||||
):
|
||||
"""
|
||||
This function checks if the kernel uses the provided non-default stream.
|
||||
It does this by capturing the stream and then checking if any kernels were launched.
|
||||
:param kernel: The kernel to check
|
||||
:type kernel: Callable
|
||||
:param stream: The stream to check
|
||||
:type stream: cuda_driver.CUstream
|
||||
:return: True if the kernel uses the stream, False otherwise
|
||||
:rtype: bool
|
||||
"""
|
||||
|
||||
assert int(stream) != int(
|
||||
cuda_driver.CUstream_flags.CU_STREAM_DEFAULT
|
||||
), "Stream must be a non-default stream"
|
||||
|
||||
err = cuda_runtime.cudaStreamBeginCapture(
|
||||
stream, cuda_runtime.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal
|
||||
)
|
||||
_cuda_success(err, "Error on stream capture")
|
||||
|
||||
kernel(*args, **kwargs)
|
||||
|
||||
err, graph = cuda_runtime.cudaStreamEndCapture(stream)
|
||||
_cuda_success(err, "Error on stream capture")
|
||||
|
||||
# Get number of nodes in warmup graph to check it matches what is expected
|
||||
err, _, num_nodes = cuda_runtime.cudaGraphGetNodes(graph)
|
||||
_cuda_success(err, "Error on querying graph")
|
||||
return num_nodes > 0
|
||||
|
||||
|
||||
def benchmark(
|
||||
callable: Callable,
|
||||
*,
|
||||
warmup_iterations: int = 10,
|
||||
profiling_iterations: int = 100,
|
||||
stream: Optional[cuda_driver.CUstream] = None,
|
||||
kernel_arguments: Optional[JitArguments] = None,
|
||||
workspace_generator: Optional[Callable[[], JitArguments]] = None,
|
||||
workspace_count: int = 1,
|
||||
use_cuda_graphs: bool = False,
|
||||
) -> float:
|
||||
"""Benchmarks a callable function with the specified parameters.
|
||||
|
||||
For example,
|
||||
.. code-block:: python
|
||||
|
||||
from cutlass.cute.testing import benchmark
|
||||
|
||||
@cute.jit
|
||||
def user_function(a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda_driver.CUstream):
|
||||
# contents of the function
|
||||
pass
|
||||
|
||||
time_us = benchmark(user_function, kernel_arguments=JitArguments(a, b, c, stream)
|
||||
warmup_iterations=10, profiling_iterations=100
|
||||
stream=stream)
|
||||
|
||||
To prevent skewing results by repeately accessing the L2 cache, use the workspace_count and workspace_generator
|
||||
parameters to cycle through a number of different workspaces.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from cutlass.cute.testing import benchmark
|
||||
|
||||
@cute.jit
|
||||
def user_function(a: cute.Tensor, b: cute.Tensor, c: cute.Tensor):
|
||||
# contents of the function
|
||||
pass
|
||||
|
||||
def workspace_generator():
|
||||
# create a, b, and c
|
||||
return JitArguments(a, b, c)
|
||||
|
||||
time_us = benchmark(user_function,
|
||||
workspace_generator=workspace_generator,
|
||||
workspace_count=10,
|
||||
warmup_iterations=10000,
|
||||
profiling_iterations=1000)
|
||||
|
||||
To benchmark you may always configure the function being profiled (callable), the warmup iterations, and
|
||||
the number of profiling iterations.
|
||||
|
||||
Whenever the kernel being benchmarked runs in a non-default stream, the stream must be provided through the stream parameter.
|
||||
|
||||
To use CUDA graphs, the callable must be a compiled @cute.jit annotated function.
|
||||
When using CUDA graphs, the kernel must be launched in a non-default stream.
|
||||
|
||||
:param callable: The function to benchmark
|
||||
:type callable: Callable
|
||||
:param warmup_iterations: Number of warmup iterations, defaults to 10
|
||||
:type warmup_iterations: int, optional
|
||||
:param profiling_iterations: Number of benchmark iterations, defaults to 100
|
||||
:type profiling_iterations: int, optional
|
||||
:param stream: Stream kernel is launched in, defaults to CUDA stream default
|
||||
:type stream: CUstream, None
|
||||
:param kernel_arguments: Kernel arguments to launch callable with, defaults to None
|
||||
:type kernel_arguments: JitArguments, None
|
||||
:param workspace_generator: Function that returns kernel arguments, defaults to None
|
||||
:type workspace_generator: Callable
|
||||
:param workspace_count: Number of workspaces (arguments) to loop through, looping through enough workspaces will keep the L2 cache cold
|
||||
:type workspace_count: int, optional
|
||||
:param use_cuda_graphs: Whether to use cuda graphs, defaults to False
|
||||
:type use_cuda_graphs: bool, optional
|
||||
|
||||
:return: The benchmark time in microseconds
|
||||
:rtype: float
|
||||
"""
|
||||
|
||||
if stream is None:
|
||||
stream = cuda_driver.CUstream(cuda_driver.CUstream_flags.CU_STREAM_DEFAULT)
|
||||
|
||||
if workspace_count < 1:
|
||||
raise ValueError("workspace_count must be at least 1")
|
||||
|
||||
time_us = float("nan")
|
||||
if workspace_generator == None:
|
||||
# If no workspace generator is provided, we need a single workspace
|
||||
if workspace_count != 1:
|
||||
raise ValueError("Need a single workspace if not providing a generator")
|
||||
|
||||
# If no workspace generator is provided, we need a kernel_argument
|
||||
if kernel_arguments == None:
|
||||
raise ValueError(
|
||||
"Please pass a kernel argument if not providing a generator"
|
||||
)
|
||||
workspace_generator = lambda: kernel_arguments
|
||||
|
||||
workspaces = [workspace_generator() for _ in range(workspace_count)]
|
||||
|
||||
for workspace in workspaces:
|
||||
if type(workspace) != JitArguments:
|
||||
raise TypeError(
|
||||
"workspace_generator and/or kernel_arguments should use JitArguments type"
|
||||
)
|
||||
|
||||
def _loop_and_call_kernel(iterations: int, workspace_index: int = 0):
|
||||
for _ in range(iterations):
|
||||
current_workspace = workspaces[workspace_index]
|
||||
callable(*current_workspace.args, **current_workspace.kwargs)
|
||||
workspace_index = (workspace_index + 1) % workspace_count
|
||||
return workspace_index
|
||||
|
||||
# Create CUDA events for timing
|
||||
err, start_event = cuda_driver.cuEventCreate(
|
||||
cuda_driver.CUevent_flags.CU_EVENT_DEFAULT
|
||||
)
|
||||
_cuda_success(err, "Error on creating event")
|
||||
err, end_event = cuda_driver.cuEventCreate(
|
||||
cuda_driver.CUevent_flags.CU_EVENT_DEFAULT
|
||||
)
|
||||
_cuda_success(err, "Error on creating event")
|
||||
|
||||
elapsed_time = float("nan")
|
||||
|
||||
if use_cuda_graphs:
|
||||
# Check if the callable is a JitExecutor
|
||||
if not isinstance(callable, cutlass.base_dsl.jit_executor.JitExecutor):
|
||||
raise TypeError("Function must be precompiled to be used with CUDA Graphs")
|
||||
|
||||
# Check if the stream is a non-default stream
|
||||
if int(stream) == int(cuda_driver.CUstream_flags.CU_STREAM_DEFAULT):
|
||||
raise ValueError(
|
||||
"Measuring with CUDA Graphs requires executing in a non-default stream"
|
||||
)
|
||||
|
||||
workspace_index = 0
|
||||
|
||||
# Capture warmup graph
|
||||
err = cuda_runtime.cudaStreamBeginCapture(
|
||||
stream, cuda_runtime.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal
|
||||
)
|
||||
_cuda_success(err, "Error on stream capture")
|
||||
|
||||
workspace_index = _loop_and_call_kernel(warmup_iterations)
|
||||
err, gwarm = cuda_runtime.cudaStreamEndCapture(stream)
|
||||
_cuda_success(err, "Error on stream capture")
|
||||
|
||||
# Get number of nodes in warmup graph to check it matches what is expected
|
||||
err, _, num_nodes = cuda_runtime.cudaGraphGetNodes(gwarm)
|
||||
_cuda_success(err, "Error on querying graph")
|
||||
# Assertion is >= since we may launch multiple kernels in one host function
|
||||
if num_nodes < warmup_iterations:
|
||||
raise ValueError(
|
||||
f"CUDA stream passed to benchmark does not match the stream the kernel was launched in"
|
||||
)
|
||||
|
||||
# Capture profiling graph
|
||||
err = cuda_runtime.cudaStreamBeginCapture(
|
||||
stream, cuda_runtime.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal
|
||||
)
|
||||
_cuda_success(err, "Error on stream capture")
|
||||
_loop_and_call_kernel(profiling_iterations, workspace_index)
|
||||
err, gprofile = cuda_runtime.cudaStreamEndCapture(stream)
|
||||
_cuda_success(err, "Error on stream capture")
|
||||
|
||||
# Instantiate graphs
|
||||
err, gwarm = cuda_runtime.cudaGraphInstantiate(gwarm, 0)
|
||||
_cuda_success(err, "Error on graph instantiation")
|
||||
err, gprofile = cuda_runtime.cudaGraphInstantiate(gprofile, 0)
|
||||
_cuda_success(err, "Error on graph instantiation")
|
||||
|
||||
# Launch warmup graph
|
||||
err = cuda_runtime.cudaGraphLaunch(gwarm, stream)
|
||||
_cuda_success(err, "Error on graph launch")
|
||||
|
||||
# Record start time
|
||||
err = cuda_driver.cuEventRecord(start_event, stream)
|
||||
_cuda_success(err, "Error on recording event")
|
||||
|
||||
# Launch profiling graph
|
||||
err = cuda_runtime.cudaGraphLaunch(gprofile, stream)
|
||||
_cuda_success(err, "Error on graph launch")
|
||||
|
||||
# Record end time
|
||||
err = cuda_driver.cuEventRecord(end_event, stream)
|
||||
_cuda_success(err, "Error on recording event")
|
||||
err = cuda_driver.cuEventSynchronize(end_event)
|
||||
_cuda_success(err, "Error on synchronizing event")
|
||||
|
||||
# Get elapsed time
|
||||
err, elapsed_time = cuda_driver.cuEventElapsedTime(start_event, end_event)
|
||||
_cuda_success(err, "Error on querying event")
|
||||
|
||||
# Destroy graphs
|
||||
err = cuda_runtime.cudaGraphExecDestroy(gwarm)
|
||||
_cuda_success(err, "Error on destroying graph")
|
||||
err = cuda_runtime.cudaGraphExecDestroy(gprofile)
|
||||
_cuda_success(err, "Error on destroying graph")
|
||||
|
||||
else:
|
||||
|
||||
if int(stream) != int(
|
||||
cuda_driver.CUstream_flags.CU_STREAM_DEFAULT
|
||||
) and not _does_kernel_use_stream(
|
||||
callable, stream, *workspaces[0].args, **workspaces[0].kwargs
|
||||
):
|
||||
raise ValueError(
|
||||
"CUDA stream passed to benchmark does not match the stream the kernel was launched in"
|
||||
)
|
||||
|
||||
# Not using graphs
|
||||
# Warmup
|
||||
workspace_index = _loop_and_call_kernel(warmup_iterations)
|
||||
# Record start event
|
||||
err = cuda_driver.cuEventRecord(start_event, stream)
|
||||
_cuda_success(err, "Error on recording event")
|
||||
_loop_and_call_kernel(profiling_iterations, workspace_index)
|
||||
# Record end event
|
||||
err = cuda_driver.cuEventRecord(end_event, stream)
|
||||
_cuda_success(err, "Error on recording event")
|
||||
# Synchronize end event
|
||||
err = cuda_driver.cuEventSynchronize(end_event)
|
||||
_cuda_success(err, "Error on synchronizing event")
|
||||
err, elapsed_time = cuda_driver.cuEventElapsedTime(start_event, end_event)
|
||||
_cuda_success(err, "Error on querying event")
|
||||
|
||||
# Destroy events
|
||||
err = cuda_driver.cuEventDestroy(start_event)
|
||||
_cuda_success(err, "Error on destroying event")
|
||||
err = cuda_driver.cuEventDestroy(end_event)
|
||||
_cuda_success(err, "Error on destroying event")
|
||||
|
||||
return elapsed_time / profiling_iterations * 1e3
|
||||
|
||||
|
||||
|
||||
@ -68,6 +68,8 @@ class Pointer(ABC):
|
||||
@property
|
||||
def dtype(self) -> Type[Numeric]: ...
|
||||
|
||||
def align(self, min_align: int) -> "Pointer": ...
|
||||
|
||||
def __get_mlir_types__(self) -> List[ir.Type]: ...
|
||||
|
||||
def __extract_mlir_values__(self) -> List[ir.Value]: ...
|
||||
|
||||
62
python/CuTeDSL/cutlass/pipeline/__init__.py
Normal file
62
python/CuTeDSL/cutlass/pipeline/__init__.py
Normal file
@ -0,0 +1,62 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from .helpers import (
|
||||
Agent,
|
||||
CooperativeGroup,
|
||||
PipelineOp,
|
||||
SyncObject,
|
||||
MbarrierArray,
|
||||
NamedBarrier,
|
||||
TmaStoreFence,
|
||||
PipelineUserType,
|
||||
PipelineState,
|
||||
make_pipeline_state,
|
||||
pipeline_init_wait,
|
||||
arrive,
|
||||
arrive_unaligned,
|
||||
wait,
|
||||
wait_unaligned,
|
||||
arrive_and_wait,
|
||||
sync,
|
||||
)
|
||||
|
||||
from .sm90 import (
|
||||
PipelineAsync,
|
||||
PipelineTmaAsync,
|
||||
PipelineTmaMultiConsumersAsync,
|
||||
PipelineTmaStore,
|
||||
)
|
||||
|
||||
from .sm100 import (
|
||||
PipelineTmaUmma,
|
||||
PipelineAsyncUmma,
|
||||
PipelineUmmaAsync,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Agent",
|
||||
"CooperativeGroup",
|
||||
"PipelineOp",
|
||||
"SyncObject",
|
||||
"MbarrierArray",
|
||||
"NamedBarrier",
|
||||
"TmaStoreFence",
|
||||
"PipelineUserType",
|
||||
"PipelineState",
|
||||
"PipelineAsync",
|
||||
"PipelineTmaAsync",
|
||||
"PipelineTmaUmma",
|
||||
"PipelineTmaMultiConsumersAsync",
|
||||
"PipelineAsyncUmma",
|
||||
"PipelineUmmaAsync",
|
||||
"PipelineTmaStore",
|
||||
]
|
||||
645
python/CuTeDSL/cutlass/pipeline/helpers.py
Normal file
645
python/CuTeDSL/cutlass/pipeline/helpers.py
Normal file
@ -0,0 +1,645 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
import warnings
|
||||
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cutlass_dsl import Boolean, Int32, Int64, if_generate
|
||||
from cutlass._mlir.dialects import llvm
|
||||
import cutlass._mlir.dialects.cute as _cute_ir
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Agent class
|
||||
##############################################################################
|
||||
|
||||
|
||||
class Agent(enum.Enum):
|
||||
"""
|
||||
Agent indicates what is participating in the pipeline synchronization.
|
||||
"""
|
||||
|
||||
# Arbitrary grouping of N threads
|
||||
Thread = enum.auto()
|
||||
# Same as AsyncThread, but includes all threads in the block
|
||||
ThreadBlock = enum.auto()
|
||||
# Same as AsyncThread, but includes all threads in the cluster
|
||||
ThreadBlockCluster = enum.auto()
|
||||
|
||||
|
||||
class CooperativeGroup:
|
||||
"""
|
||||
CooperativeGroup contains size and alignment restrictions for an Agent.
|
||||
"""
|
||||
|
||||
def __init__(self, agent: Agent, size: int = 1, alignment: int = 1):
|
||||
if agent is Agent.Thread:
|
||||
assert size > 0
|
||||
if size == 32:
|
||||
assert (
|
||||
size == alignment
|
||||
), "Error: Alignment does not match number of threads in a warp."
|
||||
elif size == 128:
|
||||
assert (
|
||||
size == alignment
|
||||
), "Error: Alignment does not match number of threads in a warpgroup."
|
||||
elif agent is Agent.ThreadBlock:
|
||||
raise NotImplementedError("Error: Not yet supported.")
|
||||
elif agent is Agent.ThreadBlockCluster:
|
||||
raise NotImplementedError("Error: Not yet supported.")
|
||||
else:
|
||||
# Should never reach this state
|
||||
size = 0
|
||||
|
||||
if size <= 0:
|
||||
raise ValueError(
|
||||
"Error: The number of threads in a CooperativeGroup must be more than 0."
|
||||
)
|
||||
|
||||
# Size indicates how many threads are participating in this CooperativeGroup
|
||||
self.size = size
|
||||
# Agent indicates the type of thread group
|
||||
self.agent = agent
|
||||
|
||||
|
||||
class PipelineOp(enum.Enum):
|
||||
"""
|
||||
PipelineOp assigns an operation to an agent corresponding to a specific hardware feature.
|
||||
"""
|
||||
|
||||
# async-threads
|
||||
AsyncThread = enum.auto()
|
||||
# Blackwell (SM100a) MMA instruction
|
||||
TCGen05Mma = enum.auto()
|
||||
# Tensor Memory Accelerator load
|
||||
TmaLoad = enum.auto()
|
||||
# TMA Store consuming smem produced by AsyncThread
|
||||
TmaStore = enum.auto()
|
||||
# Composite of multiple PipelineOps
|
||||
Composite = enum.auto()
|
||||
|
||||
|
||||
def _get_pipeline_op(type_str):
|
||||
return PipelineOp(type_str)
|
||||
|
||||
|
||||
##############################################################################
|
||||
# SyncObject class
|
||||
##############################################################################
|
||||
|
||||
|
||||
class SyncObject(ABC):
|
||||
"""Abstract base class for hardware synchronization primitives.
|
||||
|
||||
This class defines the interface for different types of hardware synchronization
|
||||
mechanisms including shared memory barriers, named barriers, and fences.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def arrive(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def wait(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def arrive_and_wait(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def arrive_and_drop(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_barrier(self) -> Union[cute.Pointer, int, None]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def max(self) -> Union[int, None]:
|
||||
pass
|
||||
|
||||
|
||||
class MbarrierArray(SyncObject):
|
||||
"""
|
||||
MbarrierArray implements an abstraction for an array of smem barriers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
barrier_storage: cute.Pointer,
|
||||
num_stages: int,
|
||||
agent: tuple[PipelineOp, CooperativeGroup],
|
||||
tx_count: int = 0,
|
||||
) -> None:
|
||||
self.barrier_storage = barrier_storage
|
||||
self.tx_count = tx_count
|
||||
self.num_stages = num_stages
|
||||
self.op_type, self.cg = agent
|
||||
self.arrive_count = self.cg.size
|
||||
|
||||
if self.num_stages <= 0:
|
||||
raise ValueError("Error: Mbarrier stage count must be greater than 0.")
|
||||
if self.arrive_count <= 0:
|
||||
raise ValueError("Error: Mbarrier arrive count must be greater than 0.")
|
||||
if self.op_type is PipelineOp.TmaLoad and self.tx_count < 0:
|
||||
raise ValueError(
|
||||
"Error: Mbarrier tx count must not be less than 0 for TMA ops."
|
||||
)
|
||||
|
||||
# Store mbarrier base pointer
|
||||
self.mbarrier_base = self.barrier_storage
|
||||
|
||||
# Mbarrier initialization in constructor
|
||||
self.mbarrier_init()
|
||||
|
||||
def recast_to_new_op_type(self, new_op_type: PipelineOp) -> "MbarrierArray":
|
||||
"""
|
||||
Creates a copy of MbarrierArray with a different op_type without re-initializing barriers
|
||||
"""
|
||||
# Create new instance without initialization
|
||||
new_mbarrier_array = object.__new__(MbarrierArray)
|
||||
|
||||
# Copy all attributes directly
|
||||
new_mbarrier_array.barrier_storage = self.barrier_storage
|
||||
new_mbarrier_array.op_type = new_op_type
|
||||
new_mbarrier_array.cg = self.cg
|
||||
new_mbarrier_array.num_stages = self.num_stages
|
||||
new_mbarrier_array.tx_count = self.tx_count
|
||||
new_mbarrier_array.arrive_count = self.arrive_count
|
||||
new_mbarrier_array.mbarrier_base = self.mbarrier_base
|
||||
return new_mbarrier_array
|
||||
|
||||
# Mbarrier initialization
|
||||
def mbarrier_init(self) -> None:
|
||||
"""
|
||||
Initializes an array of mbarriers using warp 0.
|
||||
"""
|
||||
|
||||
def then_body():
|
||||
for index in range(self.num_stages):
|
||||
cute.arch.mbarrier_init(self.get_barrier(index), self.arrive_count)
|
||||
|
||||
warp_idx = cute.arch.warp_idx()
|
||||
warp_idx = cute.arch.make_warp_uniform(warp_idx)
|
||||
|
||||
if_generate(warp_idx == 0, then_body)
|
||||
|
||||
def arrive(
|
||||
self,
|
||||
index: int,
|
||||
dst: int,
|
||||
cta_group: Optional[cute.nvgpu.tcgen05.CtaGroup] = None,
|
||||
) -> None:
|
||||
"""Select the arrive corresponding to this MbarrierArray's PipelineOp.
|
||||
|
||||
:param index: Index of the mbarrier in the array to arrive on
|
||||
:type index: int
|
||||
:param dst: Destination parameter for selective arrival, which can be either a mask or destination cta rank.
|
||||
When None, both ``TCGen05Mma`` and ``AsyncThread`` will arrive on their local mbarrier.
|
||||
- For ``TCGen05Mma``, ``dst`` serves as a multicast mask (e.g., 0b1011 allows arrive signal to be multicast to CTAs
|
||||
in the cluster with rank = 0, 1, and 3).
|
||||
- For ``AsyncThread``, ``dst`` serves as a destination cta rank (e.g., 3 means threads will arrive on
|
||||
the mbarrier with rank = 3 in the cluster).
|
||||
:type dst: int | None
|
||||
:param cta_group: CTA group for ``TCGen05Mma``, defaults to None for other op types
|
||||
:type cta_group: ``cute.nvgpu.tcgen05.CtaGroup``, optional
|
||||
"""
|
||||
if self.op_type is PipelineOp.AsyncThread:
|
||||
self.arrive_mbarrier(index, dst)
|
||||
elif self.op_type is PipelineOp.TCGen05Mma:
|
||||
assert (
|
||||
cta_group is not None
|
||||
), "Error: CTA group must be provided for TCGen05Mma."
|
||||
self.arrive_tcgen05mma(index, dst, cta_group)
|
||||
elif self.op_type in [PipelineOp.TmaLoad]:
|
||||
self.arrive_and_expect_tx(index, self.tx_count)
|
||||
else:
|
||||
assert (
|
||||
False
|
||||
), f"Error: MbarrierArray is not supported for PipelineOp: {_get_pipeline_op(self.op_type)}."
|
||||
|
||||
def arrive_mbarrier(self, index: int, dst_rank: Optional[int] = None) -> None:
|
||||
if dst_rank is None:
|
||||
cute.arch.mbarrier_arrive(self.get_barrier(index))
|
||||
else:
|
||||
cute.arch.mbarrier_arrive(self.get_barrier(index), dst_rank)
|
||||
|
||||
def arrive_tcgen05mma(
|
||||
self, index: int, mask: Optional[int], cta_group: cute.nvgpu.tcgen05.CtaGroup
|
||||
) -> None:
|
||||
if mask is None:
|
||||
with cute.arch.elect_one():
|
||||
cute.nvgpu.tcgen05.commit(self.get_barrier(index))
|
||||
else:
|
||||
with cute.arch.elect_one():
|
||||
cute.nvgpu.tcgen05.commit(self.get_barrier(index), mask, cta_group)
|
||||
|
||||
def arrive_and_expect_tx(self, index: int, tx_count: int) -> None:
|
||||
with cute.arch.elect_one():
|
||||
cute.arch.mbarrier_arrive_and_expect_tx(self.get_barrier(index), tx_count)
|
||||
|
||||
def try_wait(self, index: int, phase: int) -> Boolean:
|
||||
return cute.arch.mbarrier_try_wait(self.get_barrier(index), phase)
|
||||
|
||||
def wait(self, index: int, phase: int) -> None:
|
||||
cute.arch.mbarrier_wait(self.get_barrier(index), phase)
|
||||
|
||||
def arrive_and_wait(
|
||||
self,
|
||||
index: int,
|
||||
phase: int,
|
||||
dst: int,
|
||||
cta_group: Optional[cute.nvgpu.tcgen05.CtaGroup] = None,
|
||||
) -> None:
|
||||
arrive(index, dst, cta_group)
|
||||
wait(index, phase)
|
||||
|
||||
def arrive_and_drop(self) -> None:
|
||||
raise NotImplementedError("Error: Not yet supported.")
|
||||
|
||||
def get_barrier(self, index: int) -> cute.Pointer:
|
||||
return self.mbarrier_base + index
|
||||
|
||||
def max(self) -> int:
|
||||
# Transaction barriers have a maximum arrive count of 511 (2^9 - 1).
|
||||
# Non-transaction barriers have a maximum arrive count of 1,048,575 (2^20 - 1).
|
||||
return 511
|
||||
|
||||
def __extract_mlir_values__(self):
|
||||
return [self.barrier_storage]
|
||||
|
||||
def __new_from_mlir_values__(self, values):
|
||||
return MbarrierArray(
|
||||
values[0], self.num_stages, (self.op_type, self.cg), self.tx_count
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NamedBarrier(SyncObject):
|
||||
"""
|
||||
NamedBarrier is an abstraction for named barriers managed by hardware.
|
||||
There are 16 named barriers available, with barrier_ids 0-15.
|
||||
|
||||
See the `PTX documentation <https://https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-bar>`__.
|
||||
"""
|
||||
|
||||
barrier_id: int
|
||||
num_threads: int
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.barrier_id < 0 or self.barrier_id >= 16:
|
||||
raise ValueError("Error: NamedBarrier ID must be between 0 and 16.")
|
||||
if self.barrier_id == 0:
|
||||
warnings.warn(
|
||||
"NamedBarrier ID 0 is by other driver APIs (i.e. sync_threads()) and should not be used."
|
||||
)
|
||||
|
||||
def arrive(self) -> None:
|
||||
"""
|
||||
The aligned flavor of arrive is used when all threads in the CTA will execute the
|
||||
same instruction. See PTX documentation.
|
||||
"""
|
||||
cute.arch.barrier_arrive(
|
||||
barrier_id=self.barrier_id, number_of_threads=self.num_threads
|
||||
)
|
||||
|
||||
def arrive_unaligned(self) -> None:
|
||||
"""
|
||||
The unaligned flavor of arrive can be used with an arbitrary number of threads in the CTA.
|
||||
"""
|
||||
llvm.inline_asm(
|
||||
None,
|
||||
[Int32(self.barrier_id).ir_value(), Int32(self.num_threads).ir_value()],
|
||||
"barrier.arrive $0, $1;",
|
||||
"r,r",
|
||||
has_side_effects=True,
|
||||
is_align_stack=False,
|
||||
asm_dialect=llvm.AsmDialect.AD_ATT,
|
||||
)
|
||||
|
||||
def wait(self) -> None:
|
||||
"""
|
||||
NamedBarriers do not have a standalone wait like mbarriers, only an arrive_and_wait.
|
||||
If synchronizing two warps in a producer/consumer pairing, the arrive count would be
|
||||
32 using mbarriers but 64 using NamedBarriers. Only threads from either the producer
|
||||
or consumer are counted for mbarriers, while all threads participating in the sync
|
||||
are counted for NamedBarriers.
|
||||
"""
|
||||
warnings.warn(
|
||||
"NamedBarrier wait also arrives on the barrier. Routing call to NamedBarrier.arrive_and_wait()."
|
||||
)
|
||||
self.arrive_and_wait()
|
||||
|
||||
def wait_unaligned(self) -> None:
|
||||
warnings.warn(
|
||||
"NamedBarrier wait also arrives on the barrier. Routing call to NamedBarrier.arrive_and_wait()."
|
||||
)
|
||||
llvm.inline_asm(
|
||||
None,
|
||||
[Int32(self.barrier_id).ir_value(), Int32(self.num_threads).ir_value()],
|
||||
"barrier.sync $0, $1;",
|
||||
"r,r",
|
||||
has_side_effects=True,
|
||||
is_align_stack=False,
|
||||
asm_dialect=llvm.AsmDialect.AD_ATT,
|
||||
)
|
||||
|
||||
def arrive_and_wait(self) -> None:
|
||||
cute.arch.barrier(
|
||||
barrier_id=self.barrier_id, number_of_threads=self.num_threads
|
||||
)
|
||||
|
||||
def arrive_and_drop(self) -> None:
|
||||
raise NotImplementedError("Error: Not supported.")
|
||||
|
||||
def sync(self) -> None:
|
||||
cute.arch.barrier(barrier_id=self.barrier_id)
|
||||
|
||||
def get_barrier(self) -> int:
|
||||
return self.barrier_id
|
||||
|
||||
def max(self) -> int:
|
||||
# Transaction barriers have a maximum arrive count of 4095 (2^12 - 1).
|
||||
return 4095
|
||||
|
||||
|
||||
class TmaStoreFence(SyncObject):
|
||||
"""
|
||||
TmaStoreFence is used for a multi-stage epilogue buffer.
|
||||
"""
|
||||
|
||||
def __init__(self, num_stages: int = 0) -> None:
|
||||
if num_stages <= 0:
|
||||
raise ValueError("Mbarrier stage count must be greater than 0.")
|
||||
|
||||
self.num_stages = num_stages
|
||||
|
||||
def arrive(self) -> None:
|
||||
cute.arch.cp_async_bulk_commit_group()
|
||||
|
||||
def wait(self) -> None:
|
||||
cute.arch.cp_async_bulk_wait_group(self.num_stages - 1, read=True)
|
||||
|
||||
def arrive_and_wait(self) -> None:
|
||||
self.arrive()
|
||||
self.wait()
|
||||
|
||||
def arrive_and_drop(self) -> None:
|
||||
raise NotImplementedError("Error: Not supported.")
|
||||
|
||||
# TmaStoreFence doesn't have mbarriers
|
||||
def get_barrier(self) -> None:
|
||||
assert (
|
||||
False
|
||||
), "Error: TmaStoreFence doesn't use mbarriers and cannot return a barrier."
|
||||
|
||||
def max(self) -> None:
|
||||
raise NotImplementedError("Error: Not supported.")
|
||||
|
||||
def tail(self) -> None:
|
||||
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
||||
|
||||
|
||||
##############################################################################
|
||||
# PipelineState class
|
||||
##############################################################################
|
||||
|
||||
|
||||
class PipelineUserType(enum.Enum):
|
||||
Producer = enum.auto()
|
||||
Consumer = enum.auto()
|
||||
|
||||
|
||||
class PipelineState:
|
||||
"""
|
||||
Pipeline state contains an index and phase bit corresponding to the current position in the circular buffer.
|
||||
"""
|
||||
|
||||
def __init__(self, stages: int, count, index, phase):
|
||||
self._stages = stages
|
||||
self._count = count
|
||||
self._index = index
|
||||
self._phase = phase
|
||||
|
||||
def clone(self) -> "PipelineState":
|
||||
return PipelineState(self.stages, self._count, self.index, self.phase)
|
||||
|
||||
@property
|
||||
def index(self) -> Int32:
|
||||
return self._index
|
||||
|
||||
@property
|
||||
def count(self) -> Int32:
|
||||
return self._count
|
||||
|
||||
@property
|
||||
def stages(self) -> int:
|
||||
return self._stages
|
||||
|
||||
@property
|
||||
def phase(self) -> Int32:
|
||||
return self._phase
|
||||
|
||||
def reset_count(self):
|
||||
self._count = Int32(0)
|
||||
|
||||
def advance(self):
|
||||
self._index += 1
|
||||
self._count += 1
|
||||
|
||||
def then_body(index, phase):
|
||||
new_index = Int32(0)
|
||||
new_phase = phase ^ 1
|
||||
return new_index, new_phase
|
||||
|
||||
def else_body(index, phase):
|
||||
return index, phase
|
||||
|
||||
self._index, self._phase = if_generate(
|
||||
self._index == self.stages,
|
||||
then_body,
|
||||
else_body,
|
||||
[self.index, self.phase],
|
||||
[Int32, Int32],
|
||||
)
|
||||
|
||||
def reverse(self):
|
||||
self._index -= 1
|
||||
self._count -= 1
|
||||
|
||||
def then_body(index, phase):
|
||||
new_index = Int32(self.stages - 1)
|
||||
new_phase = phase ^ 1
|
||||
return new_index, new_phase
|
||||
|
||||
def else_body(index, phase):
|
||||
return index, phase
|
||||
|
||||
self._index, self._phase = if_generate(
|
||||
self._index == -1,
|
||||
then_body,
|
||||
else_body,
|
||||
[self.index, self.phase],
|
||||
[Int32, Int32],
|
||||
)
|
||||
|
||||
def __get_mlir_types__(self):
|
||||
return [self._count.type, self._index.type, self._phase.type]
|
||||
|
||||
def __extract_mlir_values__(self):
|
||||
count = self._count
|
||||
index = self._index
|
||||
phase = self._phase
|
||||
return [count.ir_value(), index.ir_value(), phase.ir_value()]
|
||||
|
||||
# This can be overridden by derived classes
|
||||
def __new_from_mlir_values__(self, values):
|
||||
return PipelineState(
|
||||
self.stages, Int32(values[0]), Int32(values[1]), Int32(values[2])
|
||||
)
|
||||
|
||||
|
||||
def make_pipeline_state(type: PipelineUserType, stages: int):
|
||||
"""
|
||||
Creates a pipeline state. Producers are assumed to start with an empty buffer and have a flipped phase bit of 1.
|
||||
"""
|
||||
if type is PipelineUserType.Producer:
|
||||
return PipelineState(
|
||||
stages,
|
||||
Int32(0),
|
||||
Int32(0),
|
||||
Int32(1),
|
||||
)
|
||||
elif type is PipelineUserType.Consumer:
|
||||
return PipelineState(
|
||||
stages,
|
||||
Int32(0),
|
||||
Int32(0),
|
||||
Int32(0),
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
False
|
||||
), "Error: invalid PipelineUserType specified for make_pipeline_state."
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Helper functions
|
||||
##############################################################################
|
||||
|
||||
|
||||
def pipeline_init_wait(cta_layout_vmnk: Optional[cute.Layout] = None):
|
||||
"""
|
||||
Fences the mbarrier init and syncs the threadblock or cluster
|
||||
"""
|
||||
cute.arch.mbarrier_init_fence()
|
||||
|
||||
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
|
||||
# If not using clusters, sync the threadblock
|
||||
_sync(Agent.ThreadBlock)
|
||||
else:
|
||||
# If using clusters, sync the cluster
|
||||
_sync(Agent.ThreadBlockCluster)
|
||||
|
||||
|
||||
def _sync(group: Agent):
|
||||
"""
|
||||
Syncs all threads within an agent.
|
||||
"""
|
||||
if group is Agent.Thread:
|
||||
raise NotImplementedError("Error: Not supported.")
|
||||
elif group is Agent.ThreadBlock:
|
||||
cute.arch.sync_threads()
|
||||
elif group is Agent.ThreadBlockCluster:
|
||||
cute.arch.cluster_arrive()
|
||||
cute.arch.cluster_wait()
|
||||
else:
|
||||
assert (
|
||||
False
|
||||
), "Error: No explicit sync instruction exists. Please use barriers (named / mbarrier) instead."
|
||||
|
||||
|
||||
def _mbarrier_i64_to_ptr(val: Int64) -> cute.Pointer:
|
||||
"""
|
||||
Converts a smem pointer of type Int64 to cute.Pointer with 8B alignment
|
||||
"""
|
||||
return cute.make_ptr(
|
||||
Int64,
|
||||
val.ir_value(),
|
||||
mem_space=_cute_ir.AddressSpace.smem,
|
||||
assumed_align=8,
|
||||
)
|
||||
|
||||
|
||||
# NamedBarrier free functions
|
||||
def arrive(barrier_id: int, num_threads: int):
|
||||
"""
|
||||
The aligned flavor of arrive is used when all threads in the CTA will execute the
|
||||
same instruction. See PTX documentation.
|
||||
"""
|
||||
cute.arch.barrier_arrive(barrier_id=barrier_id, number_of_threads=num_threads)
|
||||
|
||||
|
||||
def arrive_unaligned(barrier_id: int, num_threads: int):
|
||||
"""
|
||||
The unaligned flavor of arrive can be used with an arbitrary number of threads in the CTA.
|
||||
"""
|
||||
llvm.inline_asm(
|
||||
None,
|
||||
[Int32(barrier_id).ir_value(), Int32(num_threads).ir_value()],
|
||||
"barrier.arrive $0, $1;",
|
||||
"r,r",
|
||||
has_side_effects=True,
|
||||
is_align_stack=False,
|
||||
asm_dialect=llvm.AsmDialect.AD_ATT,
|
||||
)
|
||||
|
||||
|
||||
def wait(barrier_id: int, num_threads: int):
|
||||
"""
|
||||
NamedBarriers do not have a standalone wait like mbarriers, only an arrive_and_wait.
|
||||
If synchronizing two warps in a producer/consumer pairing, the arrive count would be
|
||||
32 using mbarriers but 64 using NamedBarriers. Only threads from either the producer
|
||||
or consumer are counted for mbarriers, while all threads participating in the sync
|
||||
are counted for NamedBarriers.
|
||||
"""
|
||||
warnings.warn(
|
||||
"NamedBarrier wait also arrives on the barrier. Routing call to NamedBarrier.arrive_and_wait()."
|
||||
)
|
||||
arrive_and_wait()
|
||||
|
||||
|
||||
def wait_unaligned(barrier_id: int, num_threads: int):
|
||||
warnings.warn(
|
||||
"NamedBarrier wait also arrives on the barrier. Routing call to NamedBarrier.arrive_and_wait()."
|
||||
)
|
||||
llvm.inline_asm(
|
||||
None,
|
||||
[Int32(barrier_id).ir_value(), Int32(num_threads).ir_value()],
|
||||
"barrier.sync $0, $1;",
|
||||
"r,r",
|
||||
has_side_effects=True,
|
||||
is_align_stack=False,
|
||||
asm_dialect=llvm.AsmDialect.AD_ATT,
|
||||
)
|
||||
|
||||
|
||||
def arrive_and_wait(barrier_id: int, num_threads: int):
|
||||
cute.arch.barrier(barrier_id=barrier_id, number_of_threads=num_threads)
|
||||
|
||||
|
||||
def sync(barrier_id: int = 0):
|
||||
cute.arch.barrier(barrier_id=barrier_id)
|
||||
452
python/CuTeDSL/cutlass/pipeline/sm100.py
Normal file
452
python/CuTeDSL/cutlass/pipeline/sm100.py
Normal file
@ -0,0 +1,452 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
import warnings
|
||||
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cutlass_dsl import Boolean, if_generate
|
||||
|
||||
from cutlass.pipeline import (
|
||||
CooperativeGroup,
|
||||
PipelineOp,
|
||||
PipelineState,
|
||||
pipeline_init_wait,
|
||||
PipelineAsync,
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# Pipeline classes
|
||||
##############################################################################
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PipelineTmaUmma(PipelineAsync):
|
||||
"""
|
||||
PipelineTmaUmma is used for TMA producers and UMMA consumers (e.g. Blackwell mainloops).
|
||||
"""
|
||||
|
||||
is_leader_cta: bool
|
||||
cta_group: cute.nvgpu.tcgen05.CtaGroup
|
||||
|
||||
@staticmethod
|
||||
def _compute_mcast_arrival_mask(cta_layout_vmnk: cute.Layout):
|
||||
"""
|
||||
Computes a mask for signaling arrivals to multicasting threadblocks.
|
||||
"""
|
||||
cta_rank_in_cluster = cute.arch.make_warp_uniform(
|
||||
cute.arch.block_idx_in_cluster()
|
||||
)
|
||||
cta_in_cluster_coord_vmnk = cta_layout_vmnk.get_flat_coord(cta_rank_in_cluster)
|
||||
|
||||
tma_mcast_mask_a = cute.nvgpu.cpasync.create_tma_multicast_mask(
|
||||
cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=2
|
||||
)
|
||||
tma_mcast_mask_b = cute.nvgpu.cpasync.create_tma_multicast_mask(
|
||||
cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=1
|
||||
)
|
||||
|
||||
block_in_cluster_coord_vmnk_peer = (
|
||||
cta_in_cluster_coord_vmnk[0] ^ 1,
|
||||
*cta_in_cluster_coord_vmnk[1:],
|
||||
)
|
||||
tma_mcast_mask_a_peer = cute.nvgpu.cpasync.create_tma_multicast_mask(
|
||||
cta_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=2
|
||||
)
|
||||
tma_mcast_mask_b_peer = cute.nvgpu.cpasync.create_tma_multicast_mask(
|
||||
cta_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=1
|
||||
)
|
||||
|
||||
return (
|
||||
tma_mcast_mask_a
|
||||
| tma_mcast_mask_b
|
||||
| tma_mcast_mask_a_peer
|
||||
| tma_mcast_mask_b_peer
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _compute_is_leader_cta(cta_layout_vmnk: cute.Layout):
|
||||
"""
|
||||
Computes leader threadblocks for 2CTA kernels. For 1CTA, all threadblocks are leaders.
|
||||
"""
|
||||
bidx, bidy, _ = cute.arch.block_idx()
|
||||
|
||||
mma_coord_vmnk = (
|
||||
bidx % cute.size(cta_layout_vmnk, mode=[0]),
|
||||
bidx // cute.size(cta_layout_vmnk, mode=[0]),
|
||||
bidy,
|
||||
None,
|
||||
)
|
||||
return mma_coord_vmnk[0] == 0
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
*,
|
||||
num_stages: int,
|
||||
producer_group: CooperativeGroup,
|
||||
consumer_group: CooperativeGroup,
|
||||
tx_count: int,
|
||||
barrier_storage: cute.Pointer = None,
|
||||
cta_layout_vmnk: Optional[cute.Layout] = None,
|
||||
):
|
||||
"""
|
||||
This helper function computes any necessary attributes and returns an instance of PipelineTmaUmma.
|
||||
:param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
|
||||
:type barrier_storage: cute.Pointer
|
||||
:param num_stages: Number of buffer stages for this pipeline
|
||||
:type num_stages: Int32
|
||||
:param producer_group: CooperativeGroup for the producer agent
|
||||
:type producer_group: CooperativeGroup
|
||||
:param consumer_group: CooperativeGroup for the consumer agent
|
||||
:type consumer_group: CooperativeGroup
|
||||
:param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
|
||||
:type tx_count: int
|
||||
:param cta_layout_vmnk: Layout of the cluster shape
|
||||
:type cta_layout_vmnk: cute.Layout | None
|
||||
"""
|
||||
if not isinstance(barrier_storage, cute.Pointer):
|
||||
raise ValueError(
|
||||
f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
|
||||
)
|
||||
|
||||
producer_type = PipelineOp.TmaLoad
|
||||
consumer_type = PipelineOp.TCGen05Mma
|
||||
|
||||
producer = (producer_type, producer_group)
|
||||
consumer = (consumer_type, consumer_group)
|
||||
|
||||
sync_object_full = PipelineAsync._make_sync_object(
|
||||
barrier_storage.align(min_align=8), num_stages, producer, tx_count
|
||||
)
|
||||
sync_object_empty = PipelineAsync._make_sync_object(
|
||||
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
|
||||
)
|
||||
|
||||
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
|
||||
# No mcast mask if not using clusters
|
||||
producer_mask = None
|
||||
# All threadblocks are leaders if not using clusters
|
||||
is_leader_cta = True
|
||||
else:
|
||||
producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(cta_layout_vmnk)
|
||||
is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk)
|
||||
|
||||
cta_group = (
|
||||
cute.nvgpu.tcgen05.CtaGroup.ONE
|
||||
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1
|
||||
else cute.nvgpu.tcgen05.CtaGroup.TWO
|
||||
)
|
||||
|
||||
consumer_mask = producer_mask
|
||||
|
||||
pipeline_init_wait(cta_layout_vmnk)
|
||||
|
||||
return PipelineTmaUmma(
|
||||
sync_object_full,
|
||||
sync_object_empty,
|
||||
num_stages,
|
||||
producer_mask,
|
||||
consumer_mask,
|
||||
is_leader_cta,
|
||||
cta_group,
|
||||
)
|
||||
|
||||
def consumer_release(self, state: PipelineState):
|
||||
"""
|
||||
UMMA consumer release buffer empty, cta_group needs to be provided.
|
||||
"""
|
||||
self.sync_object_empty.arrive(state.index, self.consumer_mask, self.cta_group)
|
||||
|
||||
def producer_acquire(
|
||||
self, state: PipelineState, try_acquire_token: Optional[Boolean] = None
|
||||
):
|
||||
"""
|
||||
TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.
|
||||
"""
|
||||
if_generate(
|
||||
try_acquire_token is None or try_acquire_token == 0,
|
||||
lambda: self.sync_object_empty.wait(state.index, state.phase),
|
||||
)
|
||||
if_generate(
|
||||
self.is_leader_cta,
|
||||
lambda: self.sync_object_full.arrive(state.index, self.producer_mask),
|
||||
)
|
||||
|
||||
def producer_commit(self, state: PipelineState):
|
||||
"""
|
||||
TMA producer commit is a noop since TMA instruction itself updates the transaction count.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PipelineAsyncUmma(PipelineAsync):
|
||||
"""
|
||||
PipelineAsyncUmma is used for AsyncThread producers and UMMA consumers (e.g. Blackwell input fusion pipelines).
|
||||
"""
|
||||
|
||||
cta_group: cute.nvgpu.tcgen05.CtaGroup
|
||||
|
||||
@staticmethod
|
||||
def _compute_leading_cta_rank(cta_v_size):
|
||||
"""
|
||||
Computes the leading CTA rank.
|
||||
"""
|
||||
cta_rank_in_cluster = cute.arch.make_warp_uniform(
|
||||
cute.arch.block_idx_in_cluster()
|
||||
)
|
||||
return cta_rank_in_cluster // cta_v_size * cta_v_size
|
||||
|
||||
@staticmethod
|
||||
def _compute_is_leader_cta(cta_layout_vmnk: cute.Layout):
|
||||
"""
|
||||
Computes leader threadblocks for 2CTA kernels. For 1CTA, all threadblocks are leaders.
|
||||
"""
|
||||
bidx, bidy, _ = cute.arch.block_idx()
|
||||
mma_coord_vmnk = (
|
||||
bidx % cute.size(cta_layout_vmnk, mode=[0]),
|
||||
bidx // cute.size(cta_layout_vmnk, mode=[0]),
|
||||
bidy,
|
||||
None,
|
||||
)
|
||||
return mma_coord_vmnk[0] == 0
|
||||
|
||||
@staticmethod
|
||||
def _compute_peer_cta_mask(cta_layout_vmnk: cute.Layout):
|
||||
"""
|
||||
Computes a mask for signaling arrivals to multicasting threadblocks.
|
||||
"""
|
||||
cta_rank_in_cluster = cute.arch.make_warp_uniform(
|
||||
cute.arch.block_idx_in_cluster()
|
||||
)
|
||||
cta_in_cluster_coord_vmnk = cta_layout_vmnk.get_flat_coord(cta_rank_in_cluster)
|
||||
mask_self = cute.nvgpu.cpasync.create_tma_multicast_mask(
|
||||
cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=0
|
||||
)
|
||||
block_in_cluster_coord_vmnk_peer = (
|
||||
cta_in_cluster_coord_vmnk[0] ^ 1,
|
||||
*cta_in_cluster_coord_vmnk[1:],
|
||||
)
|
||||
mask_peer = cute.nvgpu.cpasync.create_tma_multicast_mask(
|
||||
cta_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=0
|
||||
)
|
||||
return mask_self | mask_peer
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
*,
|
||||
num_stages: int,
|
||||
producer_group: CooperativeGroup,
|
||||
consumer_group: CooperativeGroup,
|
||||
barrier_storage: cute.Pointer = None,
|
||||
cta_layout_vmnk: Optional[cute.Layout] = None,
|
||||
):
|
||||
"""
|
||||
This helper function computes any necessary attributes and returns an instance of PipelineAsyncUmma.
|
||||
:param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
|
||||
:type barrier_storage: cute.Pointer
|
||||
:param num_stages: Number of buffer stages for this pipeline
|
||||
:type num_stages: Int32
|
||||
:param producer_group: CooperativeGroup for the producer agent
|
||||
:type producer_group: CooperativeGroup
|
||||
:param consumer_group: CooperativeGroup for the consumer agent
|
||||
:type consumer_group: CooperativeGroup
|
||||
:param cta_layout_vmnk: Layout of the cluster shape
|
||||
:type cta_layout_vmnk: cute.Layout | None
|
||||
"""
|
||||
if not isinstance(barrier_storage, cute.Pointer):
|
||||
raise ValueError(
|
||||
f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
|
||||
)
|
||||
|
||||
producer_type = PipelineOp.AsyncThread
|
||||
consumer_type = PipelineOp.TCGen05Mma
|
||||
|
||||
producer = (producer_type, producer_group)
|
||||
consumer = (consumer_type, consumer_group)
|
||||
|
||||
sync_object_full = PipelineAsync._make_sync_object(
|
||||
barrier_storage.align(min_align=8),
|
||||
num_stages,
|
||||
producer,
|
||||
)
|
||||
sync_object_empty = PipelineAsync._make_sync_object(
|
||||
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
|
||||
)
|
||||
|
||||
cta_v_size = (
|
||||
cute.size(cta_layout_vmnk, mode=[0]) if cta_layout_vmnk is not None else 1
|
||||
)
|
||||
cta_group = (
|
||||
cute.nvgpu.tcgen05.CtaGroup.ONE
|
||||
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1
|
||||
else cute.nvgpu.tcgen05.CtaGroup.TWO
|
||||
)
|
||||
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1:
|
||||
# No mcast mask if we're not using 2CTA tcgen05 MMA
|
||||
producer_mask = None
|
||||
consumer_mask = None
|
||||
else:
|
||||
# If we're using 2CTA UMMAs, producer will arrive the mbar on leading CTA
|
||||
# We need to get the target cta_rank
|
||||
producer_mask = PipelineAsyncUmma._compute_leading_cta_rank(cta_v_size)
|
||||
# consumer needs to get the mask to signal
|
||||
consumer_mask = PipelineAsyncUmma._compute_peer_cta_mask(cta_layout_vmnk)
|
||||
|
||||
pipeline_init_wait(cta_layout_vmnk)
|
||||
|
||||
return PipelineAsyncUmma(
|
||||
sync_object_full,
|
||||
sync_object_empty,
|
||||
num_stages,
|
||||
producer_mask,
|
||||
consumer_mask,
|
||||
cta_group,
|
||||
)
|
||||
|
||||
def consumer_release(self, state: PipelineState):
|
||||
"""
|
||||
UMMA consumer release buffer empty, cta_group needs to be provided.
|
||||
"""
|
||||
self.sync_object_empty.arrive(state.index, self.consumer_mask, self.cta_group)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PipelineUmmaAsync(PipelineAsync):
|
||||
"""
|
||||
PipelineUmmaAsync is used for UMMA producers and AsyncThread consumers (e.g. Blackwell accumulator pipelines).
|
||||
"""
|
||||
|
||||
cta_group: cute.nvgpu.tcgen05.CtaGroup
|
||||
|
||||
@staticmethod
|
||||
def _compute_tmem_sync_mask(cta_layout_vmnk: cute.Layout):
|
||||
"""
|
||||
Computes a mask to signal completion of tmem buffers for 2CTA kernels.
|
||||
"""
|
||||
cta_rank_in_cluster = cute.arch.make_warp_uniform(
|
||||
cute.arch.block_idx_in_cluster()
|
||||
)
|
||||
cta_in_cluster_coord_vmnk = cta_layout_vmnk.get_flat_coord(cta_rank_in_cluster)
|
||||
return cute.make_layout_image_mask(
|
||||
cta_layout_vmnk, cta_in_cluster_coord_vmnk, mode=0
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _compute_peer_cta_rank():
|
||||
"""
|
||||
Computes a mask to signal release of tmem buffers for 2CTA kernels.
|
||||
"""
|
||||
cta_rank_in_cluster = cute.arch.make_warp_uniform(
|
||||
cute.arch.block_idx_in_cluster()
|
||||
)
|
||||
return cta_rank_in_cluster // 2 * 2
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
*,
|
||||
num_stages: int,
|
||||
producer_group: CooperativeGroup,
|
||||
consumer_group: CooperativeGroup,
|
||||
barrier_storage: cute.Pointer = None,
|
||||
cta_layout_vmnk: Optional[cute.Layout] = None,
|
||||
):
|
||||
"""
|
||||
This helper function computes any necessary attributes and returns an instance of PipelineUmmaAsync.
|
||||
:param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
|
||||
:type barrier_storage: cute.Pointer
|
||||
:param num_stages: Number of buffer stages for this pipeline
|
||||
:type num_stages: Int32
|
||||
:param producer_group: CooperativeGroup for the producer agent
|
||||
:type producer_group: CooperativeGroup
|
||||
:param consumer_group: CooperativeGroup for the consumer agent
|
||||
:type consumer_group: CooperativeGroup
|
||||
:param cta_layout_vmnk: Layout of the cluster shape
|
||||
:type cta_layout_vmnk: cute.Layout | None
|
||||
"""
|
||||
if not isinstance(barrier_storage, cute.Pointer):
|
||||
raise ValueError(
|
||||
f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
|
||||
)
|
||||
|
||||
producer_type = PipelineOp.TCGen05Mma
|
||||
consumer_type = PipelineOp.AsyncThread
|
||||
|
||||
producer = (producer_type, producer_group)
|
||||
consumer = (consumer_type, consumer_group)
|
||||
|
||||
sync_object_full = PipelineAsync._make_sync_object(
|
||||
barrier_storage.align(min_align=8), num_stages, producer
|
||||
)
|
||||
sync_object_empty = PipelineAsync._make_sync_object(
|
||||
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
|
||||
)
|
||||
|
||||
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
|
||||
# Set mask to None if not using clusters (i.e. 1CTA kernels)
|
||||
producer_mask = None
|
||||
else:
|
||||
producer_mask = PipelineUmmaAsync._compute_tmem_sync_mask(cta_layout_vmnk)
|
||||
|
||||
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1:
|
||||
# Set mask to None if not using 2CTA intructions
|
||||
consumer_mask = None
|
||||
else:
|
||||
consumer_mask = PipelineUmmaAsync._compute_peer_cta_rank()
|
||||
|
||||
cta_group = (
|
||||
cute.nvgpu.tcgen05.CtaGroup.ONE
|
||||
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1
|
||||
else cute.nvgpu.tcgen05.CtaGroup.TWO
|
||||
)
|
||||
|
||||
pipeline_init_wait(cta_layout_vmnk)
|
||||
|
||||
return PipelineUmmaAsync(
|
||||
sync_object_full,
|
||||
sync_object_empty,
|
||||
num_stages,
|
||||
producer_mask,
|
||||
consumer_mask,
|
||||
cta_group,
|
||||
)
|
||||
|
||||
def producer_commit(self, state: PipelineState):
|
||||
"""
|
||||
UMMA producer commit buffer full, cta_group needs to be provided.
|
||||
"""
|
||||
self.sync_object_full.arrive(state.index, self.producer_mask, self.cta_group)
|
||||
|
||||
def producer_tail(self, state: PipelineState):
|
||||
"""
|
||||
Make sure the last used buffer empty signal is visible to producer.
|
||||
Producer tail is usually executed by producer before exit, to avoid dangling
|
||||
mbarrier arrive signals after kernel exit.
|
||||
|
||||
:param state: The pipeline state that points to next useful buffer
|
||||
:type state: PipelineState
|
||||
"""
|
||||
cta_rank_in_cluster = cute.arch.make_warp_uniform(
|
||||
cute.arch.block_idx_in_cluster()
|
||||
)
|
||||
is_leader_cta = cta_rank_in_cluster % 2 == 0
|
||||
|
||||
def then_body():
|
||||
# Assume state contains that next useful buffer
|
||||
# So we only need to advance to num_stages - 1 times to last used buffer
|
||||
for i in range(self.num_stages - 1):
|
||||
state.advance()
|
||||
self.producer_acquire(state)
|
||||
|
||||
if_generate(is_leader_cta, then_body)
|
||||
803
python/CuTeDSL/cutlass/pipeline/sm90.py
Normal file
803
python/CuTeDSL/cutlass/pipeline/sm90.py
Normal file
@ -0,0 +1,803 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
import warnings
|
||||
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cutlass_dsl import Boolean, Int32, if_generate
|
||||
|
||||
from cutlass.pipeline import (
|
||||
CooperativeGroup,
|
||||
PipelineOp,
|
||||
SyncObject,
|
||||
MbarrierArray,
|
||||
TmaStoreFence,
|
||||
PipelineUserType,
|
||||
PipelineState,
|
||||
make_pipeline_state,
|
||||
pipeline_init_wait,
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# Pipeline classes
|
||||
##############################################################################
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PipelineAsync:
|
||||
"""PipelineAsync is a generic pipeline class where both the producer and consumer are
|
||||
AsyncThreads. It also serves as a base class for specialized pipeline classes.
|
||||
|
||||
This class implements a producer-consumer pipeline pattern where both sides operate
|
||||
asynchronously. The pipeline maintains synchronization state using barrier objects
|
||||
to coordinate between producer and consumer threads.
|
||||
|
||||
The pipeline state transitions of one pipeline entry(mbarrier) can be represented as:
|
||||
|
||||
.. table:: Pipeline State Transitions
|
||||
:widths: auto
|
||||
|
||||
+-----------+-----------+-----------+-----------+-----------+-----------+
|
||||
| Barrier | State | p.acquire | p.commit | c.wait | c.release |
|
||||
+===========+===========+===========+===========+===========+===========+
|
||||
| empty_bar | empty | <Return> | n/a | n/a | - |
|
||||
+-----------+-----------+-----------+-----------+-----------+-----------+
|
||||
| empty_bar | wait | <Block> | n/a | n/a | -> empty |
|
||||
+-----------+-----------+-----------+-----------+-----------+-----------+
|
||||
| full_bar | wait | n/a | -> full | <Block > | n/a |
|
||||
+-----------+-----------+-----------+-----------+-----------+-----------+
|
||||
| full_bar | full | n/a | - | <Return> | n/a |
|
||||
+-----------+-----------+-----------+-----------+-----------+-----------+
|
||||
|
||||
Where:
|
||||
|
||||
- p: producer
|
||||
- c: consumer
|
||||
- <Block>: This action is blocked until transition to a state allow it to proceed by other side
|
||||
- e.g. ``p.acquire()`` is blocked until ``empty_bar`` transition to ``empty`` state by ``c.release()``
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
Array of mbarriers as circular buffer:
|
||||
|
||||
Advance Direction
|
||||
<-------------------
|
||||
|
||||
Producer Consumer
|
||||
| ^
|
||||
V |
|
||||
+-----------------+
|
||||
--|X|X|W|D|D|D|D|R|X|<-.
|
||||
/ +-----------------+ \\
|
||||
| |
|
||||
`------------------------'
|
||||
|
||||
Where:
|
||||
|
||||
- X: Empty buffer (initial state)
|
||||
- W: Producer writing (producer is waiting for buffer to be empty)
|
||||
- D: Data ready (producer has written data to buffer)
|
||||
- R: Consumer reading (consumer is consuming data from buffer)
|
||||
|
||||
**Example:**
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Create pipeline with 5 stages
|
||||
pipeline = PipelineAsync.create(
|
||||
num_stages=5, # number of pipeline stages
|
||||
producer_group=producer_warp,
|
||||
consumer_group=consumer_warp
|
||||
barrier_storage=smem_ptr, # smem pointer for array of mbarriers in shared memory
|
||||
)
|
||||
|
||||
# Producer side
|
||||
producer = pipeline.make_pipeline_producer(producer_warp)
|
||||
for i in range(num_iterations):
|
||||
producer.acquire() # Wait for buffer to be empty
|
||||
# Write data to pipeline buffer
|
||||
producer.commit() # Signal buffer is full
|
||||
producer.advance() # Move index to next stage
|
||||
|
||||
# Consumer side
|
||||
consumer = pipeline.make_pipeline_consumer(consumer_warp)
|
||||
for i in range(num_iterations):
|
||||
consumer.wait() # Wait for buffer to be full
|
||||
# Read data from pipeline buffer
|
||||
consumer.release() # Signal buffer is empty
|
||||
consumer.advance() # Move index to next stage
|
||||
"""
|
||||
|
||||
sync_object_full: SyncObject
|
||||
sync_object_empty: SyncObject
|
||||
num_stages: int
|
||||
producer_mask: Optional[Int32]
|
||||
consumer_mask: Optional[Int32]
|
||||
|
||||
@staticmethod
|
||||
def _make_sync_object(
|
||||
barrier_storage: cute.Pointer,
|
||||
num_stages: int,
|
||||
agent: tuple[PipelineOp, CooperativeGroup],
|
||||
tx_count: int = 0,
|
||||
) -> SyncObject:
|
||||
"""
|
||||
Returns a SyncObject corresponding to an agent's PipelineOp.
|
||||
"""
|
||||
if agent[0] in [
|
||||
PipelineOp.AsyncThread,
|
||||
PipelineOp.TmaLoad,
|
||||
PipelineOp.TCGen05Mma,
|
||||
PipelineOp.Composite,
|
||||
]:
|
||||
return MbarrierArray(
|
||||
barrier_storage=barrier_storage,
|
||||
num_stages=num_stages,
|
||||
agent=agent,
|
||||
tx_count=tx_count,
|
||||
)
|
||||
elif agent[0] is PipelineOp.TmaStore:
|
||||
# Path taken for AsyncTmaStore
|
||||
return TmaStoreFence(num_stages=num_stages)
|
||||
else:
|
||||
assert False, "Error: Invalid PipelineOp specified."
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
*,
|
||||
num_stages: int,
|
||||
producer_group: CooperativeGroup,
|
||||
consumer_group: CooperativeGroup,
|
||||
barrier_storage: cute.Pointer = None,
|
||||
producer_mask: Int32 = None,
|
||||
consumer_mask: Int32 = None,
|
||||
):
|
||||
"""Creates and initializes a new PipelineAsync instance.
|
||||
|
||||
This helper function computes necessary attributes and returns an instance of PipelineAsync
|
||||
with the specified configuration for producer and consumer synchronization.
|
||||
|
||||
:param barrier_storage: Pointer to the shared memory address for this pipeline's mbarriers
|
||||
:type barrier_storage: cute.Pointer
|
||||
:param num_stages: Number of buffer stages for this pipeline
|
||||
:type num_stages: int
|
||||
:param producer_group: `CooperativeGroup` for the producer agent
|
||||
:type producer_group: CooperativeGroup
|
||||
:param consumer_group: `CooperativeGroup` for the consumer agent
|
||||
:type consumer_group: CooperativeGroup
|
||||
:param producer_mask: Mask for signaling arrives for the producer agent, defaults to ``None``
|
||||
:type producer_mask: Int32, optional
|
||||
:param consumer_mask: Mask for signaling arrives for the consumer agent, defaults to ``None``
|
||||
:type consumer_mask: Int32, optional
|
||||
:return: A new PipelineAsync instance
|
||||
:rtype: PipelineAsync
|
||||
:raises ValueError: If barrier_storage is not a cute.Pointer instance
|
||||
"""
|
||||
if not isinstance(barrier_storage, cute.Pointer):
|
||||
raise ValueError(
|
||||
f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
|
||||
)
|
||||
|
||||
producer_type = PipelineOp.AsyncThread
|
||||
consumer_type = PipelineOp.AsyncThread
|
||||
|
||||
producer = (producer_type, producer_group)
|
||||
consumer = (consumer_type, consumer_group)
|
||||
|
||||
sync_object_full = PipelineAsync._make_sync_object(
|
||||
barrier_storage.align(min_align=8), num_stages, producer
|
||||
)
|
||||
sync_object_empty = PipelineAsync._make_sync_object(
|
||||
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
|
||||
)
|
||||
|
||||
pipeline_init_wait()
|
||||
|
||||
return PipelineAsync(
|
||||
sync_object_full,
|
||||
sync_object_empty,
|
||||
num_stages,
|
||||
producer_mask,
|
||||
consumer_mask,
|
||||
)
|
||||
|
||||
def producer_acquire(
|
||||
self, state: PipelineState, try_acquire_token: Optional[Boolean] = None
|
||||
):
|
||||
if_generate(
|
||||
try_acquire_token is None or try_acquire_token == 0,
|
||||
lambda: self.sync_object_empty.wait(state.index, state.phase),
|
||||
)
|
||||
|
||||
def producer_try_acquire(self, state: PipelineState):
|
||||
return self.sync_object_empty.try_wait(state.index, state.phase)
|
||||
|
||||
def producer_commit(self, state: PipelineState):
|
||||
self.sync_object_full.arrive(state.index, self.producer_mask)
|
||||
|
||||
def consumer_wait(
|
||||
self, state: PipelineState, try_wait_token: Optional[Boolean] = None
|
||||
):
|
||||
if_generate(
|
||||
try_wait_token is None or try_wait_token == 0,
|
||||
lambda: self.sync_object_full.wait(state.index, state.phase),
|
||||
)
|
||||
|
||||
def consumer_try_wait(self, state: PipelineState):
|
||||
return self.sync_object_full.try_wait(state.index, state.phase)
|
||||
|
||||
def consumer_release(self, state: PipelineState):
|
||||
self.sync_object_empty.arrive(state.index, self.consumer_mask)
|
||||
|
||||
def producer_get_barrier(self, state: PipelineState) -> cute.Pointer:
|
||||
return self.sync_object_full.get_barrier(state.index)
|
||||
|
||||
def producer_tail(self, state: PipelineState):
|
||||
"""
|
||||
Make sure the last used buffer empty signal is visible to producer.
|
||||
Producer tail is usually executed by producer before exit, to avoid dangling
|
||||
mbarrier arrive signals after kernel exit.
|
||||
|
||||
:param state: The pipeline state that points to next useful buffer
|
||||
:type state: PipelineState
|
||||
"""
|
||||
# Assume state contains that next useful buffer
|
||||
# So we only need to advance to num_stages - 1 times to last used buffer
|
||||
for i in range(self.num_stages - 1):
|
||||
state.advance()
|
||||
self.producer_acquire(state)
|
||||
|
||||
# Util methods to manage produer and consumer
|
||||
def make_pipeline_producer(self, group: CooperativeGroup):
|
||||
state = make_pipeline_state(PipelineUserType.Producer, self.num_stages)
|
||||
return PipelineProducer(self, state, group)
|
||||
|
||||
def make_pipeline_consumer(self, group: CooperativeGroup):
|
||||
state = make_pipeline_state(PipelineUserType.Consumer, self.num_stages)
|
||||
return PipelineConsumer(self, state, group)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PipelineTmaAsync(PipelineAsync):
|
||||
"""
|
||||
PipelineTmaAsync is used for TMA producers and AsyncThread consumers (e.g. Hopper mainloops).
|
||||
"""
|
||||
|
||||
is_signalling_thread: Boolean
|
||||
|
||||
@staticmethod
|
||||
@cute.jit
|
||||
def init_empty_barrier_arrive_signal(cta_layout_vmnk: cute.Layout, tidx: Int32):
|
||||
"""
|
||||
Initialize the empty barrier arrive signal
|
||||
This function returns the destination cta rank and a boolean indicating if the signalling thread is the same as the current thread
|
||||
"""
|
||||
# Logic to optimally schedule Empty Arrives
|
||||
cluster_shape_vmnk = cta_layout_vmnk.shape
|
||||
|
||||
cta_rank_in_cluster = cute.arch.make_warp_uniform(
|
||||
cute.arch.block_idx_in_cluster()
|
||||
)
|
||||
|
||||
tidx = tidx % 32
|
||||
is_signalling_thread = tidx < cute.size(cluster_shape_vmnk)
|
||||
dst_rank = tidx % cute.size(cluster_shape_vmnk)
|
||||
|
||||
dst_cta_coord = cta_layout_vmnk.get_hier_coord(dst_rank)
|
||||
cur_cta_coord = cta_layout_vmnk.get_hier_coord(cta_rank_in_cluster)
|
||||
|
||||
is_same_row = (
|
||||
dst_cta_coord[0] == cur_cta_coord[0]
|
||||
and dst_cta_coord[1] == cur_cta_coord[1]
|
||||
and dst_cta_coord[3] == cur_cta_coord[3]
|
||||
)
|
||||
is_same_col = (
|
||||
dst_cta_coord[0] == cur_cta_coord[0]
|
||||
and dst_cta_coord[2] == cur_cta_coord[2]
|
||||
and dst_cta_coord[3] == cur_cta_coord[3]
|
||||
)
|
||||
|
||||
is_same_row_or_col = is_same_row or is_same_col
|
||||
is_signalling_thread_final = is_signalling_thread and is_same_row_or_col
|
||||
|
||||
return dst_rank, is_signalling_thread_final
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
*,
|
||||
num_stages: int,
|
||||
producer_group: CooperativeGroup,
|
||||
consumer_group: CooperativeGroup,
|
||||
tx_count: int,
|
||||
barrier_storage: cute.Pointer = None,
|
||||
cta_layout_vmnk: Optional[cute.Layout] = None,
|
||||
tidx: Optional[Int32] = None,
|
||||
):
|
||||
"""
|
||||
This helper function computes any necessary attributes and returns an instance of PipelineTmaAsync.
|
||||
:param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
|
||||
:type barrier_storage: cute.Pointer
|
||||
:param num_stages: Number of buffer stages for this pipeline
|
||||
:type num_stages: Int32
|
||||
:param producer_group: CooperativeGroup for the producer agent
|
||||
:type producer_group: CooperativeGroup
|
||||
:param consumer_group: CooperativeGroup for the consumer agent
|
||||
:type consumer_group: CooperativeGroup
|
||||
:param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
|
||||
:type tx_count: int
|
||||
:param cta_layout_vmnk: Layout of the cluster shape
|
||||
:type cta_layout_vmnk: cute.Layout | None
|
||||
:param tidx: thread index to consumer async threads
|
||||
:type tidx: Int32 | None
|
||||
"""
|
||||
if not isinstance(barrier_storage, cute.Pointer):
|
||||
raise ValueError(
|
||||
f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
|
||||
)
|
||||
|
||||
producer_type = PipelineOp.TmaLoad
|
||||
consumer_type = PipelineOp.AsyncThread
|
||||
|
||||
producer = (producer_type, producer_group)
|
||||
consumer = (consumer_type, consumer_group)
|
||||
|
||||
sync_object_full = PipelineAsync._make_sync_object(
|
||||
barrier_storage.align(min_align=8), num_stages, producer, tx_count
|
||||
)
|
||||
sync_object_empty = PipelineAsync._make_sync_object(
|
||||
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
|
||||
)
|
||||
if tidx is None:
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if cta_layout_vmnk is None:
|
||||
cta_layout_vmnk = cute.make_layout((1, 1, 1, 1))
|
||||
(
|
||||
dst_rank,
|
||||
is_signalling_thread,
|
||||
) = PipelineTmaAsync.init_empty_barrier_arrive_signal(cta_layout_vmnk, tidx)
|
||||
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
|
||||
dst_rank = None
|
||||
else:
|
||||
dst_rank = dst_rank
|
||||
|
||||
producer_mask = None
|
||||
|
||||
pipeline_init_wait(cta_layout_vmnk)
|
||||
|
||||
return PipelineTmaAsync(
|
||||
sync_object_full,
|
||||
sync_object_empty,
|
||||
num_stages,
|
||||
producer_mask,
|
||||
dst_rank,
|
||||
is_signalling_thread,
|
||||
)
|
||||
|
||||
def producer_acquire(
|
||||
self, state: PipelineState, try_acquire_token: Optional[Boolean] = None
|
||||
):
|
||||
"""
|
||||
TMA producer commit conditionally waits on buffer empty and sets the transaction barrier.
|
||||
"""
|
||||
if_generate(
|
||||
try_acquire_token is None or try_acquire_token == 0,
|
||||
lambda: self.sync_object_empty.wait(state.index, state.phase),
|
||||
)
|
||||
self.sync_object_full.arrive(state.index, self.producer_mask)
|
||||
|
||||
def producer_commit(self, state: PipelineState):
|
||||
"""
|
||||
TMA producer commit is a noop since TMA instruction itself updates the transaction count.
|
||||
"""
|
||||
pass
|
||||
|
||||
def consumer_release(self, state: PipelineState):
|
||||
"""
|
||||
TMA consumer release conditionally signals the empty buffer to the producer.
|
||||
"""
|
||||
if_generate(
|
||||
self.is_signalling_thread,
|
||||
lambda: self.sync_object_empty.arrive(state.index, self.consumer_mask),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PipelineTmaMultiConsumersAsync(PipelineAsync):
|
||||
"""
|
||||
PipelineTmaMultiConsumersAsync is used for TMA producers and UMMA+Async consumers.
|
||||
"""
|
||||
|
||||
is_leader_cta: bool
|
||||
sync_object_empty_umma: SyncObject
|
||||
sync_object_empty_async: SyncObject
|
||||
cta_group: cute.nvgpu.tcgen05.CtaGroup
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
*,
|
||||
num_stages: int,
|
||||
producer_group: CooperativeGroup,
|
||||
consumer_group_umma: CooperativeGroup,
|
||||
consumer_group_async: CooperativeGroup,
|
||||
tx_count: int,
|
||||
barrier_storage: cute.Pointer = None,
|
||||
cta_layout_vmnk: Optional[cute.Layout] = None,
|
||||
):
|
||||
"""
|
||||
This helper function computes any necessary attributes and returns an instance of PipelineTmaMultiConsumersAsync.
|
||||
:param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
|
||||
:type barrier_storage: cute.Pointer
|
||||
:param num_stages: Number of buffer stages for this pipeline
|
||||
:type num_stages: Int32
|
||||
:param producer_group: CooperativeGroup for the producer agent
|
||||
:type producer_group: CooperativeGroup
|
||||
:param consumer_group_umma: CooperativeGroup for the UMMA consumer agent
|
||||
:type consumer_group_umma: CooperativeGroup
|
||||
:param consumer_group_async: CooperativeGroup for the AsyncThread consumer agent
|
||||
:type consumer_group_async: CooperativeGroup
|
||||
:param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
|
||||
:type tx_count: int
|
||||
:param cta_layout_vmnk: Layout of the cluster shape
|
||||
:type cta_layout_vmnk: cute.Layout | None
|
||||
"""
|
||||
if not isinstance(barrier_storage, cute.Pointer):
|
||||
raise ValueError(
|
||||
f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
|
||||
)
|
||||
|
||||
producer_type = PipelineOp.TmaLoad
|
||||
consumer_type = PipelineOp.Composite
|
||||
consumer_type_umma = PipelineOp.TCGen05Mma
|
||||
consumer_type_async = PipelineOp.AsyncThread
|
||||
|
||||
if consumer_group_umma.agent != consumer_group_async.agent:
|
||||
raise ValueError(
|
||||
"UMMA and AsyncThread consumer groups must be the same agent"
|
||||
)
|
||||
|
||||
if cta_layout_vmnk is not None and cute.size(cta_layout_vmnk) != 1:
|
||||
raise ValueError(
|
||||
f"PipelineTmaMultiConsumersAsync is not verified for cta_layout_vmnk != 1, cta_layout_vmnk:{cta_layout_vmnk}"
|
||||
)
|
||||
|
||||
consumer_group = CooperativeGroup(
|
||||
consumer_group_umma.agent,
|
||||
consumer_group_umma.size + consumer_group_async.size,
|
||||
)
|
||||
|
||||
producer = (producer_type, producer_group)
|
||||
consumer = (consumer_type, consumer_group)
|
||||
|
||||
sync_object_full = PipelineAsync._make_sync_object(
|
||||
barrier_storage.align(min_align=8), num_stages, producer, tx_count
|
||||
)
|
||||
sync_object_empty = PipelineAsync._make_sync_object(
|
||||
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
|
||||
)
|
||||
sync_object_empty_umma = sync_object_empty.recast_to_new_op_type(
|
||||
consumer_type_umma
|
||||
)
|
||||
sync_object_empty_async = sync_object_empty.recast_to_new_op_type(
|
||||
consumer_type_async
|
||||
)
|
||||
|
||||
# No mcast mask if not using clusters
|
||||
producer_mask = None
|
||||
consumer_mask = None
|
||||
# All threadblocks are leaders if not using clusters
|
||||
is_leader_cta = True
|
||||
cta_group = (
|
||||
cute.nvgpu.tcgen05.CtaGroup.ONE
|
||||
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1
|
||||
else cute.nvgpu.tcgen05.CtaGroup.TWO
|
||||
)
|
||||
|
||||
pipeline_init_wait(cta_layout_vmnk)
|
||||
|
||||
return PipelineTmaMultiConsumersAsync(
|
||||
sync_object_full,
|
||||
sync_object_empty,
|
||||
num_stages,
|
||||
producer_mask,
|
||||
consumer_mask,
|
||||
is_leader_cta,
|
||||
sync_object_empty_umma,
|
||||
sync_object_empty_async,
|
||||
cta_group,
|
||||
)
|
||||
|
||||
def producer_acquire(
|
||||
self, state: PipelineState, try_acquire_token: Optional[Boolean] = None
|
||||
):
|
||||
"""
|
||||
TMA producer acquire waits on buffer empty and sets the transaction barrier for leader threadblocks.
|
||||
"""
|
||||
if_generate(
|
||||
try_acquire_token is None or try_acquire_token == 0,
|
||||
lambda: self.sync_object_empty.wait(state.index, state.phase),
|
||||
)
|
||||
if_generate(
|
||||
self.is_leader_cta,
|
||||
lambda: self.sync_object_full.arrive(state.index, self.producer_mask),
|
||||
)
|
||||
|
||||
def producer_commit(self, state: PipelineState):
|
||||
"""
|
||||
TMA producer commit is a noop since TMA instruction itself updates the transaction count.
|
||||
"""
|
||||
pass
|
||||
|
||||
def consumer_release(self, state: PipelineState, op_type: PipelineOp):
|
||||
if op_type == PipelineOp.TCGen05Mma:
|
||||
self.sync_object_empty_umma.arrive(
|
||||
state.index, self.consumer_mask, self.cta_group
|
||||
)
|
||||
elif op_type == PipelineOp.AsyncThread:
|
||||
self.sync_object_empty_async.arrive(state.index, self.consumer_mask)
|
||||
else:
|
||||
raise ValueError(f"Invalid PipelineOp specified. op_type:{op_type}")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PipelineTmaStore(PipelineAsync):
|
||||
"""
|
||||
PipelineTmaStore is used for synchronizing TMA stores in the epilogue. It does not use mbarriers.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
*,
|
||||
num_stages: int,
|
||||
producer_group: CooperativeGroup,
|
||||
):
|
||||
"""
|
||||
This helper function computes any necessary attributes and returns an instance of PipelineTmaStore.
|
||||
:param num_stages: Number of buffer stages for this pipeline
|
||||
:type num_stages: Int32
|
||||
:param producer_group: CooperativeGroup for the producer agent
|
||||
:type producer_group: CooperativeGroup
|
||||
"""
|
||||
producer_type = PipelineOp.TmaStore
|
||||
|
||||
producer = (producer_type, producer_group)
|
||||
|
||||
sync_object_full = PipelineAsync._make_sync_object(None, num_stages, producer)
|
||||
|
||||
return PipelineTmaStore(sync_object_full, None, num_stages, None, None)
|
||||
|
||||
def producer_acquire(self):
|
||||
self.sync_object_full.wait()
|
||||
|
||||
def producer_commit(self):
|
||||
self.sync_object_full.arrive()
|
||||
|
||||
def consumer_wait(self):
|
||||
assert False, "Error: PipelineTmaStore does not have a consumer agent."
|
||||
|
||||
def consumer_release(self):
|
||||
assert False, "Error: PipelineTmaStore does not have a consumer agent."
|
||||
|
||||
def producer_tail(self):
|
||||
self.sync_object_full.tail()
|
||||
|
||||
|
||||
#################################################################
|
||||
# Utilities to help user of pipeline to simplify the workflow
|
||||
#################################################################
|
||||
|
||||
|
||||
class PipelineProducer:
|
||||
"""A class representing a producer in an asynchronous pipeline.
|
||||
|
||||
The Producer class manages the producer side of an asynchronous pipeline, handling
|
||||
synchronization and state management for producing data. It provides methods for
|
||||
acquiring, committing, and advancing through pipeline stages.
|
||||
|
||||
:ivar _pipeline: The asynchronous pipeline this producer belongs to
|
||||
:type _pipeline: PipelineAsync
|
||||
:ivar _state: The current state of the producer in the pipeline
|
||||
:type _state: PipelineState
|
||||
:ivar _group: The cooperative group this producer operates in
|
||||
:type _group: CooperativeGroup
|
||||
|
||||
**Examples:**
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
pipeline = PipelineAsync.create(...)
|
||||
producer = pipeline.create_producer(producer_group, stages)
|
||||
for i in range(iterations):
|
||||
producer.acquire() # Wait for buffer to be empty
|
||||
# Produce data
|
||||
producer.commit() # Signal data is ready
|
||||
producer.advance() # Move to next stage
|
||||
"""
|
||||
|
||||
_pipeline: PipelineAsync
|
||||
_state: PipelineState
|
||||
_group: CooperativeGroup
|
||||
|
||||
def __init__(self, pipeline, state, group: CooperativeGroup):
|
||||
"""Initialize a new Producer instance.
|
||||
|
||||
:param pipeline: The pipeline this producer belongs to
|
||||
:type pipeline: PipelineAsync
|
||||
:param state: Initial pipeline state
|
||||
:type state: PipelineState
|
||||
:param group: The cooperative group for synchronization
|
||||
:type group: CooperativeGroup
|
||||
"""
|
||||
self._pipeline = pipeline
|
||||
self._state = state
|
||||
self._group = group
|
||||
|
||||
@property
|
||||
def index(self):
|
||||
"""Get the index of the current pipeline stage."""
|
||||
return self._state.index
|
||||
|
||||
def get_barrier(self):
|
||||
"""Get the barrier pointer for the current pipeline stage.
|
||||
|
||||
:return: Pointer to the barrier for the current stage
|
||||
:rtype: cute.Pointer
|
||||
"""
|
||||
return self._pipeline.producer_get_barrier(self._state)
|
||||
|
||||
def acquire(self):
|
||||
"""Wait for the current buffer to be empty before producing data.
|
||||
This is a blocking operation.
|
||||
"""
|
||||
self._pipeline.producer_acquire(self._state)
|
||||
|
||||
def try_acquire(self):
|
||||
"""Try to acquire the current buffer without blocking.
|
||||
|
||||
:return: True if acquisition was successful, False otherwise
|
||||
:rtype: bool
|
||||
"""
|
||||
self._pipeline.producer_try_acquire(self._state)
|
||||
|
||||
def commit(self):
|
||||
"""Signal that data production is complete for the current stage.
|
||||
This allows consumers to start processing the data.
|
||||
"""
|
||||
self._pipeline.producer_commit(self._state)
|
||||
|
||||
def tail(self):
|
||||
"""Ensure all used buffers are properly synchronized before producer exit.
|
||||
This should be called before the producer finishes to avoid dangling signals.
|
||||
"""
|
||||
self._pipeline.producer_tail(self._state)
|
||||
|
||||
def advance(self):
|
||||
"""Move to the next pipeline stage."""
|
||||
self._state.advance()
|
||||
|
||||
def __extract_mlir_values__(self):
|
||||
"""Extract MLIR values from the current state.
|
||||
|
||||
:return: List of MLIR values representing the current state
|
||||
:rtype: list
|
||||
"""
|
||||
# TODO: need to handle pipeline as well
|
||||
return self._state.__extract_mlir_values__()
|
||||
|
||||
def __new_from_mlir_values__(self, values):
|
||||
"""Create a new Producer instance from MLIR values.
|
||||
|
||||
:param values: MLIR values to initialize the state
|
||||
:type values: Any
|
||||
:return: New Producer instance with state initialized from values
|
||||
:rtype: Producer
|
||||
"""
|
||||
return PipelineProducer(
|
||||
self._pipeline, self._state.__new_from_mlir_values__(values), self._group
|
||||
)
|
||||
|
||||
|
||||
class PipelineConsumer:
|
||||
"""A class representing a consumer in an asynchronous pipeline.
|
||||
|
||||
The Consumer class manages the consumer side of an asynchronous pipeline, handling
|
||||
synchronization and state management for consuming data. It provides methods for
|
||||
waiting, releasing, and advancing through pipeline stages.
|
||||
|
||||
:ivar _pipeline: The asynchronous pipeline this consumer belongs to
|
||||
:type _pipeline: PipelineAsync
|
||||
:ivar _state: The current state of the consumer in the pipeline
|
||||
:type _state: PipelineState
|
||||
:ivar _group: The cooperative group this consumer operates in
|
||||
:type _group: CooperativeGroup
|
||||
|
||||
**Examples:**
|
||||
.. code-block:: python
|
||||
|
||||
pipeline = PipelineAsync.create(...)
|
||||
consumer = pipeline.create_consumer(consumer_group, stages)
|
||||
for i in range(iterations):
|
||||
consumer.wait() # Wait for data to be ready
|
||||
# Consume data
|
||||
consumer.release() # Signal buffer is empty
|
||||
consumer.advance() # Move to next stage
|
||||
"""
|
||||
|
||||
_pipeline: PipelineAsync
|
||||
_state: PipelineState
|
||||
_group: CooperativeGroup
|
||||
|
||||
def __init__(self, pipeline, state: PipelineState, group: CooperativeGroup):
|
||||
"""Initialize a new Consumer instance.
|
||||
|
||||
:param pipeline: The pipeline this consumer belongs to
|
||||
:type pipeline: PipelineAsync
|
||||
:param state: Initial pipeline state
|
||||
:type state: PipelineState
|
||||
:param group: The cooperative group for synchronization
|
||||
:type group: CooperativeGroup
|
||||
"""
|
||||
self._pipeline = pipeline
|
||||
self._group = group
|
||||
self._state = state
|
||||
|
||||
@property
|
||||
def index(self):
|
||||
"""Get the index of the current pipeline stage."""
|
||||
return self._state.index
|
||||
|
||||
def wait(self):
|
||||
"""Wait for data to be ready in the current buffer.
|
||||
This is a blocking operation.
|
||||
"""
|
||||
self._pipeline.consumer_wait(self._state)
|
||||
|
||||
def try_wait(self):
|
||||
"""Try to check if data is ready without blocking.
|
||||
|
||||
:return: True if data is ready, False otherwise
|
||||
:rtype: bool
|
||||
"""
|
||||
self._pipeline.consumer_try_wait(self._state)
|
||||
|
||||
def release(self):
|
||||
"""Signal that data consumption is complete for the current stage.
|
||||
This allows producers to start producing new data.
|
||||
"""
|
||||
self._pipeline.consumer_release(self._state)
|
||||
|
||||
def advance(self):
|
||||
"""Move to the next pipeline stage."""
|
||||
self._state.advance()
|
||||
|
||||
def __extract_mlir_values__(self):
|
||||
"""Extract MLIR values from the current state.
|
||||
|
||||
:return: List of MLIR values representing the current state
|
||||
:rtype: list
|
||||
"""
|
||||
return self._state.__extract_mlir_values__()
|
||||
|
||||
def __new_from_mlir_values__(self, values):
|
||||
"""Create a new Consumer instance from MLIR values.
|
||||
|
||||
:param values: MLIR values to initialize the state
|
||||
:type values: Any
|
||||
:return: New Consumer instance with state initialized from values
|
||||
:rtype: Consumer
|
||||
"""
|
||||
# TODO: need to call pipeline.__new_from_mlir_values__ recursively
|
||||
return PipelineConsumer(
|
||||
self._pipeline, self._state.__new_from_mlir_values__(values), self._group
|
||||
)
|
||||
@ -29,6 +29,7 @@ from cutlass.cute.typing import (
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
import cutlass.cute as cute
|
||||
import torch
|
||||
from cuda import cuda
|
||||
|
||||
|
||||
def dtype(ty: Type[Numeric]):
|
||||
@ -94,12 +95,13 @@ def create_and_permute_torch_tensor(
|
||||
init_config: Optional[
|
||||
Union[RandomInitConfig, ScalarInitConfig, GaussianInitConfig]
|
||||
] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Create a torch tensor with specified shape and dtype. Optionally permute it and initialize it with specified init type and config
|
||||
"""
|
||||
init_dtype = torch.int32 if init_type == TensorInitType.RANDOM else torch.float32
|
||||
init_torch_tensor = torch.empty(*shape, dtype=init_dtype)
|
||||
init_torch_tensor = torch.empty(*shape, dtype=init_dtype, device=device)
|
||||
if init_type == TensorInitType.SKIP:
|
||||
assert init_config is None
|
||||
f32_torch_tensor = init_torch_tensor
|
||||
@ -167,3 +169,122 @@ def convert_cute_tensor(
|
||||
# Copy and convert from f32 cute tensor to dtype cute tensor
|
||||
cute.testing.convert(fp32_cute_tensor, cute_tensor)
|
||||
return cute_tensor
|
||||
|
||||
|
||||
def default_stream() -> cuda.CUstream:
|
||||
"""
|
||||
Get default CUstream from torch stream
|
||||
"""
|
||||
torch_stream = torch.cuda.default_stream()
|
||||
stream = cuda.CUstream(torch_stream.cuda_stream)
|
||||
return stream
|
||||
|
||||
|
||||
def current_stream() -> cuda.CUstream:
|
||||
"""
|
||||
Get current CUstream from torch stream
|
||||
"""
|
||||
torch_stream = torch.cuda.current_stream()
|
||||
stream = cuda.CUstream(torch_stream.cuda_stream)
|
||||
return stream
|
||||
|
||||
|
||||
def matrix(
|
||||
l: int,
|
||||
mode0: int,
|
||||
mode1: int,
|
||||
is_mode0_major: bool,
|
||||
cutlass_dtype: Type[Numeric],
|
||||
init_type: TensorInitType = TensorInitType.RANDOM,
|
||||
init_config: Optional[
|
||||
Union[RandomInitConfig, ScalarInitConfig, GaussianInitConfig]
|
||||
] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Create a torch tensor for matrix
|
||||
|
||||
:param l: length of the matrix
|
||||
:param mode0: mode0 of the matrix
|
||||
:param mode1: mode1 of the matrix
|
||||
:param is_mode0_major: whether the matrix is mode0 major
|
||||
:param cutlass_dtype: cutlass dtype of the matrix
|
||||
:param init_type: type of initialization
|
||||
:param init_config: configuration for initialization
|
||||
:param device: target torch device
|
||||
"""
|
||||
|
||||
shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1)
|
||||
permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0)
|
||||
|
||||
if cutlass_dtype.is_float and cutlass_dtype.width <= 8:
|
||||
torch_dtype = torch.int8
|
||||
else:
|
||||
torch_dtype = dtype(cutlass_dtype)
|
||||
|
||||
if init_type == TensorInitType.RANDOM and init_config is None:
|
||||
if torch_dtype.is_signed:
|
||||
min_val = -2
|
||||
max_val = 2
|
||||
else:
|
||||
min_val = 0
|
||||
max_val = 4
|
||||
init_config = RandomInitConfig(min_val=min_val, max_val=max_val)
|
||||
|
||||
# Create dtype torch tensor
|
||||
torch_tensor = create_and_permute_torch_tensor(
|
||||
shape,
|
||||
torch_dtype,
|
||||
permute_order=permute_order,
|
||||
init_type=init_type,
|
||||
init_config=init_config,
|
||||
device=device,
|
||||
)
|
||||
|
||||
return torch_tensor
|
||||
|
||||
|
||||
def cute_tensor_like(
|
||||
data_ref: torch.Tensor,
|
||||
cutlass_dtype: Type[Numeric],
|
||||
is_dynamic_layout: bool,
|
||||
assumed_align: Optional[int] = None,
|
||||
) -> tuple[Tensor, torch.Tensor]:
|
||||
"""
|
||||
Create a cute tensor use a torch tensor as the data source
|
||||
|
||||
:param data_ref: torch tensor as the data source
|
||||
:param cutlass_dtype: cutlass dtype of the cute tensor
|
||||
:param is_dynamic_layout: whether the cute tensor uses dynamic layout
|
||||
:param assumed_align: assumed alignment of the cute tensor
|
||||
"""
|
||||
|
||||
# allocate device buffer for cute tensor
|
||||
if cutlass_dtype.is_float and cutlass_dtype.width <= 8:
|
||||
torch_dtype = torch.int8
|
||||
else:
|
||||
torch_dtype = dtype(cutlass_dtype)
|
||||
torch_tensor = torch.empty_like(data_ref, dtype=torch_dtype, device="cuda")
|
||||
|
||||
# create cute tensor using the device buffer
|
||||
cute_tensor = from_dlpack(torch_tensor, assumed_align=assumed_align)
|
||||
cute_tensor.element_type = cutlass_dtype
|
||||
if is_dynamic_layout:
|
||||
for i, stride in enumerate(torch_tensor.stride()):
|
||||
if stride == 1:
|
||||
leading_dim = i
|
||||
break
|
||||
cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=leading_dim)
|
||||
|
||||
# initialize the cute tensor data
|
||||
if cutlass_dtype.is_float and cutlass_dtype.width <= 8:
|
||||
cute_tensor = convert_cute_tensor(
|
||||
data_ref.to(dtype=torch.float32),
|
||||
cute_tensor,
|
||||
cutlass_dtype,
|
||||
is_dynamic_layout,
|
||||
)
|
||||
else:
|
||||
torch_tensor.copy_(data_ref.to(dtype=torch_dtype))
|
||||
|
||||
return cute_tensor, torch_tensor
|
||||
|
||||
@ -15,20 +15,6 @@ from .static_persistent_tile_scheduler import (
|
||||
StaticPersistentTileScheduler,
|
||||
)
|
||||
|
||||
from .pipeline import (
|
||||
Agent,
|
||||
CooperativeGroup,
|
||||
PipelineUserType,
|
||||
PipelineState,
|
||||
make_pipeline_state,
|
||||
PipelineAsync,
|
||||
PipelineTmaAsync,
|
||||
PipelineTmaUmma,
|
||||
PipelineUmmaAsync,
|
||||
PipelineTmaStore,
|
||||
pipeline_init_wait,
|
||||
)
|
||||
|
||||
from .hardware_info import (
|
||||
HardwareInfo,
|
||||
)
|
||||
@ -65,6 +51,8 @@ from .smem_allocator import SmemAllocator
|
||||
from .layout import LayoutEnum
|
||||
|
||||
__all__ = [
|
||||
"SmemAllocator",
|
||||
"LayoutEnum",
|
||||
"WorkTileInfo",
|
||||
"PersistentTileSchedulerParams",
|
||||
"StaticPersistentTileScheduler",
|
||||
|
||||
@ -51,8 +51,13 @@ from cutlass.cute.nvgpu.tcgen05 import (
|
||||
is_tmem_load,
|
||||
get_tmem_copy_properties,
|
||||
)
|
||||
from cutlass.cute.nvgpu.cpasync import (
|
||||
CopyBulkTensorTileG2SMulticastOp,
|
||||
CopyBulkTensorTileG2SOp,
|
||||
)
|
||||
from cutlass.utils.layout import LayoutEnum
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def compute_epilogue_tile_shape(
|
||||
cta_tile_shape: cute.Shape,
|
||||
@ -716,6 +721,7 @@ def make_smem_layout_b(
|
||||
|
||||
return b_smem_layout_staged
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def get_smem_layout_atom_epi(
|
||||
layout: LayoutEnum,
|
||||
@ -827,6 +833,7 @@ SMEM_CAPACITY = {
|
||||
"sm120": SmemCapacity.SM120_SMEM_CAPACITY_BYTES.value,
|
||||
}
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def make_trivial_tiled_mma(
|
||||
ab_dtype: Type[Numeric],
|
||||
@ -908,3 +915,139 @@ def make_trivial_tiled_mma(
|
||||
raise TypeError(f"unsupported ab_dtype, got {ab_dtype}")
|
||||
|
||||
return cute.make_tiled_mma(cute.make_mma_atom(mma_op))
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def cluster_shape_to_tma_atom_A(
|
||||
cluster_shape_mnk: cute.Shape, atom_thr_id: cute.Layout, *, loc=None, ip=None
|
||||
) -> Union[CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileG2SOp]:
|
||||
"""
|
||||
Select the appropriate TMA copy atom for A based on the number of SMs and the multicast flag.
|
||||
|
||||
:param cluster_shape_mnk: The shape of the cluster
|
||||
:type cluster_shape_mnk: cute.Shape
|
||||
:param atom_thr_id: The thread ID of the atom
|
||||
:type atom_thr_id: cute.Layout
|
||||
|
||||
:return: The appropriate TMA copy atom kind
|
||||
:rtype: cpasync.CopyBulkTensorTileG2SMulticastOp or cpasync.CopyBulkTensorTileG2SOp
|
||||
|
||||
:raise ValueError: If the atom_sm_cnt is invalid
|
||||
:raise ValueError: If the cluster shape is not divisible by the atom SM count
|
||||
"""
|
||||
atom_sm_cnt = cute.size(atom_thr_id, loc=loc, ip=ip)
|
||||
mcast = not (cute.size(cluster_shape_mnk, mode=[1], loc=loc, ip=ip) == 1)
|
||||
cluster_size = cute.size(cluster_shape_mnk, loc=loc, ip=ip)
|
||||
|
||||
if not isinstance(cluster_size, int) or not isinstance(atom_sm_cnt, int):
|
||||
raise ValueError(
|
||||
f"Dynamic cluster shape or atom SM count is not supported: {cluster_shape_mnk} and {atom_thr_id}"
|
||||
)
|
||||
|
||||
if cute.size(cluster_shape_mnk, mode=[0], loc=loc, ip=ip) % atom_sm_cnt != 0:
|
||||
raise ValueError(
|
||||
f"Cluster shape not divisible by MMA size: {cluster_shape_mnk} and {atom_thr_id}"
|
||||
)
|
||||
|
||||
if atom_sm_cnt == 2 and mcast:
|
||||
return CopyBulkTensorTileG2SMulticastOp(CtaGroup.TWO)
|
||||
elif atom_sm_cnt == 2 and not mcast:
|
||||
return CopyBulkTensorTileG2SOp(CtaGroup.TWO)
|
||||
elif atom_sm_cnt == 1 and mcast:
|
||||
return CopyBulkTensorTileG2SMulticastOp(CtaGroup.ONE)
|
||||
elif atom_sm_cnt == 1 and not mcast:
|
||||
return CopyBulkTensorTileG2SOp(CtaGroup.ONE)
|
||||
|
||||
raise ValueError(
|
||||
f"Unsupported Configuration for SM100 TMA: {cluster_shape_mnk} and {atom_thr_id}"
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def cluster_shape_to_tma_atom_B(
|
||||
cluster_shape_mnk: cute.Shape, atom_thr_id: cute.Layout, *, loc=None, ip=None
|
||||
) -> Union[CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileG2SOp]:
|
||||
"""
|
||||
Select the appropriate TMA copy atom for Bbased on the number of SMs and the multicast flag.
|
||||
|
||||
:param cluster_shape_mnk: The shape of the cluster
|
||||
:type cluster_shape_mnk: cute.Shape
|
||||
:param atom_thr_id: The thread ID of the atom
|
||||
:type atom_thr_id: cute.Layout
|
||||
|
||||
:return: The appropriate TMA copy atom kind
|
||||
:rtype: cpasync.CopyBulkTensorTileG2SMulticastOp or cpasync.CopyBulkTensorTileG2SOp
|
||||
|
||||
:raise ValueError: If the atom_sm_cnt is invalid
|
||||
:raise ValueError: If the cluster shape is not divisible by the atom SM count
|
||||
"""
|
||||
atom_sm_cnt = cute.size(atom_thr_id, loc=loc, ip=ip)
|
||||
mcast = not (cute.size(cluster_shape_mnk, mode=[0], loc=loc, ip=ip) == atom_sm_cnt)
|
||||
cluster_size = cute.size(cluster_shape_mnk, loc=loc, ip=ip)
|
||||
|
||||
if not isinstance(cluster_size, int) or not isinstance(atom_sm_cnt, int):
|
||||
raise ValueError(
|
||||
f"Dynamic cluster shape or atom SM count is not supported: {cluster_shape_mnk} and {atom_thr_id}"
|
||||
)
|
||||
|
||||
if cute.size(cluster_shape_mnk, mode=[0], loc=loc, ip=ip) % atom_sm_cnt != 0:
|
||||
raise ValueError(
|
||||
f"Cluster shape not divisible by MMA size: {cluster_shape_mnk} and {atom_thr_id}"
|
||||
)
|
||||
|
||||
if atom_sm_cnt == 2 and mcast:
|
||||
return CopyBulkTensorTileG2SMulticastOp(CtaGroup.TWO)
|
||||
elif atom_sm_cnt == 2 and not mcast:
|
||||
return CopyBulkTensorTileG2SOp(CtaGroup.TWO)
|
||||
elif atom_sm_cnt == 1 and mcast:
|
||||
return CopyBulkTensorTileG2SMulticastOp(CtaGroup.ONE)
|
||||
elif atom_sm_cnt == 1 and not mcast:
|
||||
return CopyBulkTensorTileG2SOp(CtaGroup.ONE)
|
||||
|
||||
raise ValueError(
|
||||
f"Unsupported Configuration for SM100 TMA: {cluster_shape_mnk} and {atom_thr_id}"
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def cluster_shape_to_tma_atom_SFB(
|
||||
cluster_shape_mnk: cute.Shape, atom_thr_id: cute.Layout, *, loc=None, ip=None
|
||||
) -> Union[CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileG2SOp]:
|
||||
"""
|
||||
Select the appropriate TMA copy atom for SFB based on the number of SMs and the multicast flag.
|
||||
|
||||
:param cluster_shape_mnk: The shape of the cluster
|
||||
:type cluster_shape_mnk: cute.Shape
|
||||
:param atom_thr_id: The thread ID of the atom
|
||||
:type atom_thr_id: cute.Layout
|
||||
|
||||
:return: The appropriate TMA copy atom kind
|
||||
:rtype: cpasync.CopyBulkTensorTileG2SMulticastOp or cpasync.CopyBulkTensorTileG2SOp
|
||||
|
||||
:raise ValueError: If the atom_sm_cnt is invalid
|
||||
:raise ValueError: If the cluster shape is not divisible by the atom SM count
|
||||
"""
|
||||
atom_sm_cnt = cute.size(atom_thr_id, loc=loc, ip=ip)
|
||||
mcast = not (cute.size(cluster_shape_mnk, mode=[0], loc=loc, ip=ip) == 1)
|
||||
cluster_size = cute.size(cluster_shape_mnk, loc=loc, ip=ip)
|
||||
|
||||
if not isinstance(cluster_size, int) or not isinstance(atom_sm_cnt, int):
|
||||
raise ValueError(
|
||||
f"Dynamic cluster shape or atom SM count is not supported: {cluster_shape_mnk} and {atom_thr_id}"
|
||||
)
|
||||
|
||||
if cute.size(cluster_shape_mnk, mode=[0], loc=loc, ip=ip) % atom_sm_cnt != 0:
|
||||
raise ValueError(
|
||||
f"Cluster shape not divisible by MMA size: {cluster_shape_mnk} and {atom_thr_id}"
|
||||
)
|
||||
|
||||
if atom_sm_cnt == 2:
|
||||
return CopyBulkTensorTileG2SMulticastOp(CtaGroup.TWO)
|
||||
elif atom_sm_cnt == 1 and mcast:
|
||||
return CopyBulkTensorTileG2SMulticastOp(CtaGroup.ONE)
|
||||
elif atom_sm_cnt == 1 and not mcast:
|
||||
return CopyBulkTensorTileG2SOp(CtaGroup.ONE)
|
||||
|
||||
raise ValueError(
|
||||
f"Unsupported Configuration for SM100 TMA: {cluster_shape_mnk} and {atom_thr_id}"
|
||||
)
|
||||
|
||||
@ -96,6 +96,10 @@ def make_trivial_tiled_mma(
|
||||
acc_dtype: Type[Numeric],
|
||||
atom_layout_mnk: Tuple[int, int, int],
|
||||
tiler_mn: Tuple[int, int],
|
||||
a_source: OperandSource = OperandSource.SMEM,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> cute.TiledMma:
|
||||
"""Make a tiled MMA atom with given data type, leading dimension, cta group and mma tile shape.
|
||||
By default, the MMA atom is created with SMEM operand source for A.
|
||||
@ -131,7 +135,7 @@ def make_trivial_tiled_mma(
|
||||
a_dtype,
|
||||
acc_dtype,
|
||||
(*tiler_mn, 16),
|
||||
OperandSource.SMEM,
|
||||
a_source,
|
||||
a_leading_mode,
|
||||
b_leading_mode,
|
||||
)
|
||||
@ -144,7 +148,7 @@ def make_trivial_tiled_mma(
|
||||
b_dtype,
|
||||
acc_dtype,
|
||||
(*tiler_mn, 32),
|
||||
OperandSource.SMEM,
|
||||
a_source,
|
||||
a_leading_mode,
|
||||
b_leading_mode,
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -18,41 +18,25 @@ from cutlass.cute.arch import get_dyn_smem
|
||||
|
||||
|
||||
class SmemAllocator:
|
||||
"""
|
||||
A class for managing shared memory allocation on GPU.
|
||||
"""A class for managing shared memory allocation on GPU.
|
||||
|
||||
This class manages a chunk of shared memory and provide APIs for sub-allocation
|
||||
This class manages a chunk of shared memory and provides APIs for sub-allocation
|
||||
inside the chunk.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
_base : cute.Pointer as i8 typed dynamic value
|
||||
The current base address of the shared memory.
|
||||
:ivar _base: The current base address of the shared memory as an i8 typed dynamic value.
|
||||
:type _base: cute.Pointer
|
||||
:ivar _allocated_bytes: The total number of bytes allocated in shared memory.
|
||||
:type _allocated_bytes: int
|
||||
|
||||
_allocated_bytes:
|
||||
The bytes allocated in shared memory.
|
||||
|
||||
Methods
|
||||
-------
|
||||
allocate(num_bytes, alignment)
|
||||
Allocates num_bytes in the shared memory with the given byte alignment.
|
||||
|
||||
allocate_value(value_ty, num_elems)
|
||||
Allocates num_elems of value_ty values in the shared memory.
|
||||
|
||||
allocate_tensor(value_ty, layout, alignment)
|
||||
Allocates a tensor in the shared memory with given layout and byte alignment.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This class is responsible for managing the allocation of tensors in shared memory.
|
||||
.. note::
|
||||
This class is responsible for managing the allocation of tensors in shared memory.
|
||||
The base pointer is aligned to 1024 bytes upon initialization.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initializes the SmemAllocator instance with dynamic smem base ptr,
|
||||
which is i8 type and aligned to 1024.
|
||||
"""Initialize the SmemAllocator instance.
|
||||
|
||||
Creates a dynamic shared memory base pointer of type i8, aligned to 1024 bytes.
|
||||
"""
|
||||
self._base = get_dyn_smem(Int8, alignment=1024)
|
||||
self._allocated_bytes = 0
|
||||
@ -64,30 +48,19 @@ class SmemAllocator:
|
||||
def allocate(self, size_or_type: cute.struct, byte_alignment: int): ...
|
||||
|
||||
def allocate(self, size_or_type, byte_alignment: int = 1) -> int:
|
||||
"""Allocate a block of memory with specified size and alignment.
|
||||
|
||||
This method adjusts the base pointer to ensure proper alignment and updates
|
||||
the internal state to track allocated memory.
|
||||
|
||||
:param size_or_type: The number of bytes to allocate or a struct class
|
||||
:type size_or_type: Union[int, cute.struct]
|
||||
:param byte_alignment: The byte alignment requirement, defaults to 1 (no alignment)
|
||||
:type byte_alignment: int, optional
|
||||
:return: Pointer to the start of the allocated memory block or struct instance
|
||||
:rtype: cute.Pointer
|
||||
:raises ValueError: If size is negative or alignment is less than 1
|
||||
"""
|
||||
Allocates a block of memory with the specified size and byte alignment.
|
||||
|
||||
This method adjusts the base cute.Pointer to ensure that the allocated memory
|
||||
is aligned according to the specified byte alignment. It updates the internal
|
||||
state to reflect the new base cute.Pointer and the total allocated bytes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
size_or_type : int or struct
|
||||
The number of bytes to allocate or struct class.
|
||||
byte_alignment : int
|
||||
The byte alignment requirement for the allocation. Defaults to 1 (no alignment).
|
||||
|
||||
Returns
|
||||
----------
|
||||
A cute.Pointer to the start of the allocated memory block or struct instance.
|
||||
|
||||
Raises
|
||||
----------
|
||||
ValueError
|
||||
If num_bytes is negative or if byte_alignmemt is less than 1.
|
||||
"""
|
||||
|
||||
if isinstance(size_or_type, cute.struct):
|
||||
alignment = max(byte_alignment, size_or_type.__alignof__())
|
||||
base_ptr = self.allocate(size_or_type.__sizeof__(), alignment)
|
||||
@ -110,27 +83,16 @@ class SmemAllocator:
|
||||
return ptr
|
||||
|
||||
def allocate_array(self, element_type: Type[Numeric], num_elems: int = 1):
|
||||
"""
|
||||
Allocates num_elems values of element_type in shared memory.
|
||||
"""Allocate an array of elements in shared memory.
|
||||
|
||||
This method calls allocate() to return a byte ptr, pointing to start of shared
|
||||
memory. Then calls cute.recast_ptr() to recast this byte cute.Pointer to element_type.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
element_type : Type[Numeric]
|
||||
The type of the values in the tensor.
|
||||
num_elems : int, optional
|
||||
The number of elements for each allocation. Defaults to 1.
|
||||
|
||||
Returns
|
||||
----------
|
||||
A value_type cute.Pointer to the start of the allocated memory block.
|
||||
|
||||
Raises
|
||||
----------
|
||||
ValueError
|
||||
If num_elems is less than 1.
|
||||
:param element_type: The type of elements to allocate
|
||||
:type element_type: Type[Numeric]
|
||||
:param num_elems: Number of elements to allocate, defaults to 1
|
||||
:type num_elems: int, optional
|
||||
:return: Pointer to the start of the allocated array
|
||||
:rtype: cute.Pointer
|
||||
:raises ValueError: If num_elems is less than 1
|
||||
:raises TypeError: If element_type is not a Numeric type
|
||||
"""
|
||||
if num_elems < 1:
|
||||
raise ValueError("num_elems must be at least 1")
|
||||
@ -152,28 +114,21 @@ class SmemAllocator:
|
||||
byte_alignment: int = 1,
|
||||
swizzle: cute.Swizzle = None,
|
||||
):
|
||||
"""
|
||||
Allocates a tensor in the shared memory with value type, layout and byte alignment.
|
||||
"""Allocate a tensor in shared memory.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
element_type : Type[Numeric]
|
||||
The type of the values in the tensor.
|
||||
layout : int | DynamicInt | cute.Layout | cute.ComposedLayout
|
||||
The layout of the tensor.
|
||||
byte_alignment : int, optional
|
||||
The byte alignment requirement for the allocation. Defaults to 1 (no alignment).
|
||||
swizzle : cute.Swizzle
|
||||
A swizzle for the iterator (for position-dependent swizzling).
|
||||
|
||||
Returns
|
||||
-------
|
||||
tensor : cute.Tensor
|
||||
The allocated tensor with specified value type, layout and byte alignment.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The base address is updated to point to the next available memory location.
|
||||
:param element_type: The type of elements in the tensor
|
||||
:type element_type: Type[Numeric]
|
||||
:param layout: The layout specification for the tensor
|
||||
:type layout: Union[int, cute.Layout, cute.ComposedLayout]
|
||||
:param byte_alignment: The byte alignment requirement, defaults to 1
|
||||
:type byte_alignment: int, optional
|
||||
:param swizzle: Swizzle for position-dependent swizzling, defaults to None
|
||||
:type swizzle: cute.Swizzle, optional
|
||||
:return: The allocated tensor with specified properties
|
||||
:rtype: cute.Tensor
|
||||
:raises TypeError: If element_type is not a Numeric type or if swizzle conflicts with layout
|
||||
:raises ValueError: If allocation is not byte-aligned
|
||||
:raises NotImplementedError: If dynamic layout is specified
|
||||
"""
|
||||
if not isinstance(element_type, NumericMeta):
|
||||
raise TypeError(
|
||||
|
||||
@ -23,6 +23,11 @@ from ..base_dsl.ast_helpers import (
|
||||
dynamic_expr,
|
||||
assert_executor,
|
||||
bool_cast,
|
||||
compare_executor,
|
||||
any_executor,
|
||||
all_executor,
|
||||
range_value_check,
|
||||
range_perf_warning,
|
||||
)
|
||||
|
||||
from ..base_dsl import *
|
||||
|
||||
@ -20,6 +20,7 @@ from inspect import isclass
|
||||
import functools
|
||||
import pkgutil
|
||||
from dataclasses import is_dataclass
|
||||
from collections.abc import Sequence
|
||||
|
||||
from ..base_dsl import *
|
||||
from ..base_dsl import compiler
|
||||
@ -51,6 +52,11 @@ from ..base_dsl.ast_helpers import (
|
||||
while_executor,
|
||||
assert_executor,
|
||||
bool_cast,
|
||||
compare_executor,
|
||||
any_executor,
|
||||
all_executor,
|
||||
range_value_check,
|
||||
range_perf_warning,
|
||||
)
|
||||
from ..base_dsl.runtime.dlpack_runtime import (
|
||||
get_cute_tensor_c_pointer,
|
||||
@ -67,18 +73,6 @@ from .cutlass_ast_decorators import (
|
||||
_while_execute_dynamic,
|
||||
)
|
||||
|
||||
# =============================================================================
|
||||
# Set the AST decorator
|
||||
# =============================================================================
|
||||
|
||||
# Set the DSL specific functions
|
||||
executor.set_functions(
|
||||
is_dynamic_expression,
|
||||
_loop_execute_range_dynamic,
|
||||
_if_execute_dynamic,
|
||||
_while_execute_dynamic,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Cutlass DSL Base Abstract Class
|
||||
@ -1023,7 +1017,6 @@ def select_(cond, if_value, else_value):
|
||||
)
|
||||
return value
|
||||
|
||||
# Non-DSL dynamic cond should be handled before this.
|
||||
if const_expr(not is_dynamic_expression(cond)):
|
||||
raise DSLRuntimeError("Conditional expression must be dynamic")
|
||||
|
||||
@ -1089,6 +1082,7 @@ def for_generate(
|
||||
iter_args: Optional[Sequence[ir.Value]] = None,
|
||||
*,
|
||||
unroll: LoopUnroll = None,
|
||||
pipelining=None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
@ -1126,6 +1120,9 @@ def for_generate(
|
||||
if unroll is not None:
|
||||
for_op.attributes["loop_annotation"] = unroll
|
||||
|
||||
if pipelining is not None:
|
||||
for_op.attributes["cutlass.pipelining"] = _createI32Attr(pipelining)
|
||||
|
||||
iv = for_op.induction_variable
|
||||
new_results = new_from_mlir_values(iter_args, for_op.results)
|
||||
new_iter_args = new_from_mlir_values(iter_args, for_op.inner_iter_args)
|
||||
@ -1319,3 +1316,122 @@ def while_generate(
|
||||
Generate a WhileLoopContext for a dynamic loop.
|
||||
"""
|
||||
return WhileLoopContext(inputs, condition, loc=loc, ip=ip)
|
||||
|
||||
|
||||
def equal(lhs, rhs):
|
||||
if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs):
|
||||
return lhs == rhs
|
||||
|
||||
# Both sequence
|
||||
if isinstance(lhs, Sequence) and isinstance(rhs, Sequence):
|
||||
# Short-circuit for unequal length
|
||||
if len(lhs) != len(rhs):
|
||||
return False
|
||||
return all_(equal(l, r) for l, r in zip(lhs, rhs))
|
||||
return lhs == rhs
|
||||
|
||||
|
||||
def in_(lhs, rhs, op):
|
||||
if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs):
|
||||
return lhs in rhs
|
||||
|
||||
if not isinstance(rhs, Sequence):
|
||||
raise DSLRuntimeError(
|
||||
f"'{op}' not supported between instances of {type(lhs)} and {type(rhs)}"
|
||||
)
|
||||
|
||||
return any_(equal(lhs, r) for r in rhs)
|
||||
|
||||
|
||||
def _lt_gt(lhs, rhs, op):
|
||||
def native_lt_gt(lhs, rhs, op):
|
||||
if op == "<":
|
||||
return lhs < rhs
|
||||
elif op == ">":
|
||||
return lhs > rhs
|
||||
else:
|
||||
raise DSLRuntimeError(f"Unsupported comparison operator: {op}")
|
||||
|
||||
if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs):
|
||||
return native_lt_gt(lhs, rhs, op)
|
||||
|
||||
# Both sequence, comparisons other than == and != do not allow mixing different types of sequences
|
||||
if (
|
||||
isinstance(lhs, Sequence)
|
||||
and isinstance(rhs, Sequence)
|
||||
and type(lhs) == type(rhs)
|
||||
):
|
||||
unequal_found = False
|
||||
comp_results = []
|
||||
mask = []
|
||||
for l, r in zip(lhs, rhs):
|
||||
is_equal = equal(l, r)
|
||||
mask.append(not_(or_(is_equal, unequal_found)))
|
||||
unequal_found = not_(is_equal)
|
||||
comp_results.append(_lt_gt(l, r, op))
|
||||
|
||||
result = any_(and_(r, m) for r, m in zip(comp_results, mask))
|
||||
|
||||
if len(lhs) != len(rhs):
|
||||
# Ref https://docs.python.org/3/tutorial/datastructures.html#comparing-sequences-and-other-types
|
||||
# If one sequence is an initial sub-sequence of the other, the shorter sequence is the smaller (lesser) one
|
||||
has_valid_mask = any_(mask)
|
||||
if op == "<":
|
||||
length_result = len(lhs) < len(rhs)
|
||||
elif op == ">":
|
||||
length_result = len(lhs) > len(rhs)
|
||||
if type(has_valid_mask) == bool:
|
||||
return result if has_valid_mask else length_result
|
||||
else:
|
||||
return select_(has_valid_mask, result, length_result)
|
||||
else:
|
||||
return result
|
||||
else:
|
||||
return native_lt_gt(lhs, rhs, op)
|
||||
|
||||
|
||||
def greater_than(lhs, rhs):
|
||||
return _lt_gt(lhs, rhs, ">")
|
||||
|
||||
|
||||
def less_than(lhs, rhs):
|
||||
return _lt_gt(lhs, rhs, "<")
|
||||
|
||||
|
||||
def _compare_executor(left, comparators, ops):
|
||||
result = left
|
||||
for comparator, op in zip(comparators, ops):
|
||||
# 'is' and 'is not' are pure python operators
|
||||
if op == "is":
|
||||
result = result is comparator
|
||||
elif op == "is not":
|
||||
result = result is not comparator
|
||||
elif op in ["in", "not in"]:
|
||||
result = in_(left, comparator, op)
|
||||
elif op in ["==", "!="]:
|
||||
result = equal(left, comparator)
|
||||
elif op in ["<", ">="]:
|
||||
result = less_than(left, comparator)
|
||||
elif op in [">", "<="]:
|
||||
result = greater_than(left, comparator)
|
||||
else:
|
||||
raise DSLRuntimeError(f"Unsupported comparison operator: {op}")
|
||||
# Invert the result for NotIn, NotEq, GtE, LtE
|
||||
if op in ["not in", "!=", ">=", "<="]:
|
||||
result = not_(result)
|
||||
return result
|
||||
|
||||
# =============================================================================
|
||||
# Set the AST decorator
|
||||
# =============================================================================
|
||||
|
||||
# Set the DSL specific functions
|
||||
executor.set_functions(
|
||||
is_dynamic_expression,
|
||||
_loop_execute_range_dynamic,
|
||||
_if_execute_dynamic,
|
||||
_while_execute_dynamic,
|
||||
_compare_executor,
|
||||
any_,
|
||||
all_,
|
||||
)
|
||||
|
||||
@ -10,6 +10,7 @@
|
||||
# is strictly prohibited.
|
||||
|
||||
from typing import List, Tuple
|
||||
from types import NoneType
|
||||
from cutlass._mlir import ir
|
||||
from cutlass._mlir.dialects import scf, arith
|
||||
from cutlass._mlir.extras import types as T
|
||||
@ -145,6 +146,30 @@ class ScfGenerator:
|
||||
# Use the provided terminator generator
|
||||
block_term_op_builder[builder](region_result)
|
||||
else:
|
||||
# For standard yield op, check result
|
||||
for arg, result, name in zip(
|
||||
mix_iter_args,
|
||||
(
|
||||
region_result
|
||||
if isinstance(region_result, list)
|
||||
else [region_result]
|
||||
),
|
||||
mix_iter_arg_names,
|
||||
):
|
||||
if isinstance(arg, NoneType) and not isinstance(
|
||||
result, NoneType
|
||||
):
|
||||
raise DSLRuntimeError(
|
||||
(
|
||||
f"`{name}` is None prior to this `{op_type_name}`, "
|
||||
f"and update to non-None value inside of this `{op_type_name}` is not supported."
|
||||
),
|
||||
suggestion=(
|
||||
f"Please make sure `{name}` is not None prior to this `{op_type_name}`, "
|
||||
f"or mark this `{op_type_name}` with "
|
||||
f"`{'range' if op_type_name == 'for' else 'const_expr'}`."
|
||||
),
|
||||
)
|
||||
# Normalize region_result
|
||||
region_result_list = ScfGenerator._normalize_region_result_to_list(
|
||||
region_result
|
||||
@ -200,6 +225,7 @@ def _loop_execute_range_dynamic(
|
||||
mix_iter_arg_names: List[str] = [],
|
||||
unroll: int = -1,
|
||||
unroll_full: bool = False,
|
||||
pipelining: int = None,
|
||||
):
|
||||
"""
|
||||
Example: build an scf.for with optional unroll, using our universal helper.
|
||||
@ -236,6 +262,18 @@ def _loop_execute_range_dynamic(
|
||||
unroll_attr = LoopUnroll(count=unroll)
|
||||
log().debug("Unroll attribute: %s", unroll_attr)
|
||||
|
||||
pipelining_attr = None
|
||||
if pipelining is not None:
|
||||
if pipelining >= 0:
|
||||
pipelining_attr = ir.IntegerAttr.get(
|
||||
ir.IntegerType.get_signless(32), pipelining
|
||||
)
|
||||
else:
|
||||
raise DSLRuntimeError(
|
||||
f"Pipelining must be non-negative, got {pipelining}"
|
||||
)
|
||||
log().debug("Pipelining attribute: %s", pipelining_attr)
|
||||
|
||||
log().debug(
|
||||
"Creating scf.ForOp \n\t\tstart=%s: type : %s\n\t\tstop=%s: type : %s\n\t\tstep=%s: type : %s",
|
||||
start_,
|
||||
@ -265,6 +303,9 @@ def _loop_execute_range_dynamic(
|
||||
if unroll_attr is not None:
|
||||
for_op.attributes["loop_annotation"] = unroll_attr
|
||||
|
||||
if pipelining_attr is not None:
|
||||
for_op.attributes["cutlass.pipelining"] = pipelining_attr
|
||||
|
||||
return for_op
|
||||
|
||||
def for_body_builder(
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
# Use `pip install -r requirements.txt` with the present file to install a
|
||||
# wheel consistent with the present state of the github repository
|
||||
nvidia-cutlass-dsl==4.0.0
|
||||
nvidia-cutlass-dsl==4.1.0.dev0
|
||||
|
||||
@ -73,6 +73,7 @@ class EVTFrontendBase:
|
||||
self.dag_ir = DAGIR(self.cc, self.element_compute)
|
||||
self.compute_cnt = 0
|
||||
self.layout_cnt = 0
|
||||
self.imm_cnt = 0
|
||||
|
||||
self.pass_manager = EVTPassManager(
|
||||
self.dag_ir,
|
||||
@ -107,6 +108,13 @@ class EVTFrontendBase:
|
||||
# Parse the input
|
||||
self.parse(*args, **kwargs)
|
||||
|
||||
# Verify the DAG IR to ensure that "D" is the output node with out_degree = 0
|
||||
if (self.cc >= 90):
|
||||
if (self.dag_ir.out_degree("D") != 0):
|
||||
raise RuntimeError(
|
||||
f"On SM90 or higher, D is expected to be a output node with 0 users to "
|
||||
f"enable smem reuse between C and D, but got {self.dag_ir.out_degree('D')}")
|
||||
|
||||
# Run the passes
|
||||
self.pass_manager()
|
||||
# Set the epilogue type
|
||||
@ -187,7 +195,8 @@ class EVTFrontendBase:
|
||||
except:
|
||||
raise ValueError(f"{type(value).__name__} cannot be converted to float.")
|
||||
|
||||
name = f"imm_{value}".replace('.', '_')
|
||||
name = f"imm_{value}_k{self.imm_cnt}".replace('.', '_')
|
||||
self.imm_cnt += 1
|
||||
load_node = LoadNode(name)
|
||||
load_node.tensor = {"tensor": value, "is_constant": True}
|
||||
self.add_node(load_node)
|
||||
|
||||
@ -42,7 +42,7 @@ from cutlass_library import DataType
|
||||
|
||||
import cutlass
|
||||
from cutlass.backend.evt.frontend.frontend_base import EVTFrontendBase
|
||||
from cutlass.backend.epilogue import relu
|
||||
from cutlass.backend.epilogue import identity, relu, tanh, sigmoid, silu, hardswish, gelu
|
||||
from cutlass.backend.library import FunctionalOp
|
||||
|
||||
|
||||
@ -72,10 +72,17 @@ class PythonASTFrontend(EVTFrontendBase, ast.NodeVisitor):
|
||||
ast.Div: FunctionalOp.Divides,
|
||||
"maximum": FunctionalOp.Maximum,
|
||||
"minimum": FunctionalOp.Minimum,
|
||||
"identity": identity.binding_type,
|
||||
"relu": relu.binding_type,
|
||||
"tanh": tanh.binding_type,
|
||||
"sigmoid": sigmoid.binding_type,
|
||||
"silu": silu.binding_type,
|
||||
"hardswish": hardswish.binding_type,
|
||||
"gelu": gelu.binding_type,
|
||||
"multiply_add": FunctionalOp.MultiplyAdd,
|
||||
"sum": (FunctionalOp.Plus, FunctionalOp.AtomicAdd),
|
||||
"max": (FunctionalOp.Maximum, FunctionalOp.AtomicMaximum)
|
||||
"max": (FunctionalOp.Maximum, FunctionalOp.AtomicMaximum),
|
||||
"exp": FunctionalOp.Exp
|
||||
}
|
||||
return mapping[op]
|
||||
|
||||
|
||||
@ -38,7 +38,9 @@ import networkx as nx
|
||||
|
||||
from cutlass_library import DataType
|
||||
|
||||
from cutlass.backend.evt.ir.compute_nodes import ComputeNode
|
||||
from cutlass.backend.evt.ir.node import NodeBase
|
||||
from cutlass.backend.library import ActivationOp
|
||||
from cutlass.backend.utils import device_cc
|
||||
|
||||
|
||||
@ -59,6 +61,8 @@ class DAGIR:
|
||||
|
||||
self.cc = cc
|
||||
|
||||
self.identity_counter = 0
|
||||
|
||||
#
|
||||
# IR manipulator
|
||||
#
|
||||
@ -79,7 +83,21 @@ class DAGIR:
|
||||
raise SyntaxError(f"Variable '{src}' is undefined.")
|
||||
if not self.has_node(dst):
|
||||
raise SyntaxError(f"Variable '{dst}' is undefined.")
|
||||
self._graph.add_edge(src, dst, weight=weight)
|
||||
|
||||
if self._graph.has_edge(src, dst):
|
||||
# The DiGraph doesn't support multiple edges between two nodes
|
||||
# We insert an identity node in such case as a workaround
|
||||
identity_name = f"autogen_identity_{self.identity_counter}"
|
||||
self.identity_counter += 1
|
||||
compute_node = ComputeNode(
|
||||
name=identity_name, fn=ActivationOp.Identity,
|
||||
element_output=self.element_compute,
|
||||
element_compute=self.element_compute)
|
||||
self.add_node(compute_node)
|
||||
self.add_edge(src, identity_name, 0)
|
||||
self.add_edge(identity_name, dst, weight)
|
||||
else:
|
||||
self._graph.add_edge(src, dst, weight=weight)
|
||||
|
||||
def remove_node(self, node: str):
|
||||
"""
|
||||
|
||||
@ -51,15 +51,19 @@ class Tensor:
|
||||
"""
|
||||
The tensor abstracts the data type
|
||||
"""
|
||||
def __init__(self, tensor=None, element=None, shape=None, layout_tag=None, is_constant=False) -> None:
|
||||
def __init__(self, tensor=None, element=None, shape=None, stride=None,layout_tag=None, is_constant=False) -> None:
|
||||
if element is not None and tensor is not None:
|
||||
raise Exception(f"Must not specify both element and tensor")
|
||||
elif shape is not None and tensor is not None:
|
||||
raise Exception(f"Must not specify both shape and tensor")
|
||||
elif layout_tag is not None and tensor is not None:
|
||||
raise Exception(f"Must not specify both layout_tag and tensor")
|
||||
elif (element is None or layout_tag is None or shape is None) and (tensor is None) :
|
||||
raise Exception(f"Must specify one of (element, shape, layout) or (tensor)")
|
||||
elif (element is None or (layout_tag is None and stride is None) or shape is None) and (tensor is None) :
|
||||
raise Exception(f"Must specify one of (element, shape, layout/stride) or (tensor)")
|
||||
elif stride is not None and tensor is not None:
|
||||
raise Exception(f"Must not specify both stride and tensor")
|
||||
elif stride is not None and layout_tag is not None:
|
||||
raise Exception(f"Must not specify layout_tag when stride is provided")
|
||||
|
||||
if isinstance(tensor, Tensor):
|
||||
# Directly copy all the attributes
|
||||
@ -70,10 +74,13 @@ class Tensor:
|
||||
else:
|
||||
self.element, layout_tag = get_datatype_and_layout(tensor)
|
||||
shape = get_tensor_shape(tensor)
|
||||
if layout_tag == LayoutType.RowMajor:
|
||||
self.layout = Layout(shape[::-1])
|
||||
elif layout_tag == LayoutType.ColumnMajor:
|
||||
self.layout = permutation(Layout(shape), [idx for idx in reversed(range(len(shape)))])
|
||||
if stride is not None:
|
||||
self.layout = Layout(shape[::-1], stride[::-1])
|
||||
else:
|
||||
if layout_tag == LayoutType.RowMajor:
|
||||
self.layout = Layout(shape[::-1])
|
||||
elif layout_tag == LayoutType.ColumnMajor:
|
||||
self.layout = permutation(Layout(shape), [idx for idx in reversed(range(len(shape)))])
|
||||
self.layout = canonicalization(self.layout)
|
||||
|
||||
self.is_constant = is_constant
|
||||
|
||||
@ -77,11 +77,12 @@ class PassDAG2Tree(EVTPassBase):
|
||||
reachable_nodes.append(set(self.dag_ir.all_reachable_nodes(parent)))
|
||||
# get the common reachable objects
|
||||
common_items = set.intersection(*reachable_nodes)
|
||||
node_to_fuse = set.union(*reachable_nodes).difference(common_items)
|
||||
|
||||
lca = None
|
||||
# If common ancestor exists, find the lowest one
|
||||
if len(common_items) > 0:
|
||||
topo_order = self.dag_ir.nodes_topological_order()
|
||||
lca = None
|
||||
topo_idx = -1
|
||||
for item in common_items:
|
||||
if lca is None:
|
||||
@ -91,53 +92,74 @@ class PassDAG2Tree(EVTPassBase):
|
||||
if topo_idx > topo_order.index(item):
|
||||
lca = item
|
||||
topo_idx = topo_order.index(item)
|
||||
# The lca is the output node of the DAG node
|
||||
# Get the nodes to be fused
|
||||
node_to_fuse = set.union(*reachable_nodes).difference(common_items)
|
||||
node_to_fuse.add(lca)
|
||||
# Get all the input nodes
|
||||
all_input_nodes = []
|
||||
all_output_nodes = []
|
||||
for node in node_to_fuse:
|
||||
all_input_nodes.append(set(self.dag_ir.get_all_inputs(node)))
|
||||
all_output_nodes.append(set(self.dag_ir.get_users(node)))
|
||||
all_input_nodes = set.union(*all_input_nodes)
|
||||
all_output_nodes = set.union(*all_output_nodes)
|
||||
|
||||
new_subgraph_nodes = set.union(node_to_fuse, all_input_nodes, all_output_nodes)
|
||||
|
||||
# Create the subgraph
|
||||
subgraph_ = self.dag_ir._graph.subgraph(new_subgraph_nodes)
|
||||
subgraph = DAGIR(self.dag_ir.cc)
|
||||
for node in subgraph_.nodes:
|
||||
meta = deepcopy(self.dag_ir.get_node_meta(node))
|
||||
if node not in node_to_fuse:
|
||||
meta.disabled = True
|
||||
subgraph.add_node(meta)
|
||||
for edge in subgraph_.edges:
|
||||
subgraph.add_edge(edge[0], edge[1], self.dag_ir.get_edge_weight(edge[0], edge[1]))
|
||||
|
||||
|
||||
# Create the fused node
|
||||
dag_node = TopoVisitorNode(
|
||||
name=f"dag_{lca}", subgraph=subgraph,
|
||||
output_node=self.dag_ir.get_node_meta(lca))
|
||||
self.dag_ir.add_node(dag_node)
|
||||
|
||||
# Add input edges
|
||||
for idx, node in enumerate(all_input_nodes):
|
||||
self.dag_ir.add_edge(node, dag_node.name, weight=idx)
|
||||
|
||||
# Replace all uses with DAG node (only 1 output node)
|
||||
self.dag_ir.replace_all_uses_with(lca, dag_node.name)
|
||||
|
||||
# Remove all fused nodes
|
||||
node_to_fuse.remove(lca)
|
||||
for node in node_to_fuse:
|
||||
self.dag_ir.remove_node(node)
|
||||
|
||||
else:
|
||||
raise NotImplementedError("No LCA found. Consider SplitTreeVisitor.")
|
||||
# there is no common ancestor for all the parents, we pack all the reachable
|
||||
# nodes into a single DAG node as a fallback. The lca should be the input node of
|
||||
# one of the output nodes with out_degree = 0
|
||||
potential_output_nodes = []
|
||||
for node in node_to_fuse:
|
||||
if self.dag_ir.out_degree(node) == 0:
|
||||
potential_output_nodes.append(node)
|
||||
if len(potential_output_nodes) == 0:
|
||||
raise RuntimeError(f"No output node with out degree = 0 found.")
|
||||
|
||||
output_node = None
|
||||
if (self.dag_ir.cc >= 90):
|
||||
# For SM90, the lca should be the input node of D
|
||||
if (not self.dag_ir.has_node("D")):
|
||||
raise RuntimeError(f"D is not a node in the DAG IR.")
|
||||
output_node = "D"
|
||||
else:
|
||||
output_node = potential_output_nodes[0]
|
||||
|
||||
if (output_node is None):
|
||||
raise RuntimeError(f"No output node found.")
|
||||
lca = self.dag_ir.get_all_inputs(output_node)[0]
|
||||
node_to_fuse.remove(output_node)
|
||||
|
||||
# The lca is the output node of the DAG node
|
||||
# Get the nodes to be fused
|
||||
node_to_fuse.add(lca)
|
||||
# Get all the input nodes
|
||||
all_input_nodes = []
|
||||
all_output_nodes = []
|
||||
for node in node_to_fuse:
|
||||
all_input_nodes.append(set(self.dag_ir.get_all_inputs(node)))
|
||||
all_output_nodes.append(set(self.dag_ir.get_users(node)))
|
||||
all_input_nodes = set.union(*all_input_nodes)
|
||||
all_output_nodes = set.union(*all_output_nodes)
|
||||
|
||||
new_subgraph_nodes = set.union(node_to_fuse, all_input_nodes, all_output_nodes)
|
||||
|
||||
# Create the subgraph
|
||||
subgraph_ = self.dag_ir._graph.subgraph(new_subgraph_nodes)
|
||||
subgraph = DAGIR(self.dag_ir.cc)
|
||||
for node in subgraph_.nodes:
|
||||
meta = deepcopy(self.dag_ir.get_node_meta(node))
|
||||
if node not in node_to_fuse:
|
||||
meta.disabled = True
|
||||
subgraph.add_node(meta)
|
||||
for edge in subgraph_.edges:
|
||||
subgraph.add_edge(edge[0], edge[1], self.dag_ir.get_edge_weight(edge[0], edge[1]))
|
||||
|
||||
|
||||
# Create the fused node
|
||||
dag_node = TopoVisitorNode(
|
||||
name=f"dag_{lca}", subgraph=subgraph,
|
||||
output_node=self.dag_ir.get_node_meta(lca))
|
||||
self.dag_ir.add_node(dag_node)
|
||||
|
||||
# Add input edges
|
||||
for idx, node in enumerate(all_input_nodes):
|
||||
self.dag_ir.add_edge(node, dag_node.name, weight=idx)
|
||||
|
||||
# Replace all uses with DAG node (only 1 output node)
|
||||
self.dag_ir.replace_all_uses_with(lca, dag_node.name)
|
||||
|
||||
# Remove all fused nodes
|
||||
node_to_fuse.remove(lca)
|
||||
for node in node_to_fuse:
|
||||
self.dag_ir.remove_node(node)
|
||||
|
||||
def ensures(self) -> None:
|
||||
# Ensure that after the pass, the resulting DAG becomes a tree
|
||||
|
||||
@ -118,6 +118,7 @@ class FunctionalOp(enum.Enum):
|
||||
Multiplies = enum_auto()
|
||||
MultiplyAdd = enum_auto()
|
||||
Plus = enum_auto()
|
||||
Exp = enum_auto()
|
||||
|
||||
|
||||
FunctionalOpTag = {
|
||||
@ -130,6 +131,7 @@ FunctionalOpTag = {
|
||||
FunctionalOp.Multiplies: "cutlass::multiplies",
|
||||
FunctionalOp.MultiplyAdd: "cutlass::multiply_add",
|
||||
FunctionalOp.Plus: "cutlass::plus",
|
||||
FunctionalOp.Exp: "cutlass::fast_exp_op",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -52,4 +52,5 @@ from cutlass.epilogue.evt_ops import (
|
||||
reshape,
|
||||
maximum,
|
||||
minimum,
|
||||
exp
|
||||
)
|
||||
|
||||
@ -73,6 +73,12 @@ def minimum(x, y):
|
||||
elif is_torch_tensor(x):
|
||||
return torch.minimum(x, torch.tensor(y))
|
||||
|
||||
def exp(x):
|
||||
if is_numpy_tensor(x):
|
||||
return np.exp(x)
|
||||
elif is_torch_tensor(x):
|
||||
return torch.exp(x)
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Layout manipulate nodes
|
||||
|
||||
@ -297,7 +297,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
sm100_mma_data_type_general = [
|
||||
'gemm_f16_f16_f16_f16_f16',
|
||||
'gemm_f16_f16_f16_void_f16',
|
||||
'gemm_f16_f16_f32_f16_f16',
|
||||
#'gemm_f16_f16_f32_f16_f16',
|
||||
'tf32gemm_f32_f32_f32_f32_f32',
|
||||
'bf16gemm_f32_f32_f32_f32_f32',
|
||||
]
|
||||
@ -336,7 +336,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
|
||||
'gemm.*ue4m3xf4_ue4m3xf4_f32_f16_e5m2',
|
||||
'gemm.*ue8m0xf4_ue8m0xf6_f32_f16_e5m2',
|
||||
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1',
|
||||
#'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1',
|
||||
'gemm.*ue8m0xf6_ue8m0xf6_f32_f16_ue8m0xe3m2',
|
||||
]
|
||||
|
||||
@ -547,7 +547,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
|
||||
if dynamic_cluster:
|
||||
if mode == "functional_L0":
|
||||
runtime_cluster_shapes = [[1,1,1], [2,1,1], [2,2,1], [4,1,1], [4,4,1]]
|
||||
runtime_cluster_shapes = [[1,1,1], [2,2,1]]
|
||||
else:
|
||||
runtime_cluster_shapes = [[1,1,1], [1,2,1], [2,1,1], [2,2,1], [1,4,1], [4,1,1], [2,4,1], [4,2,1], [4,4,1]]
|
||||
cta_tile_shape_m, cta_tile_shape_n, cta_tile_shape_k = operation.tile_description.threadblock_shape
|
||||
|
||||
Reference in New Issue
Block a user