v4.2.1 update. (#2666)

This commit is contained in:
Junkai-Wu
2025-09-24 01:25:43 +08:00
committed by GitHub
parent 2b8dff1f90
commit 7a6d4ee099
12 changed files with 163 additions and 65 deletions

View File

@ -55,3 +55,5 @@ LaunchConfig = _dsl.BaseDSL.LaunchConfig
register_jit_arg_adapter = _dsl.JitArgAdapterRegistry.register_jit_arg_adapter
gpu = _dsl.cutlass_gpu
cuda = _dsl.cuda_helpers
CACHE_FILE = "compiled_cache.db"

View File

@ -20,6 +20,7 @@ import warnings
import inspect
from types import BuiltinFunctionType
from functools import lru_cache
from inspect import getmembers
from .utils.logger import log
from .common import *
@ -579,3 +580,37 @@ def redirect_builtin_function(fcn):
if isinstance(fcn, BuiltinFunctionType) and executor._builtin_redirector:
return executor._builtin_redirector(fcn)
return fcn
def copy_members(dest, src):
"""
Copies all non-callable, non-dunder members from src to dest if they exist in src.
Skips members that are callables or have names starting with double underscores.
"""
if id(dest) == id(src):
return
members = getmembers(dest)
for name, value in members:
if (
name.startswith("__")
or isinstance(value, Callable)
or not hasattr(src, name)
):
continue
setattr(dest, name, getattr(src, name))
def get_locals_or_none(locals, symbols):
"""
Given a locals() dictionary and a list of symbol names, return a list of their values
in the same order as the symbols list. If a symbol is not present in locals, None is returned
for that symbol.
"""
variables = []
for symbol in symbols:
if symbol in locals:
variables.append(locals[symbol])
else:
variables.append(None)
return variables

View File

@ -668,12 +668,7 @@ class DSLPreprocessor(ast.NodeTransformer):
ast.keyword(arg="prefetch_stages", value=prefetch_stages),
ast.keyword(
arg="write_args",
value=ast.List(
elts=[
ast.Name(id=arg, ctx=ast.Load()) for arg in write_args
],
ctx=ast.Load(),
),
value=self.generate_get_locals_or_none_call(write_args),
),
ast.keyword(
arg="full_write_args_count",
@ -707,28 +702,6 @@ class DSLPreprocessor(ast.NodeTransformer):
node,
)
def create_loop_call(self, func_name, iter_args):
"""
Assigns the returned value from the loop function directly (without a tuple unpacking).
"""
if len(iter_args) == 0:
return ast.Expr(value=ast.Name(id=func_name, ctx=ast.Load()))
elif len(iter_args) == 1:
return ast.Assign(
targets=[ast.Name(id=iter_args[0], ctx=ast.Store())],
value=ast.Name(id=func_name, ctx=ast.Load()),
)
else:
return ast.Assign(
targets=[
ast.Tuple(
elts=[ast.Name(id=var, ctx=ast.Store()) for var in iter_args],
ctx=ast.Store(),
)
],
value=ast.Name(id=func_name, ctx=ast.Load()),
)
def visit_BoolOp(self, node):
# Visit child nodes first
self.generic_visit(node)
@ -1140,10 +1113,10 @@ class DSLPreprocessor(ast.NodeTransformer):
full_write_args_count,
)
assign = ast.copy_location(self.create_loop_call(func_name, write_args), node)
assign = self.create_cf_call(func_name, write_args, node)
# This should work fine as it modifies the AST structure
exprs = exprs + [func_def, assign]
exprs = exprs + [func_def] + assign
if target_var_is_active_before_loop:
# Create a new assignment to the target variable
@ -1429,11 +1402,9 @@ class DSLPreprocessor(ast.NodeTransformer):
func_def = self.create_while_function(
func_name, node, write_args, full_write_args_count
)
assign = ast.copy_location(
self.create_loop_call(func_name, write_args), node
)
assign = self.create_cf_call(func_name, write_args, node)
return [func_def, assign]
return [func_def] + assign
def visit_Try(self, node):
with self.scope_manager:
@ -1447,17 +1418,27 @@ class DSLPreprocessor(ast.NodeTransformer):
self.generic_visit(node)
return node
def create_if_call(self, func_name, yield_args):
def create_cf_call(self, func_name, yield_args, node):
"""Creates the assignment statement for the if function call"""
if not yield_args:
return ast.Expr(value=ast.Name(id=func_name, ctx=ast.Load()))
elif len(yield_args) == 1:
return ast.Assign(
return [
ast.copy_location(
ast.Expr(value=ast.Name(id=func_name, ctx=ast.Load())), node
)
]
has_self = False
for i, arg in enumerate(yield_args):
if arg == "self":
has_self = True
yield_args[i] = "yield_self"
break
if len(yield_args) == 1:
assign = ast.Assign(
targets=[ast.Name(id=yield_args[0], ctx=ast.Store())],
value=ast.Name(id=func_name, ctx=ast.Load()),
)
else:
return ast.Assign(
assign = ast.Assign(
targets=[
ast.Tuple(
elts=[ast.Name(id=var, ctx=ast.Store()) for var in yield_args],
@ -1467,6 +1448,23 @@ class DSLPreprocessor(ast.NodeTransformer):
value=ast.Name(id=func_name, ctx=ast.Load()),
)
if has_self:
fix_self = ast.Expr(
value=ast.Call(
func=self._create_module_attribute(
"copy_members", lineno=node.lineno, col_offset=node.col_offset
),
args=[
ast.Name(id="self", ctx=ast.Load()),
ast.Name(id="yield_self", ctx=ast.Load()),
],
keywords=[],
)
)
return [ast.copy_location(assign, node), ast.copy_location(fix_self, node)]
else:
return [ast.copy_location(assign, node)]
def visit_IfExp(self, node):
"""
Visits an inline if-else expression (ternary operator).
@ -1567,9 +1565,24 @@ class DSLPreprocessor(ast.NodeTransformer):
func_def = self.create_if_function(
func_name, node, yield_args, full_write_args_count
)
assign = ast.copy_location(self.create_if_call(func_name, yield_args), node)
assign = self.create_cf_call(func_name, yield_args, node)
return [func_def, assign]
return [func_def] + assign
def generate_get_locals_or_none_call(self, write_args):
return ast.Call(
func=self._create_module_attribute("get_locals_or_none"),
args=[
ast.Call(
func=ast.Name(id="locals", ctx=ast.Load()), args=[], keywords=[]
),
ast.List(
elts=[ast.Constant(value=arg) for arg in write_args],
ctx=ast.Load(),
),
],
keywords=[],
)
def create_if_function(self, func_name, node, write_args, full_write_args_count):
test_expr = self.visit(node.test)
@ -1627,10 +1640,7 @@ class DSLPreprocessor(ast.NodeTransformer):
), # ast.Name(id="pred", ctx=ast.Load())
ast.keyword(
arg="write_args",
value=ast.List(
elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in write_args],
ctx=ast.Load(),
),
value=self.generate_get_locals_or_none_call(write_args),
),
]
@ -1813,10 +1823,7 @@ class DSLPreprocessor(ast.NodeTransformer):
ast.keyword(arg="pred", value=test_expr),
ast.keyword(
arg="write_args",
value=ast.List(
elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in write_args],
ctx=ast.Load(),
),
value=self.generate_get_locals_or_none_call(write_args),
),
]
decorator = ast.copy_location(

View File

@ -255,7 +255,13 @@ def initialize_cuda_context(device_id: int = 0, flags: int = 0):
_log().info(f"{cuDevice} <-- cuDeviceGet")
# Create context
_log().info(f"cuCtxCreate {0} {cuDevice}")
context = checkCudaErrors(cuda.cuCtxCreate(0, cuDevice))
if cuda.CUDA_VERSION >= 13000:
# Use cuCtxCreate_v4 API with explicit CUctxCreateParams None, since v2
# and v3 API has been removed from CTK 13.
# See https://github.com/NVIDIA/cuda-python/pull/792
context = checkCudaErrors(cuda.cuCtxCreate(None, 0, cuDevice))
else:
context = checkCudaErrors(cuda.cuCtxCreate(0, cuDevice))
_log().info(f"{context} <-- cuCtxCreate")
return context

View File

@ -47,7 +47,8 @@ def setup_log(
if log_to_console or log_to_file:
logger.setLevel(log_level)
else:
logger.setLevel(logging.NOTSET)
# Makes sure logging is OFF
logger.setLevel(logging.CRITICAL + 1)
# Clear existing handlers to prevent duplicate logs
if logger.hasHandlers():

View File

@ -31,6 +31,8 @@ from ..base_dsl.ast_helpers import (
range_perf_warning,
cf_symbol_check,
redirect_builtin_function,
copy_members,
get_locals_or_none,
)
from ..base_dsl import *

View File

@ -1,3 +1,3 @@
# Use `pip install -r requirements.txt` with the present file to install a
# wheel consistent with the present state of the github repository
nvidia-cutlass-dsl==4.2.0
nvidia-cutlass-dsl==4.2.1

View File

@ -7467,8 +7467,9 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version,
kernel_schedule = to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100, grouped)
epi_schedule = to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)
epi_schedule_nosmem = to_grouped_schedule(EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm, grouped)
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
[[kernel_schedule, epi_schedule]],
[[kernel_schedule, epi_schedule], [kernel_schedule, epi_schedule_nosmem]],
tile_schedulers=tile_schedulers, gemm_kind=gemm_kind)
def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.Universal3x):

View File

@ -811,10 +811,14 @@ class EpilogueScheduleType(enum.Enum):
NoSmemWarpSpecialized2Sm = enum_auto()
FastF32NoSmemWarpSpecialized1Sm = enum_auto()
FastF32NoSmemWarpSpecialized2Sm = enum_auto()
BlockwiseNoSmemWarpSpecialized1Sm = enum_auto()
BlockwiseNoSmemWarpSpecialized2Sm = enum_auto()
PtrArrayNoSmemWarpSpecialized1Sm = enum_auto()
PtrArrayNoSmemWarpSpecialized2Sm = enum_auto()
PtrArrayFastF32NoSmemWarpSpecialized1Sm = enum_auto()
PtrArrayFastF32NoSmemWarpSpecialized2Sm = enum_auto()
PtrArrayBlockwiseNoSmemWarpSpecialized1Sm = enum_auto()
PtrArrayBlockwiseNoSmemWarpSpecialized2Sm = enum_auto()
TmaWarpSpecialized = enum_auto()
TmaWarpSpecializedCooperative = enum_auto()
TmaWarpSpecialized1Sm = enum_auto()
@ -834,10 +838,14 @@ EpilogueScheduleTag = {
EpilogueScheduleType.NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::NoSmemWarpSpecialized2Sm',
EpilogueScheduleType.FastF32NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::FastF32NoSmemWarpSpecialized1Sm',
EpilogueScheduleType.FastF32NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::FastF32NoSmemWarpSpecialized2Sm',
EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm',
EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::BlockwiseNoSmemWarpSpecialized2Sm',
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized1Sm',
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized2Sm',
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayFastF32NoSmemWarpSpecialized1Sm',
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayFastF32NoSmemWarpSpecialized2Sm',
EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayBlockwiseNoSmemWarpSpecialized1Sm',
EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayBlockwiseNoSmemWarpSpecialized2Sm',
EpilogueScheduleType.TmaWarpSpecialized: 'cutlass::epilogue::TmaWarpSpecialized',
EpilogueScheduleType.TmaWarpSpecializedCooperative: 'cutlass::epilogue::TmaWarpSpecializedCooperative',
EpilogueScheduleType.TmaWarpSpecialized1Sm: 'cutlass::epilogue::TmaWarpSpecialized1Sm',
@ -858,10 +866,14 @@ EpilogueScheduleSuffixes = {
EpilogueScheduleType.NoSmemWarpSpecialized2Sm: '_epi_nosmem',
EpilogueScheduleType.FastF32NoSmemWarpSpecialized1Sm: '_epi_nosmem_fastf32',
EpilogueScheduleType.FastF32NoSmemWarpSpecialized2Sm: '_epi_nosmem_fastf32',
EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm: '_epi_nosmem',
EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized2Sm: '_epi_nosmem',
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: '_epi_nosmem',
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: '_epi_nosmem',
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized1Sm: '_epi_nosmem_fastf32',
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized2Sm: '_epi_nosmem_fastf32',
EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized1Sm: '_epi_nosmem',
EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized2Sm: '_epi_nosmem',
EpilogueScheduleType.TmaWarpSpecialized: '_epi_tma',
EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma',
EpilogueScheduleType.TmaWarpSpecialized1Sm: '',
@ -926,6 +938,8 @@ def to_grouped_schedule(schedule, grouped):
EpilogueScheduleType.TmaWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm,
EpilogueScheduleType.NoSmemWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm,
EpilogueScheduleType.NoSmemWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm,
EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized1Sm,
EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized2Sm,
# SM103
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103,
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103,