v4.1 release update v2. (#2481)
This commit is contained in:
@ -300,6 +300,21 @@ def if_executor(
|
||||
|
||||
|
||||
class range:
|
||||
"""
|
||||
A range-like object for dynamic loop iteration in the DSL.
|
||||
|
||||
This class provides a range interface similar to Python's built-in range,
|
||||
but is designed to be preprocessed into constructs for dynamic
|
||||
loop execution.
|
||||
|
||||
The class supports both single-argument (stop) and three-argument
|
||||
(start, stop, step) constructors with additional parameters for loop
|
||||
optimization:
|
||||
|
||||
- unroll: Number of iterations to unroll (0 or 1 = no unrolling)
|
||||
- unroll_full: Whether to fully unroll the loop
|
||||
- pipelining: Compiler generated pipeline configuration
|
||||
"""
|
||||
@overload
|
||||
def __new__(cls, stop, unroll=0, unroll_full=False, pipelining=None):
|
||||
pass
|
||||
@ -460,7 +475,31 @@ def range_value_check(*args):
|
||||
Ensure all `range_constexpr` bounds are compile-time constants (Python ints).
|
||||
"""
|
||||
try:
|
||||
return tuple(arg.__index__() for arg in args)
|
||||
args = tuple(arg.__index__() for arg in args)
|
||||
|
||||
# Compute range size and warn if it's too large
|
||||
start = 0
|
||||
end = 0
|
||||
step = 1
|
||||
if len(args) == 1:
|
||||
end = args[0]
|
||||
elif len(args) == 2:
|
||||
start = args[0]
|
||||
end = args[1]
|
||||
elif len(args) == 3:
|
||||
start = args[0]
|
||||
end = args[1]
|
||||
step = args[2]
|
||||
|
||||
range_length = (abs(end - start) - 1) // abs(step) + 1
|
||||
if range_length >= 64:
|
||||
warnings.warn(
|
||||
f"This static loop has {range_length} iterations, which may be very slow to compile, consider using `cutlass.range(..., unroll_full=True)` instead.",
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return (start, end, step)
|
||||
except:
|
||||
raise DSLRuntimeError(
|
||||
"`range_constexpr` requires constexpr (compile-time constant) for all arguments.",
|
||||
@ -477,8 +516,8 @@ def range_perf_warning(filename, lineno, *args):
|
||||
if not has_dynamic_expr:
|
||||
warnings.warn_explicit(
|
||||
(
|
||||
"The loop was previously unrolled in Python, but now it may not unroll in IR. This may cause performance regression."
|
||||
"If you want to unroll the loop in Python, please use `range_constexpr` instead of `range`."
|
||||
"This loop is no longer unrolled and may cause performance regression. "
|
||||
"Use `range(..., unroll_full=True)` for full unrolling, or switch to `range_constexpr` when bounds are compile-time constants."
|
||||
),
|
||||
category=UserWarning,
|
||||
filename=filename,
|
||||
|
||||
@ -102,6 +102,8 @@ class ScopeManager:
|
||||
return cls([])
|
||||
|
||||
def add_to_scope(self, name: str) -> None:
|
||||
if name == "_":
|
||||
return
|
||||
self.scopes[-1].add(name)
|
||||
|
||||
def get_active_symbols(self) -> List[Set[str]]:
|
||||
@ -361,13 +363,13 @@ class DSLPreprocessor(ast.NodeTransformer):
|
||||
isinstance(func, ast.Name)
|
||||
and func.id in self.SUPPORTED_FOR_RANGE_STATEMENTS
|
||||
):
|
||||
return func.id, True
|
||||
return func.id, True, len(iter_node.keywords) != 0
|
||||
if (
|
||||
isinstance(func, ast.Attribute)
|
||||
and func.attr in self.SUPPORTED_FOR_RANGE_STATEMENTS
|
||||
):
|
||||
return func.attr, False
|
||||
return None, None
|
||||
return func.attr, False, len(iter_node.keywords) != 0
|
||||
return None, None, None
|
||||
|
||||
def transform(self, original_function, exec_globals):
|
||||
"""
|
||||
@ -378,6 +380,7 @@ class DSLPreprocessor(ast.NodeTransformer):
|
||||
transformed_tree = self.transform_function(
|
||||
original_function.__name__, original_function
|
||||
)
|
||||
self.function_globals = None
|
||||
unified_tree = ast.Module(body=transformed_tree, type_ignores=[])
|
||||
unified_tree = ast.fix_missing_locations(unified_tree)
|
||||
|
||||
@ -731,7 +734,7 @@ class DSLPreprocessor(ast.NodeTransformer):
|
||||
self.scope_manager.add_to_scope(node.target.id)
|
||||
|
||||
# For static for loop (for with range_constexpr or not range based for), preprocessor keeps the loop.
|
||||
range_kind, is_builtin_range = self._get_range_kind(node.iter)
|
||||
range_kind, is_builtin_range, has_keyword = self._get_range_kind(node.iter)
|
||||
if range_kind == "range_constexpr" or range_kind == None:
|
||||
self.generic_visit(node)
|
||||
if range_kind == "range_constexpr":
|
||||
@ -752,7 +755,7 @@ class DSLPreprocessor(ast.NodeTransformer):
|
||||
warnings.simplefilter("default", DeprecationWarning) # reset filter
|
||||
|
||||
warning_call = None
|
||||
if range_kind == "range" and is_builtin_range:
|
||||
if range_kind == "range" and is_builtin_range and not has_keyword:
|
||||
# Warn about possible performance regression due to behavior change
|
||||
warning_call = ast.Expr(
|
||||
ast.Call(
|
||||
@ -1109,6 +1112,12 @@ class DSLPreprocessor(ast.NodeTransformer):
|
||||
self.generic_visit(node)
|
||||
return node
|
||||
|
||||
def visit_Name(self, node):
|
||||
self.generic_visit(node)
|
||||
if node.id == "_" and isinstance(node.ctx, ast.Load):
|
||||
raise DSLAstPreprocessorError("Read '_' is not allowed")
|
||||
return node
|
||||
|
||||
def check_decorator(self, node: ast.AST) -> bool:
|
||||
"""
|
||||
Check if the function has the correct decorator for preprocessing.
|
||||
|
||||
@ -19,7 +19,9 @@ from typing import Sequence, Optional, Tuple
|
||||
import os
|
||||
import sys
|
||||
import inspect
|
||||
import argparse
|
||||
from .common import DSLRuntimeError
|
||||
from .utils.logger import log
|
||||
|
||||
_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(_SCRIPT_PATH)
|
||||
@ -182,7 +184,67 @@ class Compiler:
|
||||
return self.jit(module, opt_level, shared_libs)
|
||||
|
||||
|
||||
class CompileOptions:
|
||||
def __init__(self, options: str = ""):
|
||||
"""
|
||||
This class encapsulates all compilation options relevant to function compilation.
|
||||
It provides a convenient way to manage and pass compilation options,
|
||||
particularly for controlling compilation settings.
|
||||
By centralizing these options, it ensures consistent and flexible configuration of
|
||||
compilation parameters such as optimization level, debugging control, etc.
|
||||
|
||||
:param options: The options for the function. Will be parsed by argparse.
|
||||
:type options: str
|
||||
"""
|
||||
if not isinstance(options, str):
|
||||
raise DSLRuntimeError(
|
||||
f"Invalid compilation `options`: {options}, it should be a string"
|
||||
)
|
||||
self._parser = argparse.ArgumentParser()
|
||||
self._parser.add_argument("--opt-level", nargs="?", type=int, default=3)
|
||||
self._parser.add_argument(
|
||||
"--enable-device-assertions", action="store_true", default=False
|
||||
)
|
||||
try:
|
||||
self._options = self._parser.parse_args(options.split())
|
||||
except SystemExit as e:
|
||||
# catch argparse error and raise as DSLRuntimeError
|
||||
raise DSLRuntimeError(
|
||||
f"Invalid compile options: '{options}'. Please check the option values and format."
|
||||
)
|
||||
log().info("`cute.compile` CompileOptions: options=" + options)
|
||||
|
||||
def to_str(self):
|
||||
"""
|
||||
Generate a string representation of all compilation options
|
||||
which will be used in pipeline options.
|
||||
"""
|
||||
option_strings = []
|
||||
for key, value in vars(self._options).items():
|
||||
hyphen_key = key.replace("_", "-")
|
||||
if isinstance(value, bool):
|
||||
formatted_value = "true" if value else "false"
|
||||
else:
|
||||
formatted_value = str(value)
|
||||
option_strings.append(f"{hyphen_key}={formatted_value}")
|
||||
|
||||
return " ".join(option_strings)
|
||||
|
||||
|
||||
def compile(func, *args, **kwargs):
|
||||
"""
|
||||
This function is used to compile a `cute.jit` decorated function.
|
||||
It will process the compile options and input parameters, do explicit compilation and return the jit executor.
|
||||
|
||||
:param func: The function to compile. It can be a regular function, a method or a class instance.
|
||||
:param args: The arguments to pass to the function.
|
||||
:param kwargs: The keyword arguments to pass to the function. It can contain `options` like
|
||||
`opt_level` to control the compilation flags.
|
||||
|
||||
:return: The jit executor.
|
||||
|
||||
:raises: DSLRuntimeError if the function is not decorated with `cute.jit` or is not callable.
|
||||
"""
|
||||
if func is None:
|
||||
raise DSLRuntimeError("Function is not set or invalid.")
|
||||
|
||||
@ -217,5 +279,8 @@ def compile(func, *args, **kwargs):
|
||||
if not hasattr(func, "_dsl_object"):
|
||||
raise DSLRuntimeError("Function is not decorated with jit decorator.")
|
||||
|
||||
# process compile options, extract the options and remove them from the kwargs
|
||||
options = kwargs.pop("options", "")
|
||||
func._dsl_object.compile_options = CompileOptions(options)
|
||||
fcn_ptr = func._dsl_object._preprocess_and_execute(func)
|
||||
return func._dsl_object._func(fcn_ptr, *args, **kwargs)
|
||||
|
||||
@ -38,6 +38,7 @@ import warnings
|
||||
|
||||
from . import typing as t
|
||||
from .env_manager import EnvironmentVarManager
|
||||
from .compiler import CompileOptions
|
||||
|
||||
# =============================================================================
|
||||
# CUDA Python
|
||||
@ -232,6 +233,50 @@ def new_from_mlir_values(obj, values):
|
||||
return obj
|
||||
|
||||
|
||||
class DSLCallable:
|
||||
"""
|
||||
Wrapper class for a callable object used within the DSL.
|
||||
|
||||
DSLCallable is designed to wrap a function and provide additional
|
||||
introspection utilities such as retrieving the argument specification
|
||||
and signature. It ensures that the wrapped function can only be called
|
||||
once, after which the reference to the function is cleared to prevent
|
||||
further invocations. This is useful in scenarios where a function should
|
||||
only be executed a single time within the DSL's execution model.
|
||||
|
||||
Attributes:
|
||||
func (callable): The function to be wrapped and managed.
|
||||
|
||||
Methods:
|
||||
__call__(*args, **kwargs): Calls the wrapped function and clears it.
|
||||
get_arg_spec(): Returns the argument specification of the function.
|
||||
get_signature(): Returns the signature of the function.
|
||||
"""
|
||||
|
||||
def __init__(self, func):
|
||||
self.func = func
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
ret = self.__func__(*args, **kwargs)
|
||||
self.func = None
|
||||
return ret
|
||||
|
||||
@property
|
||||
def __func__(self):
|
||||
assert self.func is not None, "DSLCallable is already called"
|
||||
return self.func
|
||||
|
||||
@property
|
||||
def __name__(self):
|
||||
return self.__func__.__name__
|
||||
|
||||
def get_arg_spec(self):
|
||||
return inspect.getfullargspec(self.__func__)
|
||||
|
||||
def get_signature(self):
|
||||
return inspect.signature(self.__func__)
|
||||
|
||||
|
||||
class BaseDSL:
|
||||
gpu_module = None
|
||||
|
||||
@ -306,6 +351,8 @@ class BaseDSL:
|
||||
self.kernel_symbols = []
|
||||
# used to generate unique name for gpu.launch
|
||||
self.launch_inner_count = 0
|
||||
# initialize default compile options
|
||||
self.compile_options = CompileOptions()
|
||||
|
||||
if preprocess:
|
||||
self.preprocessor = DSLPreprocessor()
|
||||
@ -392,26 +439,24 @@ class BaseDSL:
|
||||
if hasattr(func, "_transformed_ast"):
|
||||
# If the function ptr is already materialized, use the existing one
|
||||
func._dsl_object.frame = func._decorator_frame
|
||||
|
||||
if func._transformed_ast is None:
|
||||
func._transformed_ast = func._dsl_object.run_preprocessor(func)
|
||||
if func._transformed_ast is None:
|
||||
del func._decorator_frame
|
||||
del func._transformed_ast
|
||||
func._dsl_object.frame = None
|
||||
return func
|
||||
|
||||
fcn_ptr = func._dsl_object.get_function_ptr(func, func._transformed_ast)
|
||||
fcn_ptr = func._dsl_object.get_function_ptr(func)
|
||||
# If the function is decorated, de-decorate it
|
||||
fcn_ptr = BaseDSL._get_original_function(fcn_ptr, func.__name__)
|
||||
return fcn_ptr
|
||||
func._dsl_object.frame = None
|
||||
return DSLCallable(fcn_ptr)
|
||||
return func
|
||||
|
||||
def jit_runner(self, frame, executor, *dargs, **dkwargs):
|
||||
def jit_runner(self, executor, frame, *dargs, **dkwargs):
|
||||
"""
|
||||
Decorator to mark a function for JIT compilation.
|
||||
"""
|
||||
# Set the frame, that can be used AST preprocessor
|
||||
self.frame = frame
|
||||
log().info("jit_runner")
|
||||
|
||||
def jit_runner_decorator(func):
|
||||
@ -444,7 +489,7 @@ class BaseDSL:
|
||||
frame = inspect.currentframe().f_back
|
||||
# Instantiate the DSL Class
|
||||
main_dsl = cls._get_dsl()
|
||||
return main_dsl.jit_runner(frame, main_dsl._func, *dargs, **dkwargs)
|
||||
return main_dsl.jit_runner(main_dsl._func, frame, *dargs, **dkwargs)
|
||||
|
||||
@classmethod
|
||||
def kernel(cls, *dargs, **dkwargs):
|
||||
@ -454,7 +499,7 @@ class BaseDSL:
|
||||
frame = inspect.currentframe().f_back
|
||||
# Instantiate the DSL Class
|
||||
main_dsl = cls._get_dsl()
|
||||
return main_dsl.jit_runner(frame, main_dsl._kernel_helper, *dargs, **dkwargs)
|
||||
return main_dsl.jit_runner(main_dsl._kernel_helper, frame, *dargs, **dkwargs)
|
||||
|
||||
@abstractmethod
|
||||
def _kernel_helper(self, func, *args, **kwargs):
|
||||
@ -627,6 +672,12 @@ class BaseDSL:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _get_module_globals(self):
|
||||
"""
|
||||
Get the module's globals.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _get_globals(self):
|
||||
"""
|
||||
Combines global and local variables from the current context and the
|
||||
@ -639,7 +690,11 @@ class BaseDSL:
|
||||
AST preprocessor generates a new python code, so the resulting globals
|
||||
dictionary is used to execute the python code.
|
||||
"""
|
||||
pass
|
||||
all_globals = self._get_module_globals().copy()
|
||||
if self.frame:
|
||||
all_globals.update(self.frame.f_globals)
|
||||
all_globals.update(self.frame.f_locals)
|
||||
return all_globals
|
||||
|
||||
def _is_tensor_descriptor(self, maybe_tensor_descriptor) -> bool:
|
||||
return isinstance(
|
||||
@ -881,20 +936,15 @@ class BaseDSL:
|
||||
Get python location information and generate MLIR location
|
||||
"""
|
||||
|
||||
frame = self.frame
|
||||
if frame is None:
|
||||
print("Frame is None")
|
||||
if self.frame is None:
|
||||
log().debug("Frame is None")
|
||||
return None
|
||||
|
||||
file_loc = ir.Location.file(frame.f_code.co_filename, frame.f_lineno, 0)
|
||||
file_loc = ir.Location.file(
|
||||
self.frame.f_code.co_filename, self.frame.f_lineno, 0
|
||||
)
|
||||
|
||||
def print_all_frames():
|
||||
for i, frame in enumerate(inspect.stack()):
|
||||
print(
|
||||
f"Frame {i}: {frame.function} in {frame.filename}, line {frame.lineno}"
|
||||
)
|
||||
|
||||
loc = ir.Location.name(frame.f_code.co_name, childLoc=file_loc)
|
||||
loc = ir.Location.name(self.frame.f_code.co_name, childLoc=file_loc)
|
||||
return loc
|
||||
|
||||
def compile_and_jit(self, module, pipeline, shared_libs, function_name=""):
|
||||
@ -992,6 +1042,8 @@ class BaseDSL:
|
||||
for attr, value in self.envar.__dict__.items():
|
||||
if value is not None:
|
||||
s.write(str(value).encode())
|
||||
# Add compile options to the hash
|
||||
s.write(self.compile_options.to_str().encode())
|
||||
module_hash = self.get_version().copy()
|
||||
module_hash.update(s.getvalue())
|
||||
module_hash = module_hash.hexdigest()
|
||||
@ -1145,6 +1197,8 @@ class BaseDSL:
|
||||
self.launch_inner_count = 0
|
||||
# reset num_kernels to 0 for next compilation.
|
||||
self.num_kernels = 0
|
||||
# reset the compile options after the compilation is done.
|
||||
self.compile_options = CompileOptions()
|
||||
|
||||
def generate_mlir(
|
||||
self,
|
||||
@ -1226,9 +1280,11 @@ class BaseDSL:
|
||||
return transformed_ast
|
||||
return None
|
||||
|
||||
def get_function_ptr(self, original_function, transformed_ast):
|
||||
def get_function_ptr(self, original_function):
|
||||
file_name = inspect.getsourcefile(original_function)
|
||||
code_object = compile(transformed_ast, filename=file_name, mode="exec")
|
||||
code_object = compile(
|
||||
original_function._transformed_ast, filename=file_name, mode="exec"
|
||||
)
|
||||
return self.preprocessor.exec(
|
||||
original_function.__name__,
|
||||
original_function,
|
||||
@ -1236,10 +1292,6 @@ class BaseDSL:
|
||||
self._get_globals(),
|
||||
)
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def _get_function_signature(self, func):
|
||||
return inspect.signature(func)
|
||||
|
||||
def _get_function_bound_args(self, sig, func_name, *args, **kwargs):
|
||||
"""
|
||||
Binds provided arguments to a function's signature and applies default values.
|
||||
@ -1260,12 +1312,11 @@ class BaseDSL:
|
||||
)
|
||||
return bound_args
|
||||
|
||||
def _canonicalize_args(self, *args, **kwargs):
|
||||
def _canonicalize_args(self, sig, *args, **kwargs):
|
||||
"""
|
||||
Canonicalize the input arguments so that returned args only contain
|
||||
positional arguments and kwargs only contain keyword arguments.
|
||||
"""
|
||||
sig = self._get_function_signature(self.funcBody)
|
||||
function_name = self.funcBody.__name__
|
||||
bound_args = self._get_function_bound_args(sig, function_name, *args, **kwargs)
|
||||
canonicalized_args = bound_args.args
|
||||
@ -1276,8 +1327,11 @@ class BaseDSL:
|
||||
if not self.funcBody:
|
||||
raise DSLRuntimeError("Function body is not set.")
|
||||
|
||||
# Pass the actual function object to _get_function_signature.
|
||||
sig = self._get_function_signature(self.funcBody)
|
||||
# Pass the actual function object to inspect.signature to get the signature.
|
||||
if isinstance(self.funcBody, DSLCallable):
|
||||
sig = self.funcBody.get_signature()
|
||||
else:
|
||||
sig = inspect.signature(self.funcBody)
|
||||
function_name = self.funcBody.__name__
|
||||
|
||||
bound_args = self._get_function_bound_args(sig, function_name, *args, **kwargs)
|
||||
@ -1292,6 +1346,8 @@ class BaseDSL:
|
||||
f"Missing required argument in `{function_name}`: '{param.name}'"
|
||||
)
|
||||
|
||||
return sig
|
||||
|
||||
def _func(self, funcBody, *args, **kwargs):
|
||||
"""Decorator for MLIR functions.
|
||||
It cuts the boilerplate code, does the following:
|
||||
@ -1324,13 +1380,16 @@ class BaseDSL:
|
||||
self.print_warning("Cache is disabled as user wants to compile only.")
|
||||
|
||||
# Check the number of arguments
|
||||
self._check_arg_count(*args, **kwargs)
|
||||
sig = self._check_arg_count(*args, **kwargs)
|
||||
|
||||
args_spec = inspect.getfullargspec(funcBody)
|
||||
if isinstance(funcBody, DSLCallable):
|
||||
args_spec = funcBody.get_arg_spec()
|
||||
else:
|
||||
args_spec = inspect.getfullargspec(funcBody)
|
||||
|
||||
# Canonicalize the input arguments
|
||||
canonicalized_args, canonicalized_kwargs = self._canonicalize_args(
|
||||
*args, **kwargs
|
||||
sig, *args, **kwargs
|
||||
)
|
||||
|
||||
# Simple name mangling
|
||||
@ -1528,7 +1587,10 @@ class BaseDSL:
|
||||
kernelGenHelper = dkwargs.get("kernelGenHelper", None)
|
||||
|
||||
kernel_name = funcBody.__name__
|
||||
args_spec = inspect.getfullargspec(funcBody)
|
||||
if isinstance(funcBody, DSLCallable):
|
||||
args_spec = funcBody.get_arg_spec()
|
||||
else:
|
||||
args_spec = inspect.getfullargspec(funcBody)
|
||||
self.funcBody = funcBody
|
||||
|
||||
# Give each kernel a unique name. (The same kernel may be
|
||||
@ -1568,11 +1630,11 @@ class BaseDSL:
|
||||
), "kernelGenHelper should be explicitly specified!"
|
||||
|
||||
# check arguments
|
||||
self._check_arg_count(*args, **kwargs)
|
||||
sig = self._check_arg_count(*args, **kwargs)
|
||||
|
||||
# Canonicalize the input arguments
|
||||
canonicalized_args, canonicalized_kwargs = self._canonicalize_args(
|
||||
*args, **kwargs
|
||||
sig, *args, **kwargs
|
||||
)
|
||||
|
||||
kernel_operands, kernel_types, kernel_arg_attrs = (
|
||||
|
||||
@ -527,7 +527,16 @@ class IntegerMeta(NumericMeta):
|
||||
return 2**cls.width - 1
|
||||
|
||||
def recast_width(cls, width):
|
||||
return eval(f"Int{width}")
|
||||
type_map = {
|
||||
8: Int8,
|
||||
16: Int16,
|
||||
32: Int32,
|
||||
64: Int64,
|
||||
128: Int128,
|
||||
}
|
||||
if width not in type_map:
|
||||
raise TypeError(f"Unsupported width: {width}")
|
||||
return type_map[width]
|
||||
|
||||
|
||||
class FloatMeta(NumericMeta):
|
||||
@ -603,7 +612,14 @@ class FloatMeta(NumericMeta):
|
||||
return cls._mantissa_width
|
||||
|
||||
def recast_width(cls, width):
|
||||
return eval(f"Float{width}")
|
||||
type_map = {
|
||||
16: Float16,
|
||||
32: Float32,
|
||||
64: Float64,
|
||||
}
|
||||
if width not in type_map:
|
||||
raise TypeError(f"Unsupported width: {width}")
|
||||
return type_map[width]
|
||||
|
||||
|
||||
def _arith_signless_to_int(a, target_type):
|
||||
|
||||
@ -118,6 +118,9 @@ from .core import (
|
||||
make_tiled_copy,
|
||||
make_tiled_copy_S,
|
||||
make_tiled_copy_D,
|
||||
make_tiled_copy_A,
|
||||
make_tiled_copy_B,
|
||||
make_tiled_copy_C,
|
||||
make_tiled_copy_C_atom,
|
||||
basic_copy,
|
||||
basic_copy_if,
|
||||
|
||||
@ -90,6 +90,7 @@ __all__ = [
|
||||
#
|
||||
"alloc_smem",
|
||||
"get_dyn_smem",
|
||||
"get_dyn_smem_size",
|
||||
#
|
||||
# tmem.py
|
||||
#
|
||||
|
||||
@ -26,7 +26,19 @@ from cutlass._mlir.dialects.nvvm import (
|
||||
RoundingModeKind,
|
||||
)
|
||||
|
||||
from ..typing import Int, Boolean, Int32, Float32, Numeric, as_numeric
|
||||
from ..typing import (
|
||||
Int,
|
||||
Boolean,
|
||||
Int16,
|
||||
Uint16,
|
||||
Int32,
|
||||
Uint32,
|
||||
Int64,
|
||||
Float32,
|
||||
BFloat16,
|
||||
Numeric,
|
||||
as_numeric,
|
||||
)
|
||||
|
||||
WARP_SIZE = 32
|
||||
FULL_MASK = 0xFFFFFFFF
|
||||
@ -190,19 +202,97 @@ def shuffle_sync_op(
|
||||
"""
|
||||
if not isinstance(value, Numeric):
|
||||
value = as_numeric(value)
|
||||
return type(value)(
|
||||
nvvm.shfl_sync(
|
||||
type(value).mlir_type,
|
||||
if value.width > 64:
|
||||
raise ValueError("shuffle_sync only supports values up to 64 bits")
|
||||
|
||||
orig_type = type(value)
|
||||
if value.width < 32:
|
||||
if value.dtype.is_float:
|
||||
value = value.to(Float32)
|
||||
else:
|
||||
if value.signed:
|
||||
value = value.to(Int32)
|
||||
else:
|
||||
value = value.to(Uint32)
|
||||
return orig_type(
|
||||
nvvm.shfl_sync(
|
||||
type(value).mlir_type,
|
||||
Int32(mask).ir_value(loc=loc, ip=ip),
|
||||
value.ir_value(loc=loc, ip=ip),
|
||||
Int32(offset).ir_value(loc=loc, ip=ip),
|
||||
Int32(mask_and_clamp).ir_value(loc=loc, ip=ip),
|
||||
kind,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
)
|
||||
elif value.width == 32:
|
||||
return orig_type(
|
||||
nvvm.shfl_sync(
|
||||
type(value).mlir_type,
|
||||
Int32(mask).ir_value(loc=loc, ip=ip),
|
||||
value.ir_value(loc=loc, ip=ip),
|
||||
Int32(offset).ir_value(loc=loc, ip=ip),
|
||||
Int32(mask_and_clamp).ir_value(loc=loc, ip=ip),
|
||||
kind,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
)
|
||||
else:
|
||||
if value.width != 64:
|
||||
raise ValueError(
|
||||
"shuffle_sync only supports 64 bits values when the bit width is larger than 32"
|
||||
)
|
||||
value = llvm.bitcast(
|
||||
T.i64(), value.to(ir.Value, loc=loc, ip=ip), loc=loc, ip=ip
|
||||
)
|
||||
# extract low 32 bits
|
||||
low_32_bits = llvm.trunc(
|
||||
T.i32(), value, llvm.IntegerOverflowFlags.none, loc=loc, ip=ip
|
||||
)
|
||||
# extract high 32 bits
|
||||
high_32_bits = llvm.lshr(
|
||||
value, Int64(32).ir_value(loc=loc, ip=ip), loc=loc, ip=ip
|
||||
)
|
||||
high_32_bits = llvm.trunc(
|
||||
T.i32(), high_32_bits, llvm.IntegerOverflowFlags.none, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
low_32_bits_shfl = nvvm.shfl_sync(
|
||||
T.i32(),
|
||||
Int32(mask).ir_value(loc=loc, ip=ip),
|
||||
value.ir_value(loc=loc, ip=ip),
|
||||
low_32_bits,
|
||||
Int32(offset).ir_value(loc=loc, ip=ip),
|
||||
Int32(mask_and_clamp).ir_value(loc=loc, ip=ip),
|
||||
kind,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
high_32_bits_shfl = nvvm.shfl_sync(
|
||||
T.i32(),
|
||||
Int32(mask).ir_value(loc=loc, ip=ip),
|
||||
high_32_bits,
|
||||
Int32(offset).ir_value(loc=loc, ip=ip),
|
||||
Int32(mask_and_clamp).ir_value(loc=loc, ip=ip),
|
||||
kind,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
)
|
||||
|
||||
# combine low and high 32 bits
|
||||
low_64_bit = llvm.zext(T.i64(), low_32_bits_shfl, loc=loc, ip=ip)
|
||||
high_64_bit = llvm.zext(T.i64(), high_32_bits_shfl, loc=loc, ip=ip)
|
||||
shlf_res = llvm.shl(
|
||||
high_64_bit,
|
||||
Int64(32).ir_value(loc=loc, ip=ip),
|
||||
llvm.IntegerOverflowFlags.none,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
shlf_res = llvm.or_(shlf_res, low_64_bit, loc=loc, ip=ip)
|
||||
shlf_res = llvm.bitcast(orig_type.mlir_type, shlf_res, loc=loc, ip=ip)
|
||||
return orig_type(shlf_res)
|
||||
|
||||
shuffle_sync = partial(shuffle_sync_op, kind=nvvm.ShflKind.idx)
|
||||
shuffle_sync_up = partial(shuffle_sync_op, kind=nvvm.ShflKind.up)
|
||||
|
||||
@ -94,3 +94,15 @@ def get_dyn_smem(
|
||||
alignment,
|
||||
)
|
||||
return _cute_nvgpu_ir.arch_get_dyn_smem(ptr=ptr_ty, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def get_dyn_smem_size(*, loc=None, ip=None) -> int:
|
||||
"""
|
||||
Gets the size in bytes of the dynamic shared memory that was specified at kernel launch time.
|
||||
This can be used for bounds checking during shared memory allocation.
|
||||
|
||||
:return: The size of dynamic shared memory in bytes
|
||||
:rtype: int
|
||||
"""
|
||||
return _cute_nvgpu_ir.arch_get_dyn_smem_size(loc=loc, ip=ip)
|
||||
|
||||
@ -31,6 +31,7 @@ from typing import (
|
||||
Optional,
|
||||
)
|
||||
from enum import Enum, auto
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from cutlass.cutlass_dsl import (
|
||||
const,
|
||||
@ -517,10 +518,14 @@ class ScaledBasis:
|
||||
sb3 = ScaledBasis(4, [0, 1]) # 4 * E([0, 1])
|
||||
|
||||
# Scaled basis elements are commonly used in layout strides
|
||||
layout = make_layout((4, 8), stride=(ScaledBasis(1, 0), ScaledBasis(1, 1)))
|
||||
layout = make_layout((4, 8), stride=(ScaledBasis(2, 0), ScaledBasis(1, 1)))
|
||||
|
||||
# This creates a layout with strides (1@0, 1@1) representing
|
||||
# This creates a layout with strides (2@0, 1@1) representing
|
||||
# a coordinate system where each dimension has its own basis
|
||||
|
||||
# Example: Mapping coordinates to indices using the layout
|
||||
coord = (2, 3)
|
||||
idx = crd2idx(coord, layout) # Maps (2, 3) to (4, 3)
|
||||
"""
|
||||
|
||||
def __init__(self, value, mode) -> None:
|
||||
@ -712,8 +717,9 @@ class Swizzle(ir.Value):
|
||||
|
||||
e.g. Given
|
||||
0bxxxxxxxxxxxxxxxxYYxxxxxxxxxZZxxx
|
||||
|
||||
the result is
|
||||
0bxxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxx where AA = ZZ xor YY
|
||||
0bxxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxx where AA = ZZ `xor` YY
|
||||
|
||||
"""
|
||||
|
||||
@ -897,7 +903,7 @@ class _Layout(Layout):
|
||||
|
||||
@ir.register_value_caster(_cute_ir.ComposedLayoutType.get_static_typeid(), replace=True)
|
||||
class ComposedLayout(ir.Value):
|
||||
"""ComposedLayout represents the functional composition of layouts in CuTe.
|
||||
r"""ComposedLayout represents the functional composition of layouts in CuTe.
|
||||
|
||||
A ComposedLayout is formed by the composition of three components:
|
||||
inner o offset o outer, where:
|
||||
@ -907,7 +913,10 @@ class ComposedLayout(ir.Value):
|
||||
- outer: The outer layout that is applied first
|
||||
|
||||
ComposedLayout implements the functional composition operation where:
|
||||
R(c) := (inner o offset o outer)(c) := inner(offset + outer(c))
|
||||
|
||||
.. math::
|
||||
|
||||
R(c) := (inner \\circ offset \\circ outer)(c) := inner(offset + outer(c))
|
||||
|
||||
This composition allows for complex transformations of coordinates and indices,
|
||||
enabling operations like tiling, partitioning, and reshaping of data.
|
||||
@ -1670,7 +1679,10 @@ def print_tensor(tensor: Tensor, *, verbose: bool = False, loc=None, ip=None):
|
||||
:type ip: insertion pointer, optional
|
||||
:raises NotImplementedError: If the tensor type doesn't support trivial dereferencing
|
||||
|
||||
Example output:
|
||||
**Example output:**
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
tensor(raw_ptr<@..., Float32, generic, align(4)> o (8,5):(5,1), data=
|
||||
[[-0.4326, -0.5434, 0.1238, 0.7132, 0.8042],
|
||||
[-0.8462, 0.9871, 0.4389, 0.7298, 0.6948],
|
||||
@ -1973,7 +1985,8 @@ def find_if(
|
||||
:return: Index if found at top level, tuple of indices showing nested position, or None if not found
|
||||
:rtype: Union[int, Tuple[int, ...], None]
|
||||
|
||||
Examples:
|
||||
**Examples:**
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Find the first position of x in t
|
||||
@ -2186,6 +2199,23 @@ def is_congruent(
|
||||
) -> bool:
|
||||
"""
|
||||
Returns whether a is congruent to b.
|
||||
|
||||
Congruence is an equivalence relation between hierarchical structures.
|
||||
|
||||
Two objects are congruent if:
|
||||
* They have the same rank, AND
|
||||
* They are both non-tuple values, OR
|
||||
* They are both tuples AND all corresponding elements are congruent.
|
||||
|
||||
Congruence requires type matching at each level -- scalar values match with
|
||||
scalar values, and tuples match with tuples of the same rank.
|
||||
|
||||
:param a: First object to compare
|
||||
:type a: Union[XTuple, Layout, ComposedLayout, Tensor]
|
||||
:param b: Second object to compare
|
||||
:type b: Union[XTuple, Layout, ComposedLayout, Tensor]
|
||||
:return: True if a and b are congruent, False otherwise
|
||||
:rtype: bool
|
||||
"""
|
||||
if isinstance(a, (Layout, ComposedLayout, Tensor)):
|
||||
a = a.shape
|
||||
@ -2204,6 +2234,22 @@ def is_weakly_congruent(
|
||||
) -> bool:
|
||||
"""
|
||||
Returns whether a is weakly congruent to b.
|
||||
|
||||
Weak congruence is a partial order on hierarchical structures.
|
||||
|
||||
Object X is weakly congruent to object Y if:
|
||||
* X is a non-tuple value, OR
|
||||
* X and Y are both tuples of the same rank AND all corresponding elements are weakly congruent.
|
||||
|
||||
Weak congruence allows scalar values to match with tuples, making it useful
|
||||
for determining whether an object has a hierarchical structure "up to" another.
|
||||
|
||||
:param a: First object to compare
|
||||
:type a: Union[XTuple, Layout, ComposedLayout, Tensor]
|
||||
:param b: Second object to compare
|
||||
:type b: Union[XTuple, Layout, ComposedLayout, Tensor]
|
||||
:return: True if a and b are weakly congruent, False otherwise
|
||||
:rtype: bool
|
||||
"""
|
||||
if isinstance(a, (Layout, ComposedLayout, Tensor)):
|
||||
a = a.shape
|
||||
@ -2261,8 +2307,11 @@ def get(input, mode: List[int], *, loc=None, ip=None):
|
||||
|
||||
**Examples:**
|
||||
|
||||
For a layout like ((4,8),2):((16,1),8), get with mode=[0,1] would extract
|
||||
the element 8 from the shape component.
|
||||
.. code-block:: python
|
||||
|
||||
layout = make_layout(((4, 8), (16, 1), 8), stride=((1, 4), (32, 0), 512))
|
||||
sub_layout = get(layout, mode=[0, 1]) # 8:4
|
||||
sub_layout = get(layout, mode=[1]) # (16, 1):(32, 0)
|
||||
"""
|
||||
# Empty mode returns input and terminates the recursive call
|
||||
if not mode:
|
||||
@ -5065,6 +5114,11 @@ def make_layout_tv(
|
||||
* 2 elements per thread
|
||||
"""
|
||||
|
||||
if not isinstance(thr_layout, Layout):
|
||||
raise TypeError(f"expected a Layout for thr_layout, but got {type(thr_layout)}")
|
||||
if not isinstance(val_layout, Layout):
|
||||
raise TypeError(f"expected a Layout for val_layout, but got {type(val_layout)}")
|
||||
|
||||
# Take the raked_products to compute the Layout_MN
|
||||
# (M,N) -> (thr_idx, val_idx)
|
||||
layout_mn = raked_product(thr_layout, val_layout, loc=loc, ip=ip)
|
||||
@ -5081,8 +5135,52 @@ def make_layout_tv(
|
||||
return (tiler_mn, layout_tv)
|
||||
|
||||
|
||||
def _make_tiled_copy(atom, layout_tv, tiler_mn, *, loc=None, ip=None):
|
||||
if type(tiler_mn) is tuple:
|
||||
tiler_mn = _pack_tile(tiler_mn, loc=loc, ip=ip)
|
||||
|
||||
assert isinstance(tiler_mn, ir.Value) and _cute_ir.TileType.isinstance(
|
||||
tiler_mn.type
|
||||
), f"tiler_mn must be a Tile, but got {type(tiler_mn)}"
|
||||
assert is_static(layout_tv.type) and is_static(
|
||||
tiler_mn.type
|
||||
), "layout tv and tiler mn must be static"
|
||||
tiled_copy_ty = _cute_nvgpu_ir.TiledCopyType.get(
|
||||
atom.type, layout_tv.type, tiler_mn.type
|
||||
)
|
||||
|
||||
val = _cute_ir.make_tiled_copy(tiled_copy_ty, atom._trait.value, loc=loc, ip=ip)
|
||||
# Instead of modifying atom which might have been provided by the user, create a brand new
|
||||
# trait instance and replace the Atom ir.Value with the tiled one
|
||||
trait = new_from_mlir_values(atom._trait, [val])
|
||||
return TiledCopy(atom.op, trait)
|
||||
|
||||
|
||||
@deprecated("Use make_tiled_copy_tv instead")
|
||||
def make_tiled_copy(atom, layout_tv, tiler_mn, *, loc=None, ip=None):
|
||||
"""Create a tiled type given a TV partitioner and tiler.
|
||||
|
||||
:param atom: Copy atom, e.g. smit_copy and simt_async_copy, tma_load, etc.
|
||||
:type atom: CopyAtom
|
||||
:param layout_tv: Thread-value layout
|
||||
:type layout_tv: Layout
|
||||
:param tiler_mn: Tile size
|
||||
:type tiler_mn: Tiler
|
||||
:param loc: Source location for MLIR, defaults to None
|
||||
:type loc: Optional[Location], optional
|
||||
:param ip: Insertion point, defaults to None
|
||||
:type ip: Optional[InsertionPoint], optional
|
||||
|
||||
:return: A tiled copy for the partitioner
|
||||
:rtype: TiledCopy
|
||||
"""
|
||||
return _make_tiled_copy(atom, layout_tv, tiler_mn, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def make_tiled_copy_tv(atom, thr_layout, val_layout, *, loc=None, ip=None) -> TiledCopy:
|
||||
def make_tiled_copy_tv(
|
||||
atom: CopyAtom, thr_layout: Layout, val_layout: Layout, *, loc=None, ip=None
|
||||
) -> TiledCopy:
|
||||
"""Create a tiled copy given separate thread and value layouts.
|
||||
|
||||
A TV partitioner is inferred based on the input layouts. The input thread layout
|
||||
@ -5105,30 +5203,17 @@ def make_tiled_copy_tv(atom, thr_layout, val_layout, *, loc=None, ip=None) -> Ti
|
||||
|
||||
tiler_mn, layout_tv = make_layout_tv(thr_layout, val_layout, loc=loc, ip=ip)
|
||||
tiler_mn = _pack_tile(product_each(tiler_mn, loc=loc, ip=ip), loc=loc, ip=ip)
|
||||
if not is_static(layout_tv.type) or not is_static(tiler_mn.type):
|
||||
raise ValueError(
|
||||
f"expects layout tv and tiler mn, but got {layout_tv.type} and {tiler_mn.type}"
|
||||
)
|
||||
tiled_copy_ty = _cute_nvgpu_ir.TiledCopyType.get(
|
||||
atom.type, layout_tv.type, tiler_mn.type
|
||||
)
|
||||
val = _cute_ir.make_tiled_copy(tiled_copy_ty, atom._trait.value, loc=loc, ip=ip)
|
||||
# Instead of modifying atom which might have been provided by the user, create a brand new
|
||||
# trait instance and replace the Atom ir.Value with the tiled one
|
||||
trait = new_from_mlir_values(atom._trait, [val])
|
||||
return TiledCopy(atom.op, trait)
|
||||
return _make_tiled_copy(atom, layout_tv, tiler_mn, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def make_tiled_copy(atom, layout_tv, tiler_mn, *, loc=None, ip=None):
|
||||
"""Create a tiled type given a TV partitioner and tiler.
|
||||
def make_tiled_copy_A(atom, tiled_mma, *, loc=None, ip=None):
|
||||
"""Create a tiled copy out of the copy_atom that matches the A-Layout of tiled_mma.
|
||||
|
||||
:param atom: Copy atom, e.g. smit_copy and simt_async_copy, tma_load, etc.
|
||||
:param atom: Copy atom
|
||||
:type atom: CopyAtom
|
||||
:param layout_tv: Thread-value layout
|
||||
:type layout_tv: Layout
|
||||
:param tiler_mn: Tile size
|
||||
:type tiler_mn: Tiler
|
||||
:param tiled_mma: Tiled MMA
|
||||
:type tiled_mma: TiledMma
|
||||
:param loc: Source location for MLIR, defaults to None
|
||||
:type loc: Optional[Location], optional
|
||||
:param ip: Insertion point, defaults to None
|
||||
@ -5138,21 +5223,65 @@ def make_tiled_copy(atom, layout_tv, tiler_mn, *, loc=None, ip=None):
|
||||
:rtype: TiledCopy
|
||||
"""
|
||||
|
||||
# tiler_mn = pack_tuple(tiler_mn, make_tile)
|
||||
if type(tiler_mn) is tuple:
|
||||
tiler_mn = _pack_tile(tiler_mn, loc=loc, ip=ip)
|
||||
|
||||
assert is_static(layout_tv.type) and is_static(
|
||||
tiler_mn.type
|
||||
), "layout tv and tiler mn must be static"
|
||||
tiled_copy_ty = _cute_nvgpu_ir.TiledCopyType.get(
|
||||
atom.type, layout_tv.type, tiler_mn.type
|
||||
return _make_tiled_copy(
|
||||
atom,
|
||||
tiled_mma.tv_layout_A_tiled,
|
||||
(tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(2)),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def make_tiled_copy_B(atom, tiled_mma, *, loc=None, ip=None):
|
||||
"""Create a tiled copy out of the copy_atom that matches the B-Layout of tiled_mma.
|
||||
|
||||
:param atom: Copy atom
|
||||
:type atom: CopyAtom
|
||||
:param tiled_mma: Tiled MMA
|
||||
:type tiled_mma: TiledMma
|
||||
:param loc: Source location for MLIR, defaults to None
|
||||
:type loc: Optional[Location], optional
|
||||
:param ip: Insertion point, defaults to None
|
||||
:type ip: Optional[InsertionPoint], optional
|
||||
|
||||
:return: A tiled copy for the partitioner
|
||||
:rtype: TiledCopy
|
||||
"""
|
||||
|
||||
return _make_tiled_copy(
|
||||
atom,
|
||||
tiled_mma.tv_layout_B_tiled,
|
||||
(tiled_mma.get_tile_size(1), tiled_mma.get_tile_size(2)),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def make_tiled_copy_C(atom, tiled_mma, *, loc=None, ip=None):
|
||||
"""Create a tiled copy out of the copy_atom that matches the C-Layout of tiled_mma.
|
||||
|
||||
:param atom: Copy atom
|
||||
:type atom: CopyAtom
|
||||
:param tiled_mma: Tiled MMA
|
||||
:type tiled_mma: TiledMma
|
||||
:param loc: Source location for MLIR, defaults to None
|
||||
:type loc: Optional[Location], optional
|
||||
:param ip: Insertion point, defaults to None
|
||||
:type ip: Optional[InsertionPoint], optional
|
||||
|
||||
:return: A tiled copy for the partitioner
|
||||
:rtype: TiledCopy
|
||||
"""
|
||||
|
||||
return _make_tiled_copy(
|
||||
atom,
|
||||
tiled_mma.tv_layout_C_tiled,
|
||||
(tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(1)),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
val = _cute_ir.make_tiled_copy(tiled_copy_ty, atom._trait.value, loc=loc, ip=ip)
|
||||
# Instead of modifying atom which might have been provided by the user, create a brand new
|
||||
# trait instance and replace the Atom ir.Value with the tiled one
|
||||
trait = new_from_mlir_values(atom._trait, [val])
|
||||
return TiledCopy(atom.op, trait)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
@ -5172,7 +5301,7 @@ def make_tiled_copy_S(atom, tiled_copy, *, loc=None, ip=None):
|
||||
:rtype: TiledCopy
|
||||
"""
|
||||
|
||||
return make_tiled_copy(
|
||||
return _make_tiled_copy(
|
||||
atom, tiled_copy.layout_src_tv_tiled, tiled_copy.tiler_mn, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
@ -5194,7 +5323,7 @@ def make_tiled_copy_D(atom, tiled_copy, *, loc=None, ip=None):
|
||||
:rtype: TiledCopy
|
||||
"""
|
||||
|
||||
return make_tiled_copy(
|
||||
return _make_tiled_copy(
|
||||
atom, tiled_copy.layout_dst_tv_tiled, tiled_copy.tiler_mn, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
@ -5273,7 +5402,7 @@ def make_tiled_copy_C_atom(atom: CopyAtom, mma: TiledMma, *, loc=None, ip=None):
|
||||
|
||||
tiler_mn = _pack_tile(tiler, loc=loc, ip=ip)
|
||||
|
||||
return make_tiled_copy(atom, layout_tv, tiler_mn, loc=loc, ip=ip)
|
||||
return _make_tiled_copy(atom, layout_tv, tiler_mn, loc=loc, ip=ip)
|
||||
|
||||
|
||||
####################################################################################################
|
||||
@ -5297,7 +5426,7 @@ def gemm(
|
||||
) -> None:
|
||||
"""The GEMM algorithm.
|
||||
|
||||
Computes ``D <- AB + C`` where ``C`` and ``D`` can alias. Note that some MMA Atoms (e.g.
|
||||
Computes ``D <- A * B + C`` where ``C`` and ``D`` can alias. Note that some MMA Atoms (e.g.
|
||||
warpgroup-wide or tcgen05 MMAs) require manually setting an "accumulate" boolean field.
|
||||
|
||||
All tensors must be partitioned according to the provided MMA Atom.
|
||||
@ -6416,7 +6545,8 @@ class struct:
|
||||
"""
|
||||
Decorator to abstract C structure in Python DSL.
|
||||
|
||||
Usage:
|
||||
**Usage:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
# Supports base_dsl scalar int/float elements, array and nested struct:
|
||||
@ -6424,12 +6554,15 @@ class struct:
|
||||
class complex:
|
||||
real : cutlass.Float32
|
||||
imag : cutlass.Float32
|
||||
|
||||
|
||||
@cute.struct
|
||||
class StorageA:
|
||||
mbarA : cute.struct.MemRange[cutlass.Int64, stage]
|
||||
compA : complex
|
||||
intA : cutlass.Int16
|
||||
|
||||
|
||||
# Supports aligment for its elements:
|
||||
@cute.struct
|
||||
class StorageB:
|
||||
@ -6442,6 +6575,7 @@ class struct:
|
||||
x: cute.struct.Align[cutlass.Int32, 16]
|
||||
compA: cute.struct.Align[complex, 16]
|
||||
|
||||
|
||||
# Statically get size and alignment:
|
||||
size = StorageB.__sizeof__()
|
||||
align = StorageB.__alignof__()
|
||||
|
||||
@ -94,12 +94,6 @@ def make_tiled_tma_atom(
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
# Wrap smem_layout in a composed layout to make it a TMA-friendly layout
|
||||
if isinstance(smem_layout, Layout):
|
||||
smem_layout = core.make_composed_layout(
|
||||
core.make_swizzle(0, 4, 3), 0, smem_layout
|
||||
)
|
||||
|
||||
if isinstance(op, CopyBulkTensorTileG2SOp):
|
||||
if num_multicast != 1:
|
||||
raise ValueError(
|
||||
|
||||
@ -37,7 +37,7 @@ from .cpasync.copy import (
|
||||
def make_tiled_tma_atom_A(
|
||||
op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp],
|
||||
gmem_tensor: Tensor,
|
||||
smem_layout: Layout,
|
||||
smem_layout: Union[Layout, core.ComposedLayout],
|
||||
mma_tiler_mnk: Shape,
|
||||
tiled_mma: core.TiledMma,
|
||||
cluster_shape_vmnk: Shape,
|
||||
@ -76,7 +76,7 @@ def make_tiled_tma_atom_A(
|
||||
:param gmem_tensor: The GMEM tensor to be loaded by this copy atom
|
||||
:type gmem_tensor: Tensor
|
||||
:param smem_layout: Shared memory layout to load the tensor into (PDSL)
|
||||
:type smem_layout: Layout
|
||||
:type smem_layout: Union[Layout, core.ComposedLayout]
|
||||
:param mma_tiler_mnk: The MMA Tiler shape (TILE_M, TILE_N, TILE_K) in MNK dimensions
|
||||
:type mma_tiler_mnk: Shape
|
||||
:param tiled_mma: The TiledMMA that will consume the load as operands
|
||||
@ -142,7 +142,7 @@ def make_tiled_tma_atom_A(
|
||||
def make_tiled_tma_atom_B(
|
||||
op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp],
|
||||
gmem_tensor: Tensor,
|
||||
smem_layout: Layout,
|
||||
smem_layout: Union[Layout, core.ComposedLayout],
|
||||
mma_tiler_mnk: Shape,
|
||||
tiled_mma: core.TiledMma,
|
||||
cluster_shape_vmnk: Shape,
|
||||
@ -181,7 +181,7 @@ def make_tiled_tma_atom_B(
|
||||
:param gmem_tensor: The GMEM tensor to be loaded by this copy atom
|
||||
:type gmem_tensor: Tensor
|
||||
:param smem_layout: Shared memory layout to load the tensor into (PDSL)
|
||||
:type smem_layout: Layout
|
||||
:type smem_layout: Union[Layout, core.ComposedLayout]
|
||||
:param mma_tiler_mnk: The MMA Tiler shape (TILE_M, TILE_N, TILE_K) in MNK dimensions
|
||||
:type mma_tiler_mnk: Shape
|
||||
:param tiled_mma: The TiledMMA that will consume the load as operands
|
||||
|
||||
@ -42,6 +42,9 @@ __all__ = [
|
||||
"MmaF16BF16Op",
|
||||
"MmaI8Op",
|
||||
"MmaFP8Op",
|
||||
"MmaMXF8Op",
|
||||
"MmaMXF4Op",
|
||||
"MmaMXF4NVF4Op",
|
||||
"SmemLayoutAtomKind",
|
||||
#
|
||||
# helpers.py
|
||||
@ -54,4 +57,6 @@ __all__ = [
|
||||
"get_tmem_copy_properties",
|
||||
"find_tmem_tensor_col_offset",
|
||||
"make_tmem_copy",
|
||||
"make_s2t_copy",
|
||||
"get_s2t_smem_desc_tensor",
|
||||
]
|
||||
|
||||
@ -23,6 +23,8 @@ from ..common import OpError
|
||||
from ...core import CopyOp, Trait
|
||||
from ...typing import Numeric
|
||||
|
||||
from .mma import CtaGroup
|
||||
|
||||
|
||||
class Repetition(enum.Enum):
|
||||
"""
|
||||
@ -469,3 +471,193 @@ class St32x32bOp(_StBase):
|
||||
|
||||
class St32x32bTrait(Trait):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _S2TCopyBase(CopyOp):
|
||||
cta_group: CtaGroup
|
||||
|
||||
admissible_archs = [
|
||||
"sm_100a",
|
||||
"sm_100f",
|
||||
]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Arch verification
|
||||
arch = CuTeDSL._get_dsl().envar.arch
|
||||
if arch not in self.admissible_archs:
|
||||
raise OpError(
|
||||
self,
|
||||
f"expects arch to be one of {self.admissible_archs}, but got {arch}",
|
||||
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
|
||||
)
|
||||
# Verify that the user provided enum values
|
||||
if not isinstance(self.cta_group, CtaGroup):
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'cta_group' Op parameter to be a tcgen05.CtaGroup instance",
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = (
|
||||
f"tcgen05 {self.__class__.__name__[:-2]} Copy Operation"
|
||||
+ f"\n CTA group = {self.cta_group}"
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Cp128x256bOp(_S2TCopyBase):
|
||||
"""
|
||||
128x256b SMEM to TMEM Copy Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=tcgen05#tcgen05-instructions-tcgen05-cp>`__.
|
||||
This Operation corresponds to the ``.128x256b`` qualifier.
|
||||
"""
|
||||
|
||||
def _make_trait(
|
||||
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
||||
) -> "Cp128x256bTrait":
|
||||
ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get(
|
||||
copy_internal_type.mlir_type,
|
||||
128,
|
||||
256,
|
||||
self.cta_group.value,
|
||||
_cute_nvgpu_ir.CopyS2TBroadcast.none,
|
||||
)
|
||||
return Cp128x256bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
||||
|
||||
|
||||
class Cp128x256bTrait(Trait):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Cp128x128bOp(_S2TCopyBase):
|
||||
"""
|
||||
128x128b SMEM to TMEM Copy Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=tcgen05#tcgen05-instructions-tcgen05-cp>`__.
|
||||
This Operation corresponds to the ``.128x128b`` qualifier.
|
||||
"""
|
||||
|
||||
def _make_trait(
|
||||
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
||||
) -> "Cp128x128bTrait":
|
||||
ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get(
|
||||
copy_internal_type.mlir_type,
|
||||
128,
|
||||
128,
|
||||
self.cta_group.value,
|
||||
_cute_nvgpu_ir.CopyS2TBroadcast.none,
|
||||
)
|
||||
return Cp128x128bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
||||
|
||||
|
||||
class Cp128x128bTrait(Trait):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Cp4x256bOp(_S2TCopyBase):
|
||||
"""
|
||||
4x256b SMEM to TMEM Copy Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=tcgen05#tcgen05-instructions-tcgen05-cp>`__.
|
||||
This Operation corresponds to the ``.4x256b`` qualifier.
|
||||
"""
|
||||
|
||||
def _make_trait(
|
||||
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
||||
) -> "Cp4x256bTrait":
|
||||
ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get(
|
||||
copy_internal_type.mlir_type,
|
||||
4,
|
||||
256,
|
||||
self.cta_group.value,
|
||||
_cute_nvgpu_ir.CopyS2TBroadcast.none,
|
||||
)
|
||||
return Cp4x256bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
||||
|
||||
|
||||
class Cp4x256bTrait(Trait):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Cp4x32x128bOp(_S2TCopyBase):
|
||||
"""
|
||||
32x128b SMEM to TMEM Copy Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=tcgen05#tcgen05-instructions-tcgen05-cp>`__.
|
||||
This Operation corresponds to the ``.32x128b`` qualifier with ``warpx4`` broadcast qualifier enabled.
|
||||
"""
|
||||
|
||||
def _make_trait(
|
||||
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
||||
) -> "Cp4x32x128bTrait":
|
||||
ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get(
|
||||
copy_internal_type.mlir_type,
|
||||
32,
|
||||
128,
|
||||
self.cta_group.value,
|
||||
_cute_nvgpu_ir.CopyS2TBroadcast.x4,
|
||||
)
|
||||
return Cp4x32x128bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
||||
|
||||
|
||||
class Cp4x32x128bTrait(Trait):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Cp2x64x128b0213Op(_S2TCopyBase):
|
||||
"""
|
||||
64x128b SMEM to TMEM Copy Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=tcgen05#tcgen05-instructions-tcgen05-cp>`__.
|
||||
This Operation corresponds to the ``.64x128b`` qualifier with ``.warpx2::02_13`` broadcast qualifier enabled.
|
||||
"""
|
||||
|
||||
def _make_trait(
|
||||
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
||||
) -> "Cp2x64x128b0213Trait":
|
||||
ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get(
|
||||
copy_internal_type.mlir_type,
|
||||
64,
|
||||
128,
|
||||
self.cta_group.value,
|
||||
_cute_nvgpu_ir.CopyS2TBroadcast.lw_0213,
|
||||
)
|
||||
return Cp2x64x128b0213Trait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
||||
|
||||
|
||||
class Cp2x64x128b0213Trait(Trait):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Cp2x64x128b0123Op(_S2TCopyBase):
|
||||
"""
|
||||
64x128b SMEM to TMEM Copy Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=tcgen05#tcgen05-instructions-tcgen05-cp>`__.
|
||||
This Operation corresponds to the ``.64x128b`` qualifier with ``.warpx2::01_23`` broadcast qualifier enabled.
|
||||
"""
|
||||
|
||||
def _make_trait(
|
||||
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
||||
) -> "Cp2x64x128b0123Trait":
|
||||
ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get(
|
||||
copy_internal_type.mlir_type,
|
||||
64,
|
||||
128,
|
||||
self.cta_group.value,
|
||||
_cute_nvgpu_ir.CopyS2TBroadcast.lw_0123,
|
||||
)
|
||||
return Cp2x64x128b0123Trait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
||||
|
||||
|
||||
class Cp2x64x128b0123Trait(Trait):
|
||||
pass
|
||||
|
||||
@ -299,3 +299,30 @@ def make_tmem_copy(
|
||||
)
|
||||
new_trait = type(atom._trait)(tiled_copy_val)
|
||||
return core.TiledCopy(atom.op, new_trait)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def make_s2t_copy(
|
||||
atom: core.CopyAtom, tmem_tensor: Tensor, *, loc=None, ip=None
|
||||
) -> core.TiledCopy:
|
||||
"""
|
||||
Makes a Tiled Copy instance from a TMEM Copy Atom and a TMEM tensor.
|
||||
"""
|
||||
tiled_copy_val = _cute_nvgpu_ir.atom_make_s2t_copy(
|
||||
atom._trait.value, tmem_tensor.value, loc=loc, ip=ip
|
||||
)
|
||||
new_trait = type(atom._trait)(tiled_copy_val)
|
||||
return core.TiledCopy(atom.op, new_trait)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def get_s2t_smem_desc_tensor(
|
||||
atom: core.CopyAtom, smem_tensor: Tensor, *, loc=None, ip=None
|
||||
) -> Tensor:
|
||||
"""
|
||||
Returns the SMEM descriptor tensor from a S2T copy atom and a SMEM tensor.
|
||||
"""
|
||||
smem_desc_tensor = _cute_nvgpu_ir.atom_get_copy_s2t_smem_desc_view(
|
||||
atom._trait.value, smem_tensor.value, loc=loc, ip=ip
|
||||
)
|
||||
return smem_desc_tensor
|
||||
|
||||
@ -20,9 +20,12 @@ import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
||||
from cutlass._mlir import ir
|
||||
|
||||
from ..common import OpError
|
||||
from ...core import MmaOp, Trait, _pack_shape, rank, depth, _Tensor
|
||||
from ... import core
|
||||
from ...core import Trait, _pack_shape, rank, depth, _Tensor
|
||||
from ...typing import (
|
||||
Shape,
|
||||
Float4E2M1FN,
|
||||
Float8E8M0FNU,
|
||||
Float8E5M2,
|
||||
Float8E4M3FN,
|
||||
Float16,
|
||||
@ -35,6 +38,7 @@ from ...typing import (
|
||||
Int32,
|
||||
Numeric,
|
||||
AddressSpace,
|
||||
Pointer,
|
||||
)
|
||||
|
||||
|
||||
@ -104,7 +108,6 @@ class CtaGroup(enum.Enum):
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}.{self.name}>"
|
||||
|
||||
|
||||
class Field(enum.Enum):
|
||||
"""
|
||||
An enumeration for the fields of the MMA Atom that can be modified at runtime.
|
||||
@ -113,6 +116,8 @@ class Field(enum.Enum):
|
||||
NEGATE_A = "neg_a"
|
||||
NEGATE_B = "neg_b"
|
||||
ACCUMULATE = "accum_c"
|
||||
SFA = "sf_a"
|
||||
SFB = "sf_b"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}.{self.name}"
|
||||
@ -124,9 +129,9 @@ class Field(enum.Enum):
|
||||
return self.value
|
||||
|
||||
|
||||
# Base class for all tcgen05 MMA Ops used to factor out some internal code
|
||||
# Base class for all tcgen05 MMA Ops with syntax `tcgen05.mma.cta_group.kind` used to factor out some internal code
|
||||
@dataclass(frozen=True)
|
||||
class MmaOp(MmaOp):
|
||||
class MmaOp(core.MmaOp):
|
||||
a_dtype: Type[Numeric]
|
||||
b_dtype: Type[Numeric]
|
||||
acc_dtype: Type[Numeric]
|
||||
@ -256,6 +261,155 @@ class MmaTrait(Trait):
|
||||
)
|
||||
|
||||
|
||||
# Base class for all tcgen05 BlockScaled MMA Ops with syntax `tcgen05.mma.cta_group.kind.block_scale` used to factor out some internal code
|
||||
@dataclass(frozen=True)
|
||||
class BlockScaledMmaOp(core.MmaOp):
|
||||
a_dtype: Type[Numeric]
|
||||
b_dtype: Type[Numeric]
|
||||
acc_dtype: Float32
|
||||
sf_dtype: Type[Numeric]
|
||||
sf_vec_size: int
|
||||
shape_mnk: Shape
|
||||
cta_group: CtaGroup
|
||||
a_src: OperandSource
|
||||
a_major_mode: OperandMajorMode
|
||||
b_major_mode: OperandMajorMode
|
||||
|
||||
admissible_archs = [
|
||||
"sm_100a",
|
||||
]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Verify arch
|
||||
arch = CuTeDSL._get_dsl().envar.arch
|
||||
if arch not in self.admissible_archs:
|
||||
raise OpError(
|
||||
self,
|
||||
f"expects arch to be one of {self.admissible_archs}, but got {arch}",
|
||||
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
|
||||
)
|
||||
# Verify that the user provided enum values
|
||||
if not isinstance(self.cta_group, CtaGroup):
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'cta_group' Op parameter to be a tcgen05.CtaGroup instance",
|
||||
)
|
||||
if not isinstance(self.a_src, OperandSource):
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'a_src' Op parameter to be a tcgen05.OperandSource instance",
|
||||
)
|
||||
if not isinstance(self.a_major_mode, OperandMajorMode):
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'a_major_mode' Op parameter to be a tcgen05.OperandMajorMode instance",
|
||||
)
|
||||
if not isinstance(self.b_major_mode, OperandMajorMode):
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'b_major_mode' Op parameter to be a tcgen05.OperandMajorMode instance",
|
||||
)
|
||||
# Verify the instruction shape
|
||||
if (rank(self.shape_mnk) not in [2, 3]) or (depth(self.shape_mnk) != 1):
|
||||
raise OpError(
|
||||
self,
|
||||
f"expected a flat rank 2 or 3 tuple for the 'shape_mnk' Op parameter, "
|
||||
f"but got {self.shape_mnk}",
|
||||
)
|
||||
m, n = self.shape_mnk[0], self.shape_mnk[1]
|
||||
if self.cta_group == CtaGroup.ONE:
|
||||
if m != 128:
|
||||
raise OpError(self, f"expects the M-mode to be 128, but got {m}")
|
||||
|
||||
if (n < 8) or (n > 256) or (n % 8 != 0):
|
||||
raise OpError(
|
||||
self,
|
||||
f"expects the N-mode to satisfy 8 <= N <= 256 and N % 8 == 0, but got {n}",
|
||||
)
|
||||
else:
|
||||
if m not in [128, 256]:
|
||||
raise OpError(self, f"expects the M-mode to be 128 or 256, but got {m}")
|
||||
if (n < 16) or (n > 256) or (n % 16 != 0):
|
||||
raise OpError(
|
||||
self,
|
||||
f"expects the N-mode to satisfy 16 <= N <= 256 and N % 16 == 0, but got {n}",
|
||||
)
|
||||
if self.sf_vec_size not in [16, 32]:
|
||||
raise OpError(
|
||||
self,
|
||||
f"expects the scale factor vector size to be 16 or 32, but got {self.sf_vec_size}",
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
self.__class__.descriptive_name # type: ignore
|
||||
+ f"\n A data type = {self.a_dtype}"
|
||||
+ f"\n B data type = {self.b_dtype}"
|
||||
+ f"\n Accumulator data type = {self.acc_dtype}"
|
||||
+ f"\n Scale factor data type = {self.sf_dtype}"
|
||||
+ f"\n Scale factor vector size = {self.sf_vec_size}"
|
||||
+ f"\n CTA group = {self.cta_group}"
|
||||
+ f"\n A source location = {self.a_src}"
|
||||
+ f"\n A major mode = {self.a_major_mode}"
|
||||
+ f"\n B major mode = {self.b_major_mode}"
|
||||
+ f"\n Instruction shape MNK = {self.shape_mnk}"
|
||||
)
|
||||
|
||||
def _verify_fragment_A(self, input: _Tensor, *, loc=None, ip=None):
|
||||
if input.memspace == AddressSpace.smem and isinstance(
|
||||
input.layout.type, _cute_ir.ComposedLayoutType
|
||||
):
|
||||
raise OpError(
|
||||
self,
|
||||
f"Expected affine layout for {self._make_trait()}'s operand A, "
|
||||
f"but got composed layout instead: {input.layout}"
|
||||
f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr",
|
||||
)
|
||||
return True
|
||||
|
||||
def _verify_fragment_B(self, input: _Tensor, *, loc=None, ip=None):
|
||||
if input.memspace == AddressSpace.smem and isinstance(
|
||||
input.layout.type, _cute_ir.ComposedLayoutType
|
||||
):
|
||||
raise OpError(
|
||||
self,
|
||||
f"Expected affine layout for {self._make_trait()}'s operand B, "
|
||||
f"but got composed layout instead: {input.layout}"
|
||||
f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr",
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
class BlockScaledMmaTraits(Trait):
|
||||
admissible_fields = [
|
||||
Field.ACCUMULATE,
|
||||
Field.NEGATE_A,
|
||||
Field.NEGATE_B,
|
||||
Field.SFA,
|
||||
Field.SFB,
|
||||
]
|
||||
|
||||
def set(self, field, value, *, loc=None, ip=None) -> None:
|
||||
if field not in self.admissible_fields:
|
||||
raise ValueError(
|
||||
f"expects field to be one of {self.admissible_fields}, but got {field}"
|
||||
)
|
||||
if field in [Field.ACCUMULATE, Field.NEGATE_A, Field.NEGATE_B]:
|
||||
value = Boolean(value).ir_value(loc=loc, ip=ip)
|
||||
elif field in [Field.SFA, Field.SFB]:
|
||||
if not isinstance(value, Pointer):
|
||||
raise ValueError(
|
||||
f"expects value to be a pointer for {field}, but got {type(value).__name__}"
|
||||
)
|
||||
value = value.value
|
||||
|
||||
field_name = f"#cute_nvgpu.atom_mma_field_sm100_block_scaled<{field._to_ir_field_name()}>"
|
||||
attr = ir.Attribute.parse(field_name)
|
||||
self.value = _cute_nvgpu_ir.atom_set_value(
|
||||
self.value, attr, value, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
|
||||
#
|
||||
# TF32 MMA
|
||||
#
|
||||
@ -602,6 +756,262 @@ class MmaFP8Trait(MmaTrait):
|
||||
pass
|
||||
|
||||
|
||||
#
|
||||
# MXF8F6F4 MMA
|
||||
#
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MmaMXF8Op(BlockScaledMmaOp):
|
||||
"""
|
||||
MXF8 tcgen05 BlockScaled MMA Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma>`__.
|
||||
This Operation corresponds to the ``.kind::mxf8f6f4`` qualifier.
|
||||
"""
|
||||
|
||||
descriptive_name = "tcgen05 MXF8 BlockScaled MMA Operation"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ab_dtype: Type[Numeric],
|
||||
instruction_shape: Shape,
|
||||
cta_group: CtaGroup,
|
||||
a_src: OperandSource,
|
||||
a_major_mode: OperandMajorMode,
|
||||
b_major_mode: OperandMajorMode,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
ab_dtype,
|
||||
ab_dtype,
|
||||
Float32,
|
||||
Float8E8M0FNU,
|
||||
32,
|
||||
instruction_shape,
|
||||
cta_group,
|
||||
a_src,
|
||||
a_major_mode,
|
||||
b_major_mode,
|
||||
)
|
||||
self._verify()
|
||||
|
||||
def _verify(self) -> None:
|
||||
# Input data type verification
|
||||
if self.a_dtype not in [Float8E5M2, Float8E4M3FN]:
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'ab_dtype' Op parameter to be one of Float8E5M2 or Float8E4M3FN",
|
||||
)
|
||||
assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same"
|
||||
# Instruction shape verification
|
||||
instruction_k = 32
|
||||
if rank(self.shape_mnk) == 2:
|
||||
object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k))
|
||||
if self.shape_mnk[2] != instruction_k:
|
||||
raise OpError(
|
||||
self,
|
||||
f"expects the instruction extent in the K-mode to be {instruction_k}, "
|
||||
f"but got {self.shape_mnk[2]}",
|
||||
)
|
||||
|
||||
def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaMXF8Trait":
|
||||
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
|
||||
ty = _cute_nvgpu_ir.MmaAtomSM100UMMABlockScaledType.get(
|
||||
shape_mnk.type.attribute,
|
||||
self.cta_group.value,
|
||||
self.a_major_mode._to_ir(),
|
||||
self.b_major_mode._to_ir(),
|
||||
self.a_dtype.mlir_type,
|
||||
self.b_dtype.mlir_type,
|
||||
self.acc_dtype.mlir_type,
|
||||
self.sf_dtype.mlir_type,
|
||||
self.a_src._to_ir(),
|
||||
self.sf_vec_size,
|
||||
)
|
||||
return MmaMXF8Trait(
|
||||
_cute_nvgpu_ir.make_sm100_mma_bs(
|
||||
ty,
|
||||
Boolean(False).ir_value(loc=loc, ip=ip),
|
||||
Boolean(False).ir_value(loc=loc, ip=ip),
|
||||
Boolean(False).ir_value(loc=loc, ip=ip),
|
||||
core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value,
|
||||
core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class MmaMXF8Trait(BlockScaledMmaTraits):
|
||||
pass
|
||||
|
||||
|
||||
#
|
||||
# MXF4 MMA
|
||||
#
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MmaMXF4Op(BlockScaledMmaOp):
|
||||
"""
|
||||
MXF4 tcgen05 BlockScaled MMA Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma>`__.
|
||||
This Operation corresponds to the ``.kind::mxf4`` qualifier.
|
||||
"""
|
||||
|
||||
descriptive_name = "tcgen05 MXF4 BlockScaled MMA Operation"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
instruction_shape: Shape,
|
||||
cta_group: CtaGroup,
|
||||
a_src: OperandSource,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
Float4E2M1FN,
|
||||
Float4E2M1FN,
|
||||
Float32,
|
||||
Float8E8M0FNU,
|
||||
32,
|
||||
instruction_shape,
|
||||
cta_group,
|
||||
a_src,
|
||||
OperandMajorMode.K,
|
||||
OperandMajorMode.K,
|
||||
)
|
||||
self._verify()
|
||||
|
||||
def _verify(self) -> None:
|
||||
# Instruction shape verification
|
||||
instruction_k = 64
|
||||
if rank(self.shape_mnk) == 2:
|
||||
object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k))
|
||||
if self.shape_mnk[2] != instruction_k:
|
||||
raise OpError(
|
||||
self,
|
||||
f"expects the instruction extent in the K-mode to be {instruction_k}, "
|
||||
f"but got {self.shape_mnk[2]}",
|
||||
)
|
||||
|
||||
def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaMXF8Trait":
|
||||
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
|
||||
ty = _cute_nvgpu_ir.MmaAtomSM100UMMABlockScaledType.get(
|
||||
shape_mnk.type.attribute,
|
||||
self.cta_group.value,
|
||||
self.a_major_mode._to_ir(),
|
||||
self.b_major_mode._to_ir(),
|
||||
self.a_dtype.mlir_type,
|
||||
self.b_dtype.mlir_type,
|
||||
self.acc_dtype.mlir_type,
|
||||
self.sf_dtype.mlir_type,
|
||||
self.a_src._to_ir(),
|
||||
self.sf_vec_size,
|
||||
)
|
||||
return MmaMXF4Trait(
|
||||
_cute_nvgpu_ir.make_sm100_mma_bs(
|
||||
ty,
|
||||
Boolean(False).ir_value(loc=loc, ip=ip),
|
||||
Boolean(False).ir_value(loc=loc, ip=ip),
|
||||
Boolean(False).ir_value(loc=loc, ip=ip),
|
||||
core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value,
|
||||
core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class MmaMXF4Trait(BlockScaledMmaTraits):
|
||||
pass
|
||||
|
||||
|
||||
#
|
||||
# MXF4NVF4 MMA
|
||||
#
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MmaMXF4NVF4Op(BlockScaledMmaOp):
|
||||
"""
|
||||
MXF4NVF4 tcgen05 BlockScaled MMA Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma>`__.
|
||||
This Operation corresponds to the ``.kind::mxf4nvf4`` qualifier.
|
||||
"""
|
||||
|
||||
descriptive_name = "tcgen05 MXF4NVF4 BlockScaled MMA Operation"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sf_dtype: Type[Numeric],
|
||||
instruction_shape: Shape,
|
||||
cta_group: CtaGroup,
|
||||
a_src: OperandSource,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
Float4E2M1FN,
|
||||
Float4E2M1FN,
|
||||
Float32,
|
||||
sf_dtype,
|
||||
16,
|
||||
instruction_shape,
|
||||
cta_group,
|
||||
a_src,
|
||||
OperandMajorMode.K,
|
||||
OperandMajorMode.K,
|
||||
)
|
||||
self._verify()
|
||||
|
||||
def _verify(self) -> None:
|
||||
# Scale Factor data type verification
|
||||
if self.sf_dtype not in [Float8E8M0FNU, Float8E4M3FN]:
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'sf_dtype' Op parameter to be one of Float8E8M0FNU",
|
||||
)
|
||||
# Instruction shape verification
|
||||
instruction_k = 64
|
||||
if rank(self.shape_mnk) == 2:
|
||||
object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k))
|
||||
if self.shape_mnk[2] != instruction_k:
|
||||
raise OpError(
|
||||
self,
|
||||
f"expects the instruction extent in the K-mode to be {instruction_k}, "
|
||||
f"but got {self.shape_mnk[2]}",
|
||||
)
|
||||
|
||||
def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaMXF8Trait":
|
||||
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
|
||||
ty = _cute_nvgpu_ir.MmaAtomSM100UMMABlockScaledType.get(
|
||||
shape_mnk.type.attribute,
|
||||
self.cta_group.value,
|
||||
self.a_major_mode._to_ir(),
|
||||
self.b_major_mode._to_ir(),
|
||||
self.a_dtype.mlir_type,
|
||||
self.b_dtype.mlir_type,
|
||||
self.acc_dtype.mlir_type,
|
||||
self.sf_dtype.mlir_type,
|
||||
self.a_src._to_ir(),
|
||||
self.sf_vec_size,
|
||||
)
|
||||
return MmaMXF4NVF4Trait(
|
||||
_cute_nvgpu_ir.make_sm100_mma_bs(
|
||||
ty,
|
||||
Boolean(False).ir_value(loc=loc, ip=ip),
|
||||
Boolean(False).ir_value(loc=loc, ip=ip),
|
||||
Boolean(False).ir_value(loc=loc, ip=ip),
|
||||
core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value,
|
||||
core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class MmaMXF4NVF4Trait(BlockScaledMmaTraits):
|
||||
pass
|
||||
|
||||
####################################################################################################
|
||||
#
|
||||
# SMEM layout atoms
|
||||
|
||||
@ -28,11 +28,12 @@ import cutlass.base_dsl.jit_executor
|
||||
import cutlass.cute as cute
|
||||
from cutlass._mlir.dialects import builtin, cf, nvvm, vector
|
||||
from cutlass.cute import core, nvgpu
|
||||
from cutlass.cutlass_dsl import Constexpr, CuTeDSL, T, t
|
||||
from cutlass.cutlass_dsl import Constexpr, CuTeDSL, T, t, dsl_user_op
|
||||
|
||||
|
||||
def assert_(cond, msg=None):
|
||||
cf.assert_(t.Boolean(cond).ir_value(), msg if msg else "")
|
||||
@dsl_user_op
|
||||
def assert_(cond, msg=None, *, loc=None, ip=None):
|
||||
cf.assert_(t.Boolean(cond).ir_value(), msg if msg else "", loc=loc, ip=ip)
|
||||
|
||||
|
||||
def _maybe_recast_tensor_from_f4(src: core.Tensor, tv_layout: core.Layout):
|
||||
@ -214,7 +215,14 @@ def convert(src: core.Tensor, dst: core.Tensor):
|
||||
dst.shape
|
||||
), "Shape of src and dst tensors should be the same rank."
|
||||
# find leading mode
|
||||
leading_mode = np.argmin([np.min(s) for s in src.stride])
|
||||
leading_mode = [
|
||||
idx
|
||||
for idx, (shape, stride) in enumerate(zip(src.shape, src.stride))
|
||||
if shape > 1 and stride == 1
|
||||
]
|
||||
if len(leading_mode) != 1:
|
||||
raise ValueError(f"Leading mode should be unique, but got {leading_mode}")
|
||||
leading_mode = leading_mode[0]
|
||||
|
||||
elem_per_copy = 2
|
||||
|
||||
@ -345,7 +353,7 @@ def benchmark(
|
||||
callable: Callable,
|
||||
*,
|
||||
warmup_iterations: int = 10,
|
||||
profiling_iterations: int = 100,
|
||||
iterations: int = 100,
|
||||
stream: Optional[cuda_driver.CUstream] = None,
|
||||
kernel_arguments: Optional[JitArguments] = None,
|
||||
workspace_generator: Optional[Callable[[], JitArguments]] = None,
|
||||
@ -365,7 +373,7 @@ def benchmark(
|
||||
pass
|
||||
|
||||
time_us = benchmark(user_function, kernel_arguments=JitArguments(a, b, c, stream)
|
||||
warmup_iterations=10, profiling_iterations=100
|
||||
warmup_iterations=10, iterations=100
|
||||
stream=stream)
|
||||
|
||||
To prevent skewing results by repeately accessing the L2 cache, use the workspace_count and workspace_generator
|
||||
@ -388,7 +396,7 @@ def benchmark(
|
||||
workspace_generator=workspace_generator,
|
||||
workspace_count=10,
|
||||
warmup_iterations=10000,
|
||||
profiling_iterations=1000)
|
||||
iterations=1000)
|
||||
|
||||
To benchmark you may always configure the function being profiled (callable), the warmup iterations, and
|
||||
the number of profiling iterations.
|
||||
@ -402,8 +410,8 @@ def benchmark(
|
||||
:type callable: Callable
|
||||
:param warmup_iterations: Number of warmup iterations, defaults to 10
|
||||
:type warmup_iterations: int, optional
|
||||
:param profiling_iterations: Number of benchmark iterations, defaults to 100
|
||||
:type profiling_iterations: int, optional
|
||||
:param iterations: Number of benchmark iterations, defaults to 100
|
||||
:type iterations: int, optional
|
||||
:param stream: Stream kernel is launched in, defaults to CUDA stream default
|
||||
:type stream: CUstream, None
|
||||
:param kernel_arguments: Kernel arguments to launch callable with, defaults to None
|
||||
@ -502,7 +510,7 @@ def benchmark(
|
||||
stream, cuda_runtime.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal
|
||||
)
|
||||
_cuda_success(err, "Error on stream capture")
|
||||
_loop_and_call_kernel(profiling_iterations, workspace_index)
|
||||
_loop_and_call_kernel(iterations, workspace_index)
|
||||
err, gprofile = cuda_runtime.cudaStreamEndCapture(stream)
|
||||
_cuda_success(err, "Error on stream capture")
|
||||
|
||||
@ -557,7 +565,7 @@ def benchmark(
|
||||
# Record start event
|
||||
err = cuda_driver.cuEventRecord(start_event, stream)
|
||||
_cuda_success(err, "Error on recording event")
|
||||
_loop_and_call_kernel(profiling_iterations, workspace_index)
|
||||
_loop_and_call_kernel(iterations, workspace_index)
|
||||
# Record end event
|
||||
err = cuda_driver.cuEventRecord(end_event, stream)
|
||||
_cuda_success(err, "Error on recording event")
|
||||
@ -573,6 +581,30 @@ def benchmark(
|
||||
err = cuda_driver.cuEventDestroy(end_event)
|
||||
_cuda_success(err, "Error on destroying event")
|
||||
|
||||
return elapsed_time / profiling_iterations * 1e3
|
||||
return elapsed_time / iterations * 1e3
|
||||
|
||||
|
||||
def get_workspace_count(
|
||||
one_workspace_bytes: int, warmup_iterations: int, iterations: int
|
||||
) -> int:
|
||||
"""Calculate the number of workspaces needed to fill L2 cache.
|
||||
|
||||
:param one_workspace_bytes: Size of one workspace in bytes
|
||||
:type one_workspace_bytes: int
|
||||
:param warmup_iterations: Number of warmup iterations
|
||||
:type warmup_iterations: int
|
||||
:param iterations: Number of iterations
|
||||
:type iterations: int
|
||||
:return: Number of workspaces needed
|
||||
:rtype: int
|
||||
"""
|
||||
num_l2_cache_bytes = cutlass.utils.HardwareInfo().get_l2_cache_size_in_bytes()
|
||||
return max(
|
||||
1,
|
||||
min(
|
||||
warmup_iterations + iterations, # Don't create more workspaces than needed
|
||||
(num_l2_cache_bytes + one_workspace_bytes - 1)
|
||||
// one_workspace_bytes, # Ceiling division
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -91,33 +91,6 @@ class PipelineAsync:
|
||||
- D: Data ready (producer has written data to buffer)
|
||||
- R: Consumer reading (consumer is consuming data from buffer)
|
||||
|
||||
**Example:**
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Create pipeline with 5 stages
|
||||
pipeline = PipelineAsync.create(
|
||||
num_stages=5, # number of pipeline stages
|
||||
producer_group=producer_warp,
|
||||
consumer_group=consumer_warp
|
||||
barrier_storage=smem_ptr, # smem pointer for array of mbarriers in shared memory
|
||||
)
|
||||
|
||||
# Producer side
|
||||
producer = pipeline.make_pipeline_producer(producer_warp)
|
||||
for i in range(num_iterations):
|
||||
producer.acquire() # Wait for buffer to be empty
|
||||
# Write data to pipeline buffer
|
||||
producer.commit() # Signal buffer is full
|
||||
producer.advance() # Move index to next stage
|
||||
|
||||
# Consumer side
|
||||
consumer = pipeline.make_pipeline_consumer(consumer_warp)
|
||||
for i in range(num_iterations):
|
||||
consumer.wait() # Wait for buffer to be full
|
||||
# Read data from pipeline buffer
|
||||
consumer.release() # Signal buffer is empty
|
||||
consumer.advance() # Move index to next stage
|
||||
"""
|
||||
|
||||
sync_object_full: SyncObject
|
||||
@ -259,16 +232,6 @@ class PipelineAsync:
|
||||
state.advance()
|
||||
self.producer_acquire(state)
|
||||
|
||||
# Util methods to manage produer and consumer
|
||||
def make_pipeline_producer(self, group: CooperativeGroup):
|
||||
state = make_pipeline_state(PipelineUserType.Producer, self.num_stages)
|
||||
return PipelineProducer(self, state, group)
|
||||
|
||||
def make_pipeline_consumer(self, group: CooperativeGroup):
|
||||
state = make_pipeline_state(PipelineUserType.Consumer, self.num_stages)
|
||||
return PipelineConsumer(self, state, group)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PipelineTmaAsync(PipelineAsync):
|
||||
"""
|
||||
@ -593,211 +556,3 @@ class PipelineTmaStore(PipelineAsync):
|
||||
self.sync_object_full.tail()
|
||||
|
||||
|
||||
#################################################################
|
||||
# Utilities to help user of pipeline to simplify the workflow
|
||||
#################################################################
|
||||
|
||||
|
||||
class PipelineProducer:
|
||||
"""A class representing a producer in an asynchronous pipeline.
|
||||
|
||||
The Producer class manages the producer side of an asynchronous pipeline, handling
|
||||
synchronization and state management for producing data. It provides methods for
|
||||
acquiring, committing, and advancing through pipeline stages.
|
||||
|
||||
:ivar _pipeline: The asynchronous pipeline this producer belongs to
|
||||
:type _pipeline: PipelineAsync
|
||||
:ivar _state: The current state of the producer in the pipeline
|
||||
:type _state: PipelineState
|
||||
:ivar _group: The cooperative group this producer operates in
|
||||
:type _group: CooperativeGroup
|
||||
|
||||
**Examples:**
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
pipeline = PipelineAsync.create(...)
|
||||
producer = pipeline.create_producer(producer_group, stages)
|
||||
for i in range(iterations):
|
||||
producer.acquire() # Wait for buffer to be empty
|
||||
# Produce data
|
||||
producer.commit() # Signal data is ready
|
||||
producer.advance() # Move to next stage
|
||||
"""
|
||||
|
||||
_pipeline: PipelineAsync
|
||||
_state: PipelineState
|
||||
_group: CooperativeGroup
|
||||
|
||||
def __init__(self, pipeline, state, group: CooperativeGroup):
|
||||
"""Initialize a new Producer instance.
|
||||
|
||||
:param pipeline: The pipeline this producer belongs to
|
||||
:type pipeline: PipelineAsync
|
||||
:param state: Initial pipeline state
|
||||
:type state: PipelineState
|
||||
:param group: The cooperative group for synchronization
|
||||
:type group: CooperativeGroup
|
||||
"""
|
||||
self._pipeline = pipeline
|
||||
self._state = state
|
||||
self._group = group
|
||||
|
||||
@property
|
||||
def index(self):
|
||||
"""Get the index of the current pipeline stage."""
|
||||
return self._state.index
|
||||
|
||||
def get_barrier(self):
|
||||
"""Get the barrier pointer for the current pipeline stage.
|
||||
|
||||
:return: Pointer to the barrier for the current stage
|
||||
:rtype: cute.Pointer
|
||||
"""
|
||||
return self._pipeline.producer_get_barrier(self._state)
|
||||
|
||||
def acquire(self):
|
||||
"""Wait for the current buffer to be empty before producing data.
|
||||
This is a blocking operation.
|
||||
"""
|
||||
self._pipeline.producer_acquire(self._state)
|
||||
|
||||
def try_acquire(self):
|
||||
"""Try to acquire the current buffer without blocking.
|
||||
|
||||
:return: True if acquisition was successful, False otherwise
|
||||
:rtype: bool
|
||||
"""
|
||||
self._pipeline.producer_try_acquire(self._state)
|
||||
|
||||
def commit(self):
|
||||
"""Signal that data production is complete for the current stage.
|
||||
This allows consumers to start processing the data.
|
||||
"""
|
||||
self._pipeline.producer_commit(self._state)
|
||||
|
||||
def tail(self):
|
||||
"""Ensure all used buffers are properly synchronized before producer exit.
|
||||
This should be called before the producer finishes to avoid dangling signals.
|
||||
"""
|
||||
self._pipeline.producer_tail(self._state)
|
||||
|
||||
def advance(self):
|
||||
"""Move to the next pipeline stage."""
|
||||
self._state.advance()
|
||||
|
||||
def __extract_mlir_values__(self):
|
||||
"""Extract MLIR values from the current state.
|
||||
|
||||
:return: List of MLIR values representing the current state
|
||||
:rtype: list
|
||||
"""
|
||||
# TODO: need to handle pipeline as well
|
||||
return self._state.__extract_mlir_values__()
|
||||
|
||||
def __new_from_mlir_values__(self, values):
|
||||
"""Create a new Producer instance from MLIR values.
|
||||
|
||||
:param values: MLIR values to initialize the state
|
||||
:type values: Any
|
||||
:return: New Producer instance with state initialized from values
|
||||
:rtype: Producer
|
||||
"""
|
||||
return PipelineProducer(
|
||||
self._pipeline, self._state.__new_from_mlir_values__(values), self._group
|
||||
)
|
||||
|
||||
|
||||
class PipelineConsumer:
|
||||
"""A class representing a consumer in an asynchronous pipeline.
|
||||
|
||||
The Consumer class manages the consumer side of an asynchronous pipeline, handling
|
||||
synchronization and state management for consuming data. It provides methods for
|
||||
waiting, releasing, and advancing through pipeline stages.
|
||||
|
||||
:ivar _pipeline: The asynchronous pipeline this consumer belongs to
|
||||
:type _pipeline: PipelineAsync
|
||||
:ivar _state: The current state of the consumer in the pipeline
|
||||
:type _state: PipelineState
|
||||
:ivar _group: The cooperative group this consumer operates in
|
||||
:type _group: CooperativeGroup
|
||||
|
||||
**Examples:**
|
||||
.. code-block:: python
|
||||
|
||||
pipeline = PipelineAsync.create(...)
|
||||
consumer = pipeline.create_consumer(consumer_group, stages)
|
||||
for i in range(iterations):
|
||||
consumer.wait() # Wait for data to be ready
|
||||
# Consume data
|
||||
consumer.release() # Signal buffer is empty
|
||||
consumer.advance() # Move to next stage
|
||||
"""
|
||||
|
||||
_pipeline: PipelineAsync
|
||||
_state: PipelineState
|
||||
_group: CooperativeGroup
|
||||
|
||||
def __init__(self, pipeline, state: PipelineState, group: CooperativeGroup):
|
||||
"""Initialize a new Consumer instance.
|
||||
|
||||
:param pipeline: The pipeline this consumer belongs to
|
||||
:type pipeline: PipelineAsync
|
||||
:param state: Initial pipeline state
|
||||
:type state: PipelineState
|
||||
:param group: The cooperative group for synchronization
|
||||
:type group: CooperativeGroup
|
||||
"""
|
||||
self._pipeline = pipeline
|
||||
self._group = group
|
||||
self._state = state
|
||||
|
||||
@property
|
||||
def index(self):
|
||||
"""Get the index of the current pipeline stage."""
|
||||
return self._state.index
|
||||
|
||||
def wait(self):
|
||||
"""Wait for data to be ready in the current buffer.
|
||||
This is a blocking operation.
|
||||
"""
|
||||
self._pipeline.consumer_wait(self._state)
|
||||
|
||||
def try_wait(self):
|
||||
"""Try to check if data is ready without blocking.
|
||||
|
||||
:return: True if data is ready, False otherwise
|
||||
:rtype: bool
|
||||
"""
|
||||
self._pipeline.consumer_try_wait(self._state)
|
||||
|
||||
def release(self):
|
||||
"""Signal that data consumption is complete for the current stage.
|
||||
This allows producers to start producing new data.
|
||||
"""
|
||||
self._pipeline.consumer_release(self._state)
|
||||
|
||||
def advance(self):
|
||||
"""Move to the next pipeline stage."""
|
||||
self._state.advance()
|
||||
|
||||
def __extract_mlir_values__(self):
|
||||
"""Extract MLIR values from the current state.
|
||||
|
||||
:return: List of MLIR values representing the current state
|
||||
:rtype: list
|
||||
"""
|
||||
return self._state.__extract_mlir_values__()
|
||||
|
||||
def __new_from_mlir_values__(self, values):
|
||||
"""Create a new Consumer instance from MLIR values.
|
||||
|
||||
:param values: MLIR values to initialize the state
|
||||
:type values: Any
|
||||
:return: New Consumer instance with state initialized from values
|
||||
:rtype: Consumer
|
||||
"""
|
||||
# TODO: need to call pipeline.__new_from_mlir_values__ recursively
|
||||
return PipelineConsumer(
|
||||
self._pipeline, self._state.__new_from_mlir_values__(values), self._group
|
||||
)
|
||||
|
||||
@ -29,7 +29,7 @@ from cutlass.cute.typing import (
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
import cutlass.cute as cute
|
||||
import torch
|
||||
from cuda import cuda
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
|
||||
def dtype(ty: Type[Numeric]):
|
||||
|
||||
@ -28,12 +28,22 @@ from .blackwell_helpers import (
|
||||
make_smem_layout_b,
|
||||
make_smem_layout_epi,
|
||||
make_trivial_tiled_mma,
|
||||
make_blockscaled_trivial_tiled_mma,
|
||||
)
|
||||
|
||||
from .hopper_helpers import (
|
||||
sm90_get_smem_store_op,
|
||||
)
|
||||
|
||||
from .blockscaled_layout import (
|
||||
BlockScaledBasicChunk,
|
||||
tile_atom_to_shape_SF,
|
||||
make_smem_layout_sfa,
|
||||
make_smem_layout_sfb,
|
||||
make_tmem_layout_sfa,
|
||||
make_tmem_layout_sfb,
|
||||
)
|
||||
|
||||
from .grouped_gemm_tile_scheduler_helper import (
|
||||
GroupSearchResult,
|
||||
GroupedGemmGroupSearchState,
|
||||
@ -50,7 +60,12 @@ from .smem_allocator import SmemAllocator
|
||||
|
||||
from .layout import LayoutEnum
|
||||
|
||||
from .smem_capacity import (
|
||||
get_smem_capacity_in_bytes,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"get_smem_capacity_in_bytes",
|
||||
"SmemAllocator",
|
||||
"LayoutEnum",
|
||||
"WorkTileInfo",
|
||||
|
||||
@ -10,14 +10,22 @@
|
||||
# is strictly prohibited.
|
||||
|
||||
from enum import Enum
|
||||
from typing_extensions import deprecated
|
||||
import warnings
|
||||
|
||||
|
||||
@deprecated("Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead")
|
||||
class SmemCapacity(Enum):
|
||||
SM80_SMEM_CAPACITY_BYTES = (164 - 1) * 1024
|
||||
SM86_SMEM_CAPACITY_BYTES = (100 - 1) * 1024
|
||||
SM89_SMEM_CAPACITY_BYTES = (100 - 1) * 1024
|
||||
|
||||
|
||||
warnings.warn(
|
||||
"SMEM_CAPACITY is deprecated: Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
# Dictionary to map compute capability to SMEM capacity
|
||||
SMEM_CAPACITY = {
|
||||
"sm80": SmemCapacity.SM80_SMEM_CAPACITY_BYTES.value,
|
||||
|
||||
@ -12,6 +12,8 @@
|
||||
from enum import Enum
|
||||
from math import log2, ceil
|
||||
from typing import List, Type, Union, Tuple
|
||||
from typing_extensions import deprecated
|
||||
import warnings
|
||||
|
||||
from cutlass.cutlass_dsl import (
|
||||
Float16,
|
||||
@ -22,6 +24,7 @@ from cutlass.cutlass_dsl import (
|
||||
Int8,
|
||||
Float8E4M3FN,
|
||||
Float8E5M2,
|
||||
Float4E2M1FN,
|
||||
Numeric,
|
||||
NumericMeta,
|
||||
dsl_user_op,
|
||||
@ -34,6 +37,9 @@ from cutlass.cute.nvgpu.tcgen05 import (
|
||||
MmaTF32Op,
|
||||
MmaI8Op,
|
||||
MmaFP8Op,
|
||||
MmaMXF8Op,
|
||||
MmaMXF4Op,
|
||||
MmaMXF4NVF4Op,
|
||||
OperandSource,
|
||||
OperandMajorMode,
|
||||
CtaGroup,
|
||||
@ -58,6 +64,24 @@ from cutlass.cute.nvgpu.cpasync import (
|
||||
from cutlass.utils.layout import LayoutEnum
|
||||
|
||||
|
||||
@deprecated("Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead")
|
||||
class SmemCapacity(Enum):
|
||||
SM100_SMEM_CAPACITY_BYTES = (228 - 1) * 1024
|
||||
SM120_SMEM_CAPACITY_BYTES = (100 - 1) * 1024
|
||||
|
||||
|
||||
warnings.warn(
|
||||
"SMEM_CAPACITY is deprecated: Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
# Dictionary to map compute capability to SMEM capacity
|
||||
SMEM_CAPACITY = {
|
||||
"sm100": SmemCapacity.SM100_SMEM_CAPACITY_BYTES.value,
|
||||
"sm120": SmemCapacity.SM120_SMEM_CAPACITY_BYTES.value,
|
||||
}
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def compute_epilogue_tile_shape(
|
||||
cta_tile_shape: cute.Shape,
|
||||
@ -822,18 +846,6 @@ def make_smem_layout_epi(
|
||||
return epi_smem_layout_staged
|
||||
|
||||
|
||||
class SmemCapacity(Enum):
|
||||
SM100_SMEM_CAPACITY_BYTES = (228 - 1) * 1024
|
||||
SM120_SMEM_CAPACITY_BYTES = (100 - 1) * 1024
|
||||
|
||||
|
||||
# Dictionary to map compute capability to SMEM capacity
|
||||
SMEM_CAPACITY = {
|
||||
"sm100": SmemCapacity.SM100_SMEM_CAPACITY_BYTES.value,
|
||||
"sm120": SmemCapacity.SM120_SMEM_CAPACITY_BYTES.value,
|
||||
}
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def make_trivial_tiled_mma(
|
||||
ab_dtype: Type[Numeric],
|
||||
@ -917,6 +929,76 @@ def make_trivial_tiled_mma(
|
||||
return cute.make_tiled_mma(cute.make_mma_atom(mma_op))
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def make_blockscaled_trivial_tiled_mma(
|
||||
ab_dtype: Type[Numeric],
|
||||
a_leading_mode: OperandMajorMode,
|
||||
b_leading_mode: OperandMajorMode,
|
||||
sf_dtype: Type[Numeric],
|
||||
sf_vec_size: int,
|
||||
cta_group: CtaGroup,
|
||||
mma_tiler_mn: Tuple[int, int],
|
||||
a_source: OperandSource = OperandSource.SMEM,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> cute.TiledMma:
|
||||
"""Make a BlockScaled tiled MMA atom with given data type, leading dimension, cta group and mma tile shape.
|
||||
By default, the MMA atom is created with SMEM operand source for A.
|
||||
|
||||
:param ab_dtype: Data type of operands A and B.
|
||||
:type ab_dtype: type[Numeric]
|
||||
:param a_leading_mode: Leading dimension of operand A (1 for K, 0 for M/N).
|
||||
:type a_leading_mode: tcgen05.OperandMajorMode
|
||||
:param b_leading_mode: Leading dimension of operand B (1 for K, 0 for M/N).
|
||||
:type b_leading_mode: tcgen05.OperandMajorMode
|
||||
:param sf_dtype: Data type of the Scale Factor.
|
||||
:type sf_dtype: type[Numeric]
|
||||
:param sf_vec_size: The vector size of the Scale Factor.
|
||||
:type sf_vec_size: int
|
||||
:param cta_group: The CTA group to use.
|
||||
:type cta_group: tcgen05.CtaGroup
|
||||
:param mma_tiler_mn: The shape (M, N, K) of the MMA tiler.
|
||||
:type mma_tiler_mn: Tuple[int, int]
|
||||
:param a_source: The source of operand A (SMEM by default or TMEM).
|
||||
:type a_source: OperandSource
|
||||
|
||||
:return: A tiled MMA atom.
|
||||
:rtype: cute.TiledMma
|
||||
|
||||
:raises TypeError: If the data type is not supported.
|
||||
"""
|
||||
if ab_dtype in {Float8E4M3FN, Float8E5M2}:
|
||||
mma_op = MmaMXF8Op(
|
||||
ab_dtype,
|
||||
(*mma_tiler_mn, 32),
|
||||
cta_group,
|
||||
a_source,
|
||||
a_leading_mode,
|
||||
b_leading_mode,
|
||||
)
|
||||
elif ab_dtype == Float4E2M1FN:
|
||||
if sf_vec_size == 32:
|
||||
mma_op = MmaMXF4Op(
|
||||
(*mma_tiler_mn, 64),
|
||||
cta_group,
|
||||
a_source,
|
||||
)
|
||||
elif sf_vec_size == 16:
|
||||
mma_op = MmaMXF4NVF4Op(
|
||||
sf_dtype,
|
||||
(*mma_tiler_mn, 64),
|
||||
cta_group,
|
||||
a_source,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unsupported sf_vec_size, got {sf_vec_size}")
|
||||
else:
|
||||
raise TypeError(f"unsupported ab_dtype, got {ab_dtype}")
|
||||
|
||||
return cute.make_tiled_mma(cute.make_mma_atom(mma_op))
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def cluster_shape_to_tma_atom_A(
|
||||
cluster_shape_mnk: cute.Shape, atom_thr_id: cute.Layout, *, loc=None, ip=None
|
||||
|
||||
287
python/CuTeDSL/cutlass/utils/blockscaled_layout.py
Normal file
287
python/CuTeDSL/cutlass/utils/blockscaled_layout.py
Normal file
@ -0,0 +1,287 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Union
|
||||
|
||||
from cutlass.cutlass_dsl import dsl_user_op
|
||||
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode
|
||||
|
||||
import cutlass._mlir.dialects.cute as _cute_ir
|
||||
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BlockScaledBasicChunk:
|
||||
"""
|
||||
The basic scale factor atom layout decided by tcgen05 BlockScaled MMA Ops.
|
||||
|
||||
This class represents the fixed layout pattern for scale factors used in
|
||||
tcgen05 BlockScaled MMA Ops. The layout is determined by the
|
||||
instruction specification and cannot be modified.
|
||||
See `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-a-layout-1x>`.
|
||||
"""
|
||||
|
||||
sf_vec_size: int
|
||||
major_mode: OperandMajorMode = OperandMajorMode.K
|
||||
_layout: cute.Layout = field(init=False, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.major_mode == OperandMajorMode.K:
|
||||
# K-major layout: (AtomMN, AtomK)
|
||||
atom_shape = ((32, 4), (self.sf_vec_size, 4))
|
||||
atom_stride = ((16, 4), (0, 1))
|
||||
else:
|
||||
# MN-major layout: (AtomK, AtomMN)
|
||||
atom_shape = ((self.sf_vec_size, 4), (32, 4))
|
||||
atom_stride = ((0, 1), (16, 4))
|
||||
|
||||
object.__setattr__(
|
||||
self, "_layout", cute.make_layout(atom_shape, stride=atom_stride)
|
||||
)
|
||||
|
||||
@property
|
||||
def layout(self) -> cute.Layout:
|
||||
"""
|
||||
Get the layout for this block scaled chunk.
|
||||
|
||||
:return: The layout representing the scale factor atom
|
||||
:rtype: cute.Layout
|
||||
"""
|
||||
return self._layout
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def tile_atom_to_shape_SF(
|
||||
Shape: cute.Shape,
|
||||
sf_vec_size: int,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> cute.Layout:
|
||||
"""
|
||||
A helper function to get dynamic SFA/SFB layout by filling dynamic A/B shape to the scale factor atom layout.
|
||||
|
||||
:param Shape: The shape of the A/B tensor
|
||||
:param sf_vec_size: Scale factor vector size
|
||||
|
||||
:return: The layout of the SFA/SFB tensor
|
||||
:rtype: cute.Layout
|
||||
"""
|
||||
# ((Atom_MN, Rest_MN),(Atom_K, Rest_K),RestL)
|
||||
sf_layout = cute.tile_to_shape(
|
||||
BlockScaledBasicChunk(sf_vec_size).layout, Shape, (2, 1, 3)
|
||||
)
|
||||
return sf_layout
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def make_smem_layout_sfa(
|
||||
tiled_mma: cute.TiledMma,
|
||||
mma_tiler_mnk: cute.Tile,
|
||||
sf_vec_size: int,
|
||||
num_stages: int,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> cute.Layout:
|
||||
"""
|
||||
Make smem layout for SFA based on:
|
||||
1. BlockScaledBasicChunk
|
||||
2. MMA tiler shape
|
||||
3. Scale factor vector size
|
||||
4. Number of stages
|
||||
|
||||
:param tiled_mma: The tiled MMA
|
||||
:type tiled_mma: cute.TiledMma
|
||||
:param mma_tiler_mnk: The mma tiler shape
|
||||
:type mma_tiler_mnk: cute.Tile
|
||||
:param sf_vec_size: The scale factor vector size
|
||||
:type sf_vec_size: int
|
||||
:param num_stages: The number of stages
|
||||
:type num_stages: int
|
||||
|
||||
:return: Smem layout for SFA
|
||||
:rtype: cute.Layout
|
||||
"""
|
||||
# (CTA_Tile_Shape_M, MMA_Tile_Shape_K)
|
||||
sfa_tile_shape = (
|
||||
mma_tiler_mnk[0] // cute.size(tiled_mma.thr_id.shape),
|
||||
mma_tiler_mnk[2],
|
||||
)
|
||||
|
||||
# ((Atom_M, Rest_M),(Atom_K, Rest_K))
|
||||
smem_layout = cute.tile_to_shape(
|
||||
BlockScaledBasicChunk(sf_vec_size).layout,
|
||||
sfa_tile_shape,
|
||||
(2, 1),
|
||||
)
|
||||
|
||||
mma_tile_inst_k = 4
|
||||
# (CTA_Tile_Shape_M, MMA_Inst_Shape_K)
|
||||
sfa_tile_shape = cute.shape_div(sfa_tile_shape, (1, mma_tile_inst_k))
|
||||
# ((Atom_Inst_M, Atom_Inst_K), MMA_M, MMA_K))
|
||||
smem_layout = cute.tiled_divide(smem_layout, sfa_tile_shape)
|
||||
|
||||
atom_m = 128
|
||||
tiler_inst = ((atom_m, sf_vec_size),)
|
||||
# (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K)
|
||||
smem_layout = cute.logical_divide(smem_layout, tiler_inst)
|
||||
|
||||
# (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K, STAGE)
|
||||
sfa_smem_layout_staged = cute.append(
|
||||
smem_layout,
|
||||
cute.make_layout(
|
||||
num_stages, stride=cute.cosize(cute.filter_zeros(smem_layout))
|
||||
),
|
||||
)
|
||||
|
||||
return sfa_smem_layout_staged
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def make_smem_layout_sfb(
|
||||
tiled_mma: cute.TiledMma,
|
||||
mma_tiler_mnk: cute.Tile,
|
||||
sf_vec_size: int,
|
||||
num_stages: int,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> cute.Layout:
|
||||
"""
|
||||
Make smem layout for SFB based on:
|
||||
1. BlockScaledBasicChunk
|
||||
2. MMA tiler shape
|
||||
3. Scale factor vector size
|
||||
4. Number of stages
|
||||
|
||||
:param tiled_mma: The tiled MMA
|
||||
:type tiled_mma: cute.TiledMma
|
||||
:param mma_tiler_mnk: The mma tiler shape
|
||||
:type mma_tiler_mnk: cute.Tile
|
||||
:param sf_vec_size: The scale factor vector size
|
||||
:type sf_vec_size: int
|
||||
:param num_stages: The number of stages
|
||||
:type num_stages: int
|
||||
|
||||
:return: Smem layout for SFA
|
||||
:rtype: cute.Layout
|
||||
"""
|
||||
# (Round_Up(CTA_Tile_Shape_N, 128), MMA_Tile_Shape_K)
|
||||
sfb_tile_shape = (
|
||||
cute.round_up(mma_tiler_mnk[1], 128),
|
||||
mma_tiler_mnk[2],
|
||||
)
|
||||
|
||||
# ((Atom_N, Rest_N),(Atom_K, Rest_K))
|
||||
smem_layout = cute.tile_to_shape(
|
||||
BlockScaledBasicChunk(sf_vec_size).layout,
|
||||
sfb_tile_shape,
|
||||
(2, 1),
|
||||
)
|
||||
|
||||
mma_tile_inst_k = 4
|
||||
# (CTA_Tile_Shape_N, MMA_Inst_Shape_K)
|
||||
sfb_tile_shape = cute.shape_div(sfb_tile_shape, (1, mma_tile_inst_k))
|
||||
# ((Atom_Inst_N, Atom_Inst_K), MMA_N, MMA_K)
|
||||
smem_layout = cute.tiled_divide(smem_layout, sfb_tile_shape)
|
||||
|
||||
atom_n = 128
|
||||
tiler_inst = ((atom_n, sf_vec_size),)
|
||||
# (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K)
|
||||
smem_layout = cute.logical_divide(smem_layout, tiler_inst)
|
||||
|
||||
# (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K, STAGE)
|
||||
sfb_smem_layout_staged = cute.append(
|
||||
smem_layout,
|
||||
cute.make_layout(
|
||||
num_stages, stride=cute.cosize(cute.filter_zeros(smem_layout))
|
||||
),
|
||||
)
|
||||
|
||||
return sfb_smem_layout_staged
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def make_tmem_layout_sfa(
|
||||
tiled_mma: cute.TiledMma,
|
||||
mma_tiler_mnk: cute.Tile,
|
||||
sf_vec_size: int,
|
||||
smem_layout: cute.Layout,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> cute.Layout:
|
||||
"""Make tmem layout for SFA based on:
|
||||
1. SFA smem layout per stage
|
||||
2. Cta tile shape m
|
||||
3. tiled MMA atom thr size
|
||||
4. Scale factor vector size
|
||||
|
||||
:param tiled_mma: The tiled MMA
|
||||
:type tiled_mma: cute.TiledMma
|
||||
:param mma_tiler_mnk: The mma tiler shape
|
||||
:type mma_tiler_mnk: cute.Tile
|
||||
:param sf_vec_size: The scale factor vector size
|
||||
:type sf_vec_size: int
|
||||
:param smem_layout: The smem layout of SFA per stage
|
||||
:type smem_layout: cute.Layout
|
||||
|
||||
:return: TMEM layout for SFA
|
||||
:rtype: cute.Layout
|
||||
"""
|
||||
atom_thr_size = cute.size(tiled_mma.thr_id.shape)
|
||||
cta_tile_shape_m = mma_tiler_mnk[0] // atom_thr_size
|
||||
|
||||
sfa_layout_ty = _cute_nvgpu_ir.make_tmem_layout_sfa(
|
||||
smem_layout, cta_tile_shape_m, atom_thr_size, sf_vec_size
|
||||
)
|
||||
return _cute_ir.static(sfa_layout_ty, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def make_tmem_layout_sfb(
|
||||
tiled_mma: cute.TiledMma,
|
||||
mma_tiler_mnk: cute.Tile,
|
||||
sf_vec_size: int,
|
||||
smem_layout: cute.Layout,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> cute.Layout:
|
||||
"""Make tmem layout for SFB based on:
|
||||
1. SFB smem layout per stage
|
||||
2. Cta tile shape m
|
||||
3. tiled MMA atom thr size
|
||||
4. Scale factor vector size
|
||||
|
||||
:param tiled_mma: The tiled MMA
|
||||
:type tiled_mma: cute.TiledMma
|
||||
:param mma_tiler_mnk: The mma tiler shape
|
||||
:type mma_tiler_mnk: cute.Tile
|
||||
:param sf_vec_size: The scale factor vector size
|
||||
:type sf_vec_size: int
|
||||
:param smem_layout: The smem layout of SFB per stage
|
||||
:type smem_layout: cute.Layout
|
||||
|
||||
:return: TMEM layout for SFB
|
||||
:rtype: cute.Layout
|
||||
"""
|
||||
atom_thr_size = cute.size(tiled_mma.thr_id.shape)
|
||||
cta_tile_shape_m = mma_tiler_mnk[0] // atom_thr_size
|
||||
|
||||
sfb_layout_ty = _cute_nvgpu_ir.make_tmem_layout_sfb(
|
||||
smem_layout, cta_tile_shape_m, atom_thr_size, sf_vec_size
|
||||
)
|
||||
return _cute_ir.static(sfb_layout_ty, loc=loc, ip=ip)
|
||||
@ -11,6 +11,8 @@
|
||||
|
||||
from typing import Type, Tuple
|
||||
from enum import Enum
|
||||
from typing_extensions import deprecated
|
||||
import warnings
|
||||
|
||||
from cutlass.utils.layout import LayoutEnum
|
||||
from cutlass.cutlass_dsl import (
|
||||
@ -34,6 +36,23 @@ from cutlass.cute.nvgpu.warpgroup import (
|
||||
OperandSource,
|
||||
)
|
||||
|
||||
|
||||
@deprecated("Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead")
|
||||
class SmemCapacity(Enum):
|
||||
SM90_SMEM_CAPACITY_BYTES = (228 - 1) * 1024
|
||||
|
||||
|
||||
warnings.warn(
|
||||
"SMEM_CAPACITY is deprecated: Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
# Dictionary to map compute capability to SMEM capacity
|
||||
SMEM_CAPACITY = {
|
||||
"sm90": SmemCapacity.SM90_SMEM_CAPACITY_BYTES.value,
|
||||
}
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def sm90_get_smem_store_op(
|
||||
layout_d: LayoutEnum,
|
||||
@ -79,15 +98,6 @@ def sm90_get_smem_store_op(
|
||||
return cute.make_copy_atom(CopyUniversalOp(), elem_ty_d, loc=loc, ip=ip)
|
||||
|
||||
|
||||
class SmemCapacity(Enum):
|
||||
SM90_SMEM_CAPACITY_BYTES = (228 - 1) * 1024
|
||||
|
||||
|
||||
# Dictionary to map compute capability to SMEM capacity
|
||||
SMEM_CAPACITY = {
|
||||
"sm90": SmemCapacity.SM90_SMEM_CAPACITY_BYTES.value,
|
||||
}
|
||||
|
||||
def make_trivial_tiled_mma(
|
||||
a_dtype: Type[Numeric],
|
||||
b_dtype: Type[Numeric],
|
||||
|
||||
@ -14,7 +14,7 @@ from typing import Type, Union, overload
|
||||
from cutlass.cutlass_dsl import Int8, Numeric, NumericMeta
|
||||
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.arch import get_dyn_smem
|
||||
from cutlass.cute.arch import get_dyn_smem, get_dyn_smem_size
|
||||
|
||||
|
||||
class SmemAllocator:
|
||||
@ -60,6 +60,7 @@ class SmemAllocator:
|
||||
:return: Pointer to the start of the allocated memory block or struct instance
|
||||
:rtype: cute.Pointer
|
||||
:raises ValueError: If size is negative or alignment is less than 1
|
||||
:raises RuntimeError: If allocation would exceed available shared memory
|
||||
"""
|
||||
if isinstance(size_or_type, cute.struct):
|
||||
alignment = max(byte_alignment, size_or_type.__alignof__())
|
||||
@ -80,6 +81,14 @@ class SmemAllocator:
|
||||
byte_alignment - self._allocated_bytes % byte_alignment
|
||||
)
|
||||
self._allocated_bytes += num_bytes
|
||||
|
||||
# Check bounds against available dynamic shared memory
|
||||
cute.testing.assert_(
|
||||
self._allocated_bytes <= get_dyn_smem_size(),
|
||||
f"Allocation failed: shared memory allocation exceeds available memory set in kernel launch. "
|
||||
f"Allocated bytes: {self._allocated_bytes} bytes. "
|
||||
f"Please reduce the allocation or set a larger smem size in kernel launch.",
|
||||
)
|
||||
return ptr
|
||||
|
||||
def allocate_array(self, element_type: Type[Numeric], num_elems: int = 1):
|
||||
|
||||
26
python/CuTeDSL/cutlass/utils/smem_capacity.py
Normal file
26
python/CuTeDSL/cutlass/utils/smem_capacity.py
Normal file
@ -0,0 +1,26 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
|
||||
SMEM_CAPACITY_MAP = {
|
||||
"sm_120": (100 - 1) * 1024,
|
||||
"sm_100": (228 - 1) * 1024,
|
||||
"sm_90": (228 - 1) * 1024,
|
||||
"sm_80": (164 - 1) * 1024,
|
||||
"sm_86": (100 - 1) * 1024,
|
||||
"sm_89": (100 - 1) * 1024,
|
||||
}
|
||||
|
||||
|
||||
def get_smem_capacity_in_bytes(compute_capability: str) -> int:
|
||||
if compute_capability not in SMEM_CAPACITY_MAP:
|
||||
raise ValueError(f"Unsupported compute capability: {compute_capability}")
|
||||
return SMEM_CAPACITY_MAP[compute_capability]
|
||||
@ -159,7 +159,11 @@ class CutlassBaseDSL(BaseDSL):
|
||||
pipeline = super()._get_pipeline(pipeline)
|
||||
if pipeline == None:
|
||||
# cubin format is required to be cubin as we launch cuda module at python level.
|
||||
return "builtin.module(cute-to-nvvm{cubin-format=bin opt-level=3})"
|
||||
return (
|
||||
"builtin.module(cute-to-nvvm{cubin-format=bin "
|
||||
+ self.compile_options.to_str()
|
||||
+ "})"
|
||||
)
|
||||
|
||||
return pipeline
|
||||
|
||||
@ -294,13 +298,8 @@ class CutlassBaseDSL(BaseDSL):
|
||||
self, _CutlassIrKernelGenHelper, funcBody, *args, **kwargs
|
||||
)
|
||||
|
||||
def _get_globals(self):
|
||||
caller_globals = self.frame.f_globals
|
||||
caller_locals = self.frame.f_locals
|
||||
all_globals = globals().copy()
|
||||
all_globals.update(caller_globals)
|
||||
all_globals.update(caller_locals)
|
||||
return all_globals
|
||||
def _get_module_globals(self):
|
||||
return globals()
|
||||
|
||||
def _preprocess_launch_config_args(self, args, kwargs):
|
||||
"""Helper to preprocess args and kwargs for LaunchConfig"""
|
||||
@ -459,7 +458,10 @@ class KernelLauncher:
|
||||
|
||||
def _check_func_args(self, funcBody, *func_args, **func_kwargs):
|
||||
# Get function signature
|
||||
sig = inspect.signature(funcBody)
|
||||
if isinstance(funcBody, DSLCallable):
|
||||
sig = funcBody.get_signature()
|
||||
else:
|
||||
sig = inspect.signature(funcBody)
|
||||
|
||||
# func_args and func_kwargs should match funcBody's signature,
|
||||
# no extra or missing arguments.
|
||||
@ -485,6 +487,7 @@ class KernelLauncher:
|
||||
|
||||
ret, name = kernel_generator(*self.func_args, **self.func_kwargs, config=config)
|
||||
self.dsl.kernel_symbols.append(name)
|
||||
self.dsl.frame = None
|
||||
return ret.launch_op_ret
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
@ -537,14 +540,18 @@ def pack_from_irvalue(
|
||||
mixed_values[idx] = obj
|
||||
elif not isinstance(obj, type) and hasattr(obj, "__new_from_mlir_values__"):
|
||||
mixed_values[idx] = obj.__new_from_mlir_values__(chunk)
|
||||
elif isinstance(chunk, list) and chunk[0] is None:
|
||||
mixed_values[idx] = class_types[idx]
|
||||
else:
|
||||
try:
|
||||
if isinstance(chunk, list) and chunk[0] is None:
|
||||
mixed_values[idx] = class_types[idx]
|
||||
else:
|
||||
if len(chunk) == 1:
|
||||
try:
|
||||
mixed_values[idx] = t.as_numeric(chunk[0])
|
||||
except DSLRuntimeError as e:
|
||||
mixed_values[idx] = chunk[0]
|
||||
except ValueError:
|
||||
# Suppress the conversion error and try new_from_mlir_values below
|
||||
pass
|
||||
|
||||
if mixed_values[idx] is None:
|
||||
mixed_values[idx] = new_from_mlir_values(obj, chunk)
|
||||
|
||||
log().debug("------------------ ")
|
||||
for idx, packed in enumerate(mixed_values):
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
# Use `pip install -r requirements.txt` with the present file to install a
|
||||
# wheel consistent with the present state of the github repository
|
||||
nvidia-cutlass-dsl==4.1.0.dev0
|
||||
nvidia-cutlass-dsl==4.1.0
|
||||
|
||||
Reference in New Issue
Block a user