v4.2.1 update. (#2666)
This commit is contained in:
@ -1435,8 +1435,12 @@ private:
|
|||||||
is_same_v<FastF32NoSmemWarpSpecialized2Sm, EpilogueScheduleType> ||
|
is_same_v<FastF32NoSmemWarpSpecialized2Sm, EpilogueScheduleType> ||
|
||||||
is_same_v<PtrArrayFastF32NoSmemWarpSpecialized1Sm, EpilogueScheduleType> ||
|
is_same_v<PtrArrayFastF32NoSmemWarpSpecialized1Sm, EpilogueScheduleType> ||
|
||||||
is_same_v<PtrArrayFastF32NoSmemWarpSpecialized2Sm, EpilogueScheduleType>;
|
is_same_v<PtrArrayFastF32NoSmemWarpSpecialized2Sm, EpilogueScheduleType>;
|
||||||
// Input transform kernels - when dispatching to sm100 nosmem epilogue, go through the default path without EVT support.
|
static constexpr bool IsBlockwiseSchedule = is_same_v<BlockwiseNoSmemWarpSpecialized1Sm, EpilogueScheduleType> ||
|
||||||
static constexpr bool IsInputTransformSchedule = IsInterleavedComplex || IsFastF32Schedule;
|
is_same_v<BlockwiseNoSmemWarpSpecialized2Sm, EpilogueScheduleType> ||
|
||||||
|
is_same_v<PtrArrayBlockwiseNoSmemWarpSpecialized1Sm, EpilogueScheduleType> ||
|
||||||
|
is_same_v<PtrArrayBlockwiseNoSmemWarpSpecialized2Sm, EpilogueScheduleType>;
|
||||||
|
// Transform kernels - when dispatching to sm100 nosmem epilogue, go through the default path without EVT support.
|
||||||
|
static constexpr bool IsTransformSchedule = IsInterleavedComplex || IsFastF32Schedule || IsBlockwiseSchedule;
|
||||||
static_assert(Is1SmMma ^ Is2SmMma, "unsupported schedule");
|
static_assert(Is1SmMma ^ Is2SmMma, "unsupported schedule");
|
||||||
static_assert(not (Is2SmMma && size<0>(ClusterShape_MNK{}) % 2 == 1), "schedule + cluster mismatch");
|
static_assert(not (Is2SmMma && size<0>(ClusterShape_MNK{}) % 2 == 1), "schedule + cluster mismatch");
|
||||||
|
|
||||||
@ -1470,7 +1474,7 @@ private:
|
|||||||
static_assert(is_tuple_v<EpilogueTileType>, "Shape or Tile");
|
static_assert(is_tuple_v<EpilogueTileType>, "Shape or Tile");
|
||||||
return EpilogueTileType{};
|
return EpilogueTileType{};
|
||||||
}
|
}
|
||||||
else if constexpr (is_same_v<OpClass,arch::OpClassBlockScaledTensorOp> || not IsInputTransformSchedule) {
|
else if constexpr (is_same_v<OpClass,arch::OpClassBlockScaledTensorOp> || not IsTransformSchedule) {
|
||||||
// Save register usage for sm103 blockscaled kernels and sm100 cpasync kernels
|
// Save register usage for sm103 blockscaled kernels and sm100 cpasync kernels
|
||||||
// to avoid register spilling.
|
// to avoid register spilling.
|
||||||
constexpr int EpiM = size<0>(CtaTileShape_MNK{});
|
constexpr int EpiM = size<0>(CtaTileShape_MNK{});
|
||||||
@ -1501,7 +1505,7 @@ private:
|
|||||||
DisableSource ? thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default;
|
DisableSource ? thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default;
|
||||||
if constexpr (IsDefaultFusionOp<FusionOp>::value &&\
|
if constexpr (IsDefaultFusionOp<FusionOp>::value &&\
|
||||||
not is_same_v<OpClass, arch::OpClassBlockScaledTensorOp> && \
|
not is_same_v<OpClass, arch::OpClassBlockScaledTensorOp> && \
|
||||||
(IsInputTransformSchedule || \
|
(IsTransformSchedule || \
|
||||||
is_same_v<EpilogueScheduleType, PtrArrayNoSmemWarpSpecialized1Sm> || \
|
is_same_v<EpilogueScheduleType, PtrArrayNoSmemWarpSpecialized1Sm> || \
|
||||||
is_same_v<EpilogueScheduleType, PtrArrayNoSmemWarpSpecialized2Sm>)
|
is_same_v<EpilogueScheduleType, PtrArrayNoSmemWarpSpecialized2Sm>)
|
||||||
) {
|
) {
|
||||||
|
|||||||
@ -63,10 +63,14 @@ struct NoSmemWarpSpecialized1Sm {};
|
|||||||
struct NoSmemWarpSpecialized2Sm {};
|
struct NoSmemWarpSpecialized2Sm {};
|
||||||
struct FastF32NoSmemWarpSpecialized1Sm : NoSmemWarpSpecialized1Sm {};
|
struct FastF32NoSmemWarpSpecialized1Sm : NoSmemWarpSpecialized1Sm {};
|
||||||
struct FastF32NoSmemWarpSpecialized2Sm : NoSmemWarpSpecialized2Sm {};
|
struct FastF32NoSmemWarpSpecialized2Sm : NoSmemWarpSpecialized2Sm {};
|
||||||
|
struct BlockwiseNoSmemWarpSpecialized1Sm : NoSmemWarpSpecialized1Sm {};
|
||||||
|
struct BlockwiseNoSmemWarpSpecialized2Sm : NoSmemWarpSpecialized2Sm {};
|
||||||
struct PtrArrayNoSmemWarpSpecialized1Sm : NoSmemWarpSpecialized1Sm {};
|
struct PtrArrayNoSmemWarpSpecialized1Sm : NoSmemWarpSpecialized1Sm {};
|
||||||
struct PtrArrayNoSmemWarpSpecialized2Sm : NoSmemWarpSpecialized2Sm {};
|
struct PtrArrayNoSmemWarpSpecialized2Sm : NoSmemWarpSpecialized2Sm {};
|
||||||
struct PtrArrayFastF32NoSmemWarpSpecialized1Sm : PtrArrayNoSmemWarpSpecialized1Sm {};
|
struct PtrArrayFastF32NoSmemWarpSpecialized1Sm : PtrArrayNoSmemWarpSpecialized1Sm {};
|
||||||
struct PtrArrayFastF32NoSmemWarpSpecialized2Sm : PtrArrayNoSmemWarpSpecialized2Sm {};
|
struct PtrArrayFastF32NoSmemWarpSpecialized2Sm : PtrArrayNoSmemWarpSpecialized2Sm {};
|
||||||
|
struct PtrArrayBlockwiseNoSmemWarpSpecialized1Sm : PtrArrayNoSmemWarpSpecialized1Sm {};
|
||||||
|
struct PtrArrayBlockwiseNoSmemWarpSpecialized2Sm : PtrArrayNoSmemWarpSpecialized2Sm {};
|
||||||
// Blackwell TMA schedules
|
// Blackwell TMA schedules
|
||||||
struct TmaWarpSpecialized1Sm {};
|
struct TmaWarpSpecialized1Sm {};
|
||||||
struct TmaWarpSpecialized2Sm {};
|
struct TmaWarpSpecialized2Sm {};
|
||||||
|
|||||||
@ -55,3 +55,5 @@ LaunchConfig = _dsl.BaseDSL.LaunchConfig
|
|||||||
register_jit_arg_adapter = _dsl.JitArgAdapterRegistry.register_jit_arg_adapter
|
register_jit_arg_adapter = _dsl.JitArgAdapterRegistry.register_jit_arg_adapter
|
||||||
gpu = _dsl.cutlass_gpu
|
gpu = _dsl.cutlass_gpu
|
||||||
cuda = _dsl.cuda_helpers
|
cuda = _dsl.cuda_helpers
|
||||||
|
|
||||||
|
CACHE_FILE = "compiled_cache.db"
|
||||||
|
|||||||
@ -20,6 +20,7 @@ import warnings
|
|||||||
import inspect
|
import inspect
|
||||||
from types import BuiltinFunctionType
|
from types import BuiltinFunctionType
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
from inspect import getmembers
|
||||||
|
|
||||||
from .utils.logger import log
|
from .utils.logger import log
|
||||||
from .common import *
|
from .common import *
|
||||||
@ -579,3 +580,37 @@ def redirect_builtin_function(fcn):
|
|||||||
if isinstance(fcn, BuiltinFunctionType) and executor._builtin_redirector:
|
if isinstance(fcn, BuiltinFunctionType) and executor._builtin_redirector:
|
||||||
return executor._builtin_redirector(fcn)
|
return executor._builtin_redirector(fcn)
|
||||||
return fcn
|
return fcn
|
||||||
|
|
||||||
|
|
||||||
|
def copy_members(dest, src):
|
||||||
|
"""
|
||||||
|
Copies all non-callable, non-dunder members from src to dest if they exist in src.
|
||||||
|
Skips members that are callables or have names starting with double underscores.
|
||||||
|
"""
|
||||||
|
if id(dest) == id(src):
|
||||||
|
return
|
||||||
|
|
||||||
|
members = getmembers(dest)
|
||||||
|
for name, value in members:
|
||||||
|
if (
|
||||||
|
name.startswith("__")
|
||||||
|
or isinstance(value, Callable)
|
||||||
|
or not hasattr(src, name)
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
setattr(dest, name, getattr(src, name))
|
||||||
|
|
||||||
|
|
||||||
|
def get_locals_or_none(locals, symbols):
|
||||||
|
"""
|
||||||
|
Given a locals() dictionary and a list of symbol names, return a list of their values
|
||||||
|
in the same order as the symbols list. If a symbol is not present in locals, None is returned
|
||||||
|
for that symbol.
|
||||||
|
"""
|
||||||
|
variables = []
|
||||||
|
for symbol in symbols:
|
||||||
|
if symbol in locals:
|
||||||
|
variables.append(locals[symbol])
|
||||||
|
else:
|
||||||
|
variables.append(None)
|
||||||
|
return variables
|
||||||
|
|||||||
@ -668,12 +668,7 @@ class DSLPreprocessor(ast.NodeTransformer):
|
|||||||
ast.keyword(arg="prefetch_stages", value=prefetch_stages),
|
ast.keyword(arg="prefetch_stages", value=prefetch_stages),
|
||||||
ast.keyword(
|
ast.keyword(
|
||||||
arg="write_args",
|
arg="write_args",
|
||||||
value=ast.List(
|
value=self.generate_get_locals_or_none_call(write_args),
|
||||||
elts=[
|
|
||||||
ast.Name(id=arg, ctx=ast.Load()) for arg in write_args
|
|
||||||
],
|
|
||||||
ctx=ast.Load(),
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
ast.keyword(
|
ast.keyword(
|
||||||
arg="full_write_args_count",
|
arg="full_write_args_count",
|
||||||
@ -707,28 +702,6 @@ class DSLPreprocessor(ast.NodeTransformer):
|
|||||||
node,
|
node,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_loop_call(self, func_name, iter_args):
|
|
||||||
"""
|
|
||||||
Assigns the returned value from the loop function directly (without a tuple unpacking).
|
|
||||||
"""
|
|
||||||
if len(iter_args) == 0:
|
|
||||||
return ast.Expr(value=ast.Name(id=func_name, ctx=ast.Load()))
|
|
||||||
elif len(iter_args) == 1:
|
|
||||||
return ast.Assign(
|
|
||||||
targets=[ast.Name(id=iter_args[0], ctx=ast.Store())],
|
|
||||||
value=ast.Name(id=func_name, ctx=ast.Load()),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return ast.Assign(
|
|
||||||
targets=[
|
|
||||||
ast.Tuple(
|
|
||||||
elts=[ast.Name(id=var, ctx=ast.Store()) for var in iter_args],
|
|
||||||
ctx=ast.Store(),
|
|
||||||
)
|
|
||||||
],
|
|
||||||
value=ast.Name(id=func_name, ctx=ast.Load()),
|
|
||||||
)
|
|
||||||
|
|
||||||
def visit_BoolOp(self, node):
|
def visit_BoolOp(self, node):
|
||||||
# Visit child nodes first
|
# Visit child nodes first
|
||||||
self.generic_visit(node)
|
self.generic_visit(node)
|
||||||
@ -1140,10 +1113,10 @@ class DSLPreprocessor(ast.NodeTransformer):
|
|||||||
full_write_args_count,
|
full_write_args_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
assign = ast.copy_location(self.create_loop_call(func_name, write_args), node)
|
assign = self.create_cf_call(func_name, write_args, node)
|
||||||
|
|
||||||
# This should work fine as it modifies the AST structure
|
# This should work fine as it modifies the AST structure
|
||||||
exprs = exprs + [func_def, assign]
|
exprs = exprs + [func_def] + assign
|
||||||
|
|
||||||
if target_var_is_active_before_loop:
|
if target_var_is_active_before_loop:
|
||||||
# Create a new assignment to the target variable
|
# Create a new assignment to the target variable
|
||||||
@ -1429,11 +1402,9 @@ class DSLPreprocessor(ast.NodeTransformer):
|
|||||||
func_def = self.create_while_function(
|
func_def = self.create_while_function(
|
||||||
func_name, node, write_args, full_write_args_count
|
func_name, node, write_args, full_write_args_count
|
||||||
)
|
)
|
||||||
assign = ast.copy_location(
|
assign = self.create_cf_call(func_name, write_args, node)
|
||||||
self.create_loop_call(func_name, write_args), node
|
|
||||||
)
|
|
||||||
|
|
||||||
return [func_def, assign]
|
return [func_def] + assign
|
||||||
|
|
||||||
def visit_Try(self, node):
|
def visit_Try(self, node):
|
||||||
with self.scope_manager:
|
with self.scope_manager:
|
||||||
@ -1447,17 +1418,27 @@ class DSLPreprocessor(ast.NodeTransformer):
|
|||||||
self.generic_visit(node)
|
self.generic_visit(node)
|
||||||
return node
|
return node
|
||||||
|
|
||||||
def create_if_call(self, func_name, yield_args):
|
def create_cf_call(self, func_name, yield_args, node):
|
||||||
"""Creates the assignment statement for the if function call"""
|
"""Creates the assignment statement for the if function call"""
|
||||||
if not yield_args:
|
if not yield_args:
|
||||||
return ast.Expr(value=ast.Name(id=func_name, ctx=ast.Load()))
|
return [
|
||||||
elif len(yield_args) == 1:
|
ast.copy_location(
|
||||||
return ast.Assign(
|
ast.Expr(value=ast.Name(id=func_name, ctx=ast.Load())), node
|
||||||
|
)
|
||||||
|
]
|
||||||
|
has_self = False
|
||||||
|
for i, arg in enumerate(yield_args):
|
||||||
|
if arg == "self":
|
||||||
|
has_self = True
|
||||||
|
yield_args[i] = "yield_self"
|
||||||
|
break
|
||||||
|
if len(yield_args) == 1:
|
||||||
|
assign = ast.Assign(
|
||||||
targets=[ast.Name(id=yield_args[0], ctx=ast.Store())],
|
targets=[ast.Name(id=yield_args[0], ctx=ast.Store())],
|
||||||
value=ast.Name(id=func_name, ctx=ast.Load()),
|
value=ast.Name(id=func_name, ctx=ast.Load()),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return ast.Assign(
|
assign = ast.Assign(
|
||||||
targets=[
|
targets=[
|
||||||
ast.Tuple(
|
ast.Tuple(
|
||||||
elts=[ast.Name(id=var, ctx=ast.Store()) for var in yield_args],
|
elts=[ast.Name(id=var, ctx=ast.Store()) for var in yield_args],
|
||||||
@ -1467,6 +1448,23 @@ class DSLPreprocessor(ast.NodeTransformer):
|
|||||||
value=ast.Name(id=func_name, ctx=ast.Load()),
|
value=ast.Name(id=func_name, ctx=ast.Load()),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if has_self:
|
||||||
|
fix_self = ast.Expr(
|
||||||
|
value=ast.Call(
|
||||||
|
func=self._create_module_attribute(
|
||||||
|
"copy_members", lineno=node.lineno, col_offset=node.col_offset
|
||||||
|
),
|
||||||
|
args=[
|
||||||
|
ast.Name(id="self", ctx=ast.Load()),
|
||||||
|
ast.Name(id="yield_self", ctx=ast.Load()),
|
||||||
|
],
|
||||||
|
keywords=[],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return [ast.copy_location(assign, node), ast.copy_location(fix_self, node)]
|
||||||
|
else:
|
||||||
|
return [ast.copy_location(assign, node)]
|
||||||
|
|
||||||
def visit_IfExp(self, node):
|
def visit_IfExp(self, node):
|
||||||
"""
|
"""
|
||||||
Visits an inline if-else expression (ternary operator).
|
Visits an inline if-else expression (ternary operator).
|
||||||
@ -1567,9 +1565,24 @@ class DSLPreprocessor(ast.NodeTransformer):
|
|||||||
func_def = self.create_if_function(
|
func_def = self.create_if_function(
|
||||||
func_name, node, yield_args, full_write_args_count
|
func_name, node, yield_args, full_write_args_count
|
||||||
)
|
)
|
||||||
assign = ast.copy_location(self.create_if_call(func_name, yield_args), node)
|
assign = self.create_cf_call(func_name, yield_args, node)
|
||||||
|
|
||||||
return [func_def, assign]
|
return [func_def] + assign
|
||||||
|
|
||||||
|
def generate_get_locals_or_none_call(self, write_args):
|
||||||
|
return ast.Call(
|
||||||
|
func=self._create_module_attribute("get_locals_or_none"),
|
||||||
|
args=[
|
||||||
|
ast.Call(
|
||||||
|
func=ast.Name(id="locals", ctx=ast.Load()), args=[], keywords=[]
|
||||||
|
),
|
||||||
|
ast.List(
|
||||||
|
elts=[ast.Constant(value=arg) for arg in write_args],
|
||||||
|
ctx=ast.Load(),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
keywords=[],
|
||||||
|
)
|
||||||
|
|
||||||
def create_if_function(self, func_name, node, write_args, full_write_args_count):
|
def create_if_function(self, func_name, node, write_args, full_write_args_count):
|
||||||
test_expr = self.visit(node.test)
|
test_expr = self.visit(node.test)
|
||||||
@ -1627,10 +1640,7 @@ class DSLPreprocessor(ast.NodeTransformer):
|
|||||||
), # ast.Name(id="pred", ctx=ast.Load())
|
), # ast.Name(id="pred", ctx=ast.Load())
|
||||||
ast.keyword(
|
ast.keyword(
|
||||||
arg="write_args",
|
arg="write_args",
|
||||||
value=ast.List(
|
value=self.generate_get_locals_or_none_call(write_args),
|
||||||
elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in write_args],
|
|
||||||
ctx=ast.Load(),
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -1813,10 +1823,7 @@ class DSLPreprocessor(ast.NodeTransformer):
|
|||||||
ast.keyword(arg="pred", value=test_expr),
|
ast.keyword(arg="pred", value=test_expr),
|
||||||
ast.keyword(
|
ast.keyword(
|
||||||
arg="write_args",
|
arg="write_args",
|
||||||
value=ast.List(
|
value=self.generate_get_locals_or_none_call(write_args),
|
||||||
elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in write_args],
|
|
||||||
ctx=ast.Load(),
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
decorator = ast.copy_location(
|
decorator = ast.copy_location(
|
||||||
|
|||||||
@ -255,7 +255,13 @@ def initialize_cuda_context(device_id: int = 0, flags: int = 0):
|
|||||||
_log().info(f"{cuDevice} <-- cuDeviceGet")
|
_log().info(f"{cuDevice} <-- cuDeviceGet")
|
||||||
# Create context
|
# Create context
|
||||||
_log().info(f"cuCtxCreate {0} {cuDevice}")
|
_log().info(f"cuCtxCreate {0} {cuDevice}")
|
||||||
context = checkCudaErrors(cuda.cuCtxCreate(0, cuDevice))
|
if cuda.CUDA_VERSION >= 13000:
|
||||||
|
# Use cuCtxCreate_v4 API with explicit CUctxCreateParams None, since v2
|
||||||
|
# and v3 API has been removed from CTK 13.
|
||||||
|
# See https://github.com/NVIDIA/cuda-python/pull/792
|
||||||
|
context = checkCudaErrors(cuda.cuCtxCreate(None, 0, cuDevice))
|
||||||
|
else:
|
||||||
|
context = checkCudaErrors(cuda.cuCtxCreate(0, cuDevice))
|
||||||
_log().info(f"{context} <-- cuCtxCreate")
|
_log().info(f"{context} <-- cuCtxCreate")
|
||||||
|
|
||||||
return context
|
return context
|
||||||
|
|||||||
@ -47,7 +47,8 @@ def setup_log(
|
|||||||
if log_to_console or log_to_file:
|
if log_to_console or log_to_file:
|
||||||
logger.setLevel(log_level)
|
logger.setLevel(log_level)
|
||||||
else:
|
else:
|
||||||
logger.setLevel(logging.NOTSET)
|
# Makes sure logging is OFF
|
||||||
|
logger.setLevel(logging.CRITICAL + 1)
|
||||||
|
|
||||||
# Clear existing handlers to prevent duplicate logs
|
# Clear existing handlers to prevent duplicate logs
|
||||||
if logger.hasHandlers():
|
if logger.hasHandlers():
|
||||||
|
|||||||
@ -31,6 +31,8 @@ from ..base_dsl.ast_helpers import (
|
|||||||
range_perf_warning,
|
range_perf_warning,
|
||||||
cf_symbol_check,
|
cf_symbol_check,
|
||||||
redirect_builtin_function,
|
redirect_builtin_function,
|
||||||
|
copy_members,
|
||||||
|
get_locals_or_none,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..base_dsl import *
|
from ..base_dsl import *
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
# Use `pip install -r requirements.txt` with the present file to install a
|
# Use `pip install -r requirements.txt` with the present file to install a
|
||||||
# wheel consistent with the present state of the github repository
|
# wheel consistent with the present state of the github repository
|
||||||
nvidia-cutlass-dsl==4.2.0
|
nvidia-cutlass-dsl==4.2.1
|
||||||
|
|||||||
@ -7467,8 +7467,9 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version,
|
|||||||
|
|
||||||
kernel_schedule = to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100, grouped)
|
kernel_schedule = to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100, grouped)
|
||||||
epi_schedule = to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)
|
epi_schedule = to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)
|
||||||
|
epi_schedule_nosmem = to_grouped_schedule(EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm, grouped)
|
||||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
||||||
[[kernel_schedule, epi_schedule]],
|
[[kernel_schedule, epi_schedule], [kernel_schedule, epi_schedule_nosmem]],
|
||||||
tile_schedulers=tile_schedulers, gemm_kind=gemm_kind)
|
tile_schedulers=tile_schedulers, gemm_kind=gemm_kind)
|
||||||
|
|
||||||
def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.Universal3x):
|
def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.Universal3x):
|
||||||
|
|||||||
@ -811,10 +811,14 @@ class EpilogueScheduleType(enum.Enum):
|
|||||||
NoSmemWarpSpecialized2Sm = enum_auto()
|
NoSmemWarpSpecialized2Sm = enum_auto()
|
||||||
FastF32NoSmemWarpSpecialized1Sm = enum_auto()
|
FastF32NoSmemWarpSpecialized1Sm = enum_auto()
|
||||||
FastF32NoSmemWarpSpecialized2Sm = enum_auto()
|
FastF32NoSmemWarpSpecialized2Sm = enum_auto()
|
||||||
|
BlockwiseNoSmemWarpSpecialized1Sm = enum_auto()
|
||||||
|
BlockwiseNoSmemWarpSpecialized2Sm = enum_auto()
|
||||||
PtrArrayNoSmemWarpSpecialized1Sm = enum_auto()
|
PtrArrayNoSmemWarpSpecialized1Sm = enum_auto()
|
||||||
PtrArrayNoSmemWarpSpecialized2Sm = enum_auto()
|
PtrArrayNoSmemWarpSpecialized2Sm = enum_auto()
|
||||||
PtrArrayFastF32NoSmemWarpSpecialized1Sm = enum_auto()
|
PtrArrayFastF32NoSmemWarpSpecialized1Sm = enum_auto()
|
||||||
PtrArrayFastF32NoSmemWarpSpecialized2Sm = enum_auto()
|
PtrArrayFastF32NoSmemWarpSpecialized2Sm = enum_auto()
|
||||||
|
PtrArrayBlockwiseNoSmemWarpSpecialized1Sm = enum_auto()
|
||||||
|
PtrArrayBlockwiseNoSmemWarpSpecialized2Sm = enum_auto()
|
||||||
TmaWarpSpecialized = enum_auto()
|
TmaWarpSpecialized = enum_auto()
|
||||||
TmaWarpSpecializedCooperative = enum_auto()
|
TmaWarpSpecializedCooperative = enum_auto()
|
||||||
TmaWarpSpecialized1Sm = enum_auto()
|
TmaWarpSpecialized1Sm = enum_auto()
|
||||||
@ -834,10 +838,14 @@ EpilogueScheduleTag = {
|
|||||||
EpilogueScheduleType.NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::NoSmemWarpSpecialized2Sm',
|
EpilogueScheduleType.NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::NoSmemWarpSpecialized2Sm',
|
||||||
EpilogueScheduleType.FastF32NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::FastF32NoSmemWarpSpecialized1Sm',
|
EpilogueScheduleType.FastF32NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::FastF32NoSmemWarpSpecialized1Sm',
|
||||||
EpilogueScheduleType.FastF32NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::FastF32NoSmemWarpSpecialized2Sm',
|
EpilogueScheduleType.FastF32NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::FastF32NoSmemWarpSpecialized2Sm',
|
||||||
|
EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm',
|
||||||
|
EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::BlockwiseNoSmemWarpSpecialized2Sm',
|
||||||
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized1Sm',
|
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized1Sm',
|
||||||
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized2Sm',
|
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized2Sm',
|
||||||
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayFastF32NoSmemWarpSpecialized1Sm',
|
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayFastF32NoSmemWarpSpecialized1Sm',
|
||||||
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayFastF32NoSmemWarpSpecialized2Sm',
|
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayFastF32NoSmemWarpSpecialized2Sm',
|
||||||
|
EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayBlockwiseNoSmemWarpSpecialized1Sm',
|
||||||
|
EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayBlockwiseNoSmemWarpSpecialized2Sm',
|
||||||
EpilogueScheduleType.TmaWarpSpecialized: 'cutlass::epilogue::TmaWarpSpecialized',
|
EpilogueScheduleType.TmaWarpSpecialized: 'cutlass::epilogue::TmaWarpSpecialized',
|
||||||
EpilogueScheduleType.TmaWarpSpecializedCooperative: 'cutlass::epilogue::TmaWarpSpecializedCooperative',
|
EpilogueScheduleType.TmaWarpSpecializedCooperative: 'cutlass::epilogue::TmaWarpSpecializedCooperative',
|
||||||
EpilogueScheduleType.TmaWarpSpecialized1Sm: 'cutlass::epilogue::TmaWarpSpecialized1Sm',
|
EpilogueScheduleType.TmaWarpSpecialized1Sm: 'cutlass::epilogue::TmaWarpSpecialized1Sm',
|
||||||
@ -858,10 +866,14 @@ EpilogueScheduleSuffixes = {
|
|||||||
EpilogueScheduleType.NoSmemWarpSpecialized2Sm: '_epi_nosmem',
|
EpilogueScheduleType.NoSmemWarpSpecialized2Sm: '_epi_nosmem',
|
||||||
EpilogueScheduleType.FastF32NoSmemWarpSpecialized1Sm: '_epi_nosmem_fastf32',
|
EpilogueScheduleType.FastF32NoSmemWarpSpecialized1Sm: '_epi_nosmem_fastf32',
|
||||||
EpilogueScheduleType.FastF32NoSmemWarpSpecialized2Sm: '_epi_nosmem_fastf32',
|
EpilogueScheduleType.FastF32NoSmemWarpSpecialized2Sm: '_epi_nosmem_fastf32',
|
||||||
|
EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm: '_epi_nosmem',
|
||||||
|
EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized2Sm: '_epi_nosmem',
|
||||||
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: '_epi_nosmem',
|
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: '_epi_nosmem',
|
||||||
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: '_epi_nosmem',
|
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: '_epi_nosmem',
|
||||||
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized1Sm: '_epi_nosmem_fastf32',
|
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized1Sm: '_epi_nosmem_fastf32',
|
||||||
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized2Sm: '_epi_nosmem_fastf32',
|
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized2Sm: '_epi_nosmem_fastf32',
|
||||||
|
EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized1Sm: '_epi_nosmem',
|
||||||
|
EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized2Sm: '_epi_nosmem',
|
||||||
EpilogueScheduleType.TmaWarpSpecialized: '_epi_tma',
|
EpilogueScheduleType.TmaWarpSpecialized: '_epi_tma',
|
||||||
EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma',
|
EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma',
|
||||||
EpilogueScheduleType.TmaWarpSpecialized1Sm: '',
|
EpilogueScheduleType.TmaWarpSpecialized1Sm: '',
|
||||||
@ -926,6 +938,8 @@ def to_grouped_schedule(schedule, grouped):
|
|||||||
EpilogueScheduleType.TmaWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm,
|
EpilogueScheduleType.TmaWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm,
|
||||||
EpilogueScheduleType.NoSmemWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm,
|
EpilogueScheduleType.NoSmemWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm,
|
||||||
EpilogueScheduleType.NoSmemWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm,
|
EpilogueScheduleType.NoSmemWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm,
|
||||||
|
EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized1Sm,
|
||||||
|
EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized2Sm,
|
||||||
# SM103
|
# SM103
|
||||||
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103,
|
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103,
|
||||||
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103,
|
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103,
|
||||||
|
|||||||
@ -69,6 +69,7 @@ template<cute::UMMA::Major SFAMajor,
|
|||||||
int ScaleGranularityN,
|
int ScaleGranularityN,
|
||||||
int ScaleGranularityK,
|
int ScaleGranularityK,
|
||||||
bool Is2SM,
|
bool Is2SM,
|
||||||
|
bool NoSmemEpilogue,
|
||||||
class LayoutA,
|
class LayoutA,
|
||||||
class LayoutB,
|
class LayoutB,
|
||||||
class LayoutCD,
|
class LayoutCD,
|
||||||
@ -77,8 +78,10 @@ template<cute::UMMA::Major SFAMajor,
|
|||||||
bool groupwise_test(
|
bool groupwise_test(
|
||||||
Int<ScaleGranularityM>, Int<ScaleGranularityN>, Int<ScaleGranularityK>, C<Is2SM>,
|
Int<ScaleGranularityM>, Int<ScaleGranularityN>, Int<ScaleGranularityK>, C<Is2SM>,
|
||||||
LayoutA, LayoutB, LayoutCD,
|
LayoutA, LayoutB, LayoutCD,
|
||||||
MmaTileShape, ClusterShape) {
|
MmaTileShape, ClusterShape,
|
||||||
|
C<NoSmemEpilogue>) {
|
||||||
|
using Epilogue1SM = conditional_t<NoSmemEpilogue, cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm, cutlass::epilogue::TmaWarpSpecialized1Sm>;
|
||||||
|
using Epilogue2SM = conditional_t<NoSmemEpilogue, cutlass::epilogue::BlockwiseNoSmemWarpSpecialized2Sm, cutlass::epilogue::TmaWarpSpecialized2Sm>;
|
||||||
using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, SFAMajor, SFBMajor>;
|
using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, SFAMajor, SFBMajor>;
|
||||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
|
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
|
||||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
|
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
|
||||||
@ -90,7 +93,7 @@ bool groupwise_test(
|
|||||||
float, float,
|
float, float,
|
||||||
cutlass::float_e4m3_t, LayoutCD, 16,
|
cutlass::float_e4m3_t, LayoutCD, 16,
|
||||||
cutlass::float_e4m3_t, LayoutCD, 16,
|
cutlass::float_e4m3_t, LayoutCD, 16,
|
||||||
conditional_t<Is2SM, cutlass::epilogue::TmaWarpSpecialized2Sm, cutlass::epilogue::TmaWarpSpecialized1Sm>
|
conditional_t<Is2SM, Epilogue2SM, Epilogue1SM>
|
||||||
>::CollectiveOp;
|
>::CollectiveOp;
|
||||||
|
|
||||||
using CollectiveMainloop =
|
using CollectiveMainloop =
|
||||||
@ -259,11 +262,26 @@ TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_1sm_f32_align16_blockwise, 128
|
|||||||
cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{},
|
cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{},
|
||||||
cutlass::layout::RowMajor{},
|
cutlass::layout::RowMajor{},
|
||||||
Shape<_128,_128,_128>{},
|
Shape<_128,_128,_128>{},
|
||||||
Shape<_1,_1,_1>{});
|
Shape<_1,_1,_1>{},
|
||||||
|
false_type{});
|
||||||
|
|
||||||
EXPECT_TRUE(passed);
|
EXPECT_TRUE(passed);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_1sm_f32_align16_blockwise, 128x128x128_1x1x1_2x2x32_scale_direct_store) {
|
||||||
|
|
||||||
|
bool passed = groupwise_test<UMMA::Major::MN, UMMA::Major::K>(
|
||||||
|
Int<2>{}, Int<2>{}, Int<32>{}, false_type{},
|
||||||
|
cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{},
|
||||||
|
cutlass::layout::RowMajor{},
|
||||||
|
Shape<_128,_128,_128>{},
|
||||||
|
Shape<_1,_1,_1>{},
|
||||||
|
true_type{});
|
||||||
|
|
||||||
|
EXPECT_TRUE(passed);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_2sm_f32_align16_blockwise, 256x128x128_2x1x1_64x4x32_scale) {
|
TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_2sm_f32_align16_blockwise, 256x128x128_2x1x1_64x4x32_scale) {
|
||||||
|
|
||||||
bool passed = groupwise_test<UMMA::Major::MN, UMMA::Major::MN>(
|
bool passed = groupwise_test<UMMA::Major::MN, UMMA::Major::MN>(
|
||||||
@ -271,7 +289,8 @@ TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_2sm_f32_align16_blockwise, 256
|
|||||||
cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{},
|
cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{},
|
||||||
cutlass::layout::RowMajor{},
|
cutlass::layout::RowMajor{},
|
||||||
Shape<_256,_128,_128>{},
|
Shape<_256,_128,_128>{},
|
||||||
Shape<_2,_1,_1>{});
|
Shape<_2,_1,_1>{},
|
||||||
|
false_type{});
|
||||||
|
|
||||||
EXPECT_TRUE(passed);
|
EXPECT_TRUE(passed);
|
||||||
|
|
||||||
@ -284,7 +303,8 @@ TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_1sm_f32_align16_blockwise, 128
|
|||||||
cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{},
|
cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{},
|
||||||
cutlass::layout::RowMajor{},
|
cutlass::layout::RowMajor{},
|
||||||
Shape<_128,_128,_128>{},
|
Shape<_128,_128,_128>{},
|
||||||
Shape<_1,_1,_1>{});
|
Shape<_1,_1,_1>{},
|
||||||
|
false_type{});
|
||||||
|
|
||||||
EXPECT_TRUE(passed);
|
EXPECT_TRUE(passed);
|
||||||
|
|
||||||
@ -297,7 +317,8 @@ TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_2sm_f32_align16_blockwise, 256
|
|||||||
cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{},
|
cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{},
|
||||||
cutlass::layout::RowMajor{},
|
cutlass::layout::RowMajor{},
|
||||||
Shape<_256,_128,_128>{},
|
Shape<_256,_128,_128>{},
|
||||||
Shape<_2,_1,_1>{});
|
Shape<_2,_1,_1>{},
|
||||||
|
false_type{});
|
||||||
|
|
||||||
EXPECT_TRUE(passed);
|
EXPECT_TRUE(passed);
|
||||||
|
|
||||||
@ -311,7 +332,8 @@ TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_2sm_f32_align16_blockwise, 256
|
|||||||
cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{},
|
cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{},
|
||||||
cutlass::layout::RowMajor{},
|
cutlass::layout::RowMajor{},
|
||||||
Shape<_256,_128,_128>{},
|
Shape<_256,_128,_128>{},
|
||||||
Shape<_2,_1,_1>{});
|
Shape<_2,_1,_1>{},
|
||||||
|
false_type{});
|
||||||
|
|
||||||
EXPECT_TRUE(passed);
|
EXPECT_TRUE(passed);
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user