From ee914c3cec1a619881e481b96da536fff21fe9cd Mon Sep 17 00:00:00 2001 From: Junkai-Wu Date: Wed, 24 Sep 2025 02:25:14 +0800 Subject: [PATCH] v4.2.1 update. (#2667) --- .../collective/builders/sm100_builder.inl | 12 +- include/cutlass/epilogue/dispatch_policy.hpp | 4 + python/CuTeDSL/base_dsl/ast_helpers.py | 35 ++++++ python/CuTeDSL/base_dsl/ast_preprocessor.py | 105 ++++++++++-------- python/CuTeDSL/base_dsl/runtime/cuda.py | 8 +- python/CuTeDSL/base_dsl/utils/logger.py | 3 +- python/CuTeDSL/cutlass/__init__.py | 2 + python/CuTeDSL/cutlass_dsl/__init__.py | 2 + python/CuTeDSL/requirements.txt | 2 +- python/cutlass_library/generator.py | 3 +- python/cutlass_library/library.py | 14 +++ ...0_gemm_f8_f8_f8_tensor_op_f32_blockwise.cu | 38 +++++-- 12 files changed, 163 insertions(+), 65 deletions(-) diff --git a/include/cutlass/epilogue/collective/builders/sm100_builder.inl b/include/cutlass/epilogue/collective/builders/sm100_builder.inl index 8634134b..7c14ac33 100644 --- a/include/cutlass/epilogue/collective/builders/sm100_builder.inl +++ b/include/cutlass/epilogue/collective/builders/sm100_builder.inl @@ -1435,8 +1435,12 @@ private: is_same_v || is_same_v || is_same_v; - // Input transform kernels - when dispatching to sm100 nosmem epilogue, go through the default path without EVT support. - static constexpr bool IsInputTransformSchedule = IsInterleavedComplex || IsFastF32Schedule; + static constexpr bool IsBlockwiseSchedule = is_same_v || + is_same_v || + is_same_v || + is_same_v; + // Transform kernels - when dispatching to sm100 nosmem epilogue, go through the default path without EVT support. + static constexpr bool IsTransformSchedule = IsInterleavedComplex || IsFastF32Schedule || IsBlockwiseSchedule; static_assert(Is1SmMma ^ Is2SmMma, "unsupported schedule"); static_assert(not (Is2SmMma && size<0>(ClusterShape_MNK{}) % 2 == 1), "schedule + cluster mismatch"); @@ -1470,7 +1474,7 @@ private: static_assert(is_tuple_v, "Shape or Tile"); return EpilogueTileType{}; } - else if constexpr (is_same_v || not IsInputTransformSchedule) { + else if constexpr (is_same_v || not IsTransformSchedule) { // Save register usage for sm103 blockscaled kernels and sm100 cpasync kernels // to avoid register spilling. constexpr int EpiM = size<0>(CtaTileShape_MNK{}); @@ -1501,7 +1505,7 @@ private: DisableSource ? thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default; if constexpr (IsDefaultFusionOp::value &&\ not is_same_v && \ - (IsInputTransformSchedule || \ + (IsTransformSchedule || \ is_same_v || \ is_same_v) ) { diff --git a/include/cutlass/epilogue/dispatch_policy.hpp b/include/cutlass/epilogue/dispatch_policy.hpp index c9788a42..ca91ac19 100644 --- a/include/cutlass/epilogue/dispatch_policy.hpp +++ b/include/cutlass/epilogue/dispatch_policy.hpp @@ -63,10 +63,14 @@ struct NoSmemWarpSpecialized1Sm {}; struct NoSmemWarpSpecialized2Sm {}; struct FastF32NoSmemWarpSpecialized1Sm : NoSmemWarpSpecialized1Sm {}; struct FastF32NoSmemWarpSpecialized2Sm : NoSmemWarpSpecialized2Sm {}; +struct BlockwiseNoSmemWarpSpecialized1Sm : NoSmemWarpSpecialized1Sm {}; +struct BlockwiseNoSmemWarpSpecialized2Sm : NoSmemWarpSpecialized2Sm {}; struct PtrArrayNoSmemWarpSpecialized1Sm : NoSmemWarpSpecialized1Sm {}; struct PtrArrayNoSmemWarpSpecialized2Sm : NoSmemWarpSpecialized2Sm {}; struct PtrArrayFastF32NoSmemWarpSpecialized1Sm : PtrArrayNoSmemWarpSpecialized1Sm {}; struct PtrArrayFastF32NoSmemWarpSpecialized2Sm : PtrArrayNoSmemWarpSpecialized2Sm {}; +struct PtrArrayBlockwiseNoSmemWarpSpecialized1Sm : PtrArrayNoSmemWarpSpecialized1Sm {}; +struct PtrArrayBlockwiseNoSmemWarpSpecialized2Sm : PtrArrayNoSmemWarpSpecialized2Sm {}; // Blackwell TMA schedules struct TmaWarpSpecialized1Sm {}; struct TmaWarpSpecialized2Sm {}; diff --git a/python/CuTeDSL/base_dsl/ast_helpers.py b/python/CuTeDSL/base_dsl/ast_helpers.py index 7b0832b8..7b11474c 100644 --- a/python/CuTeDSL/base_dsl/ast_helpers.py +++ b/python/CuTeDSL/base_dsl/ast_helpers.py @@ -20,6 +20,7 @@ import warnings import inspect from types import BuiltinFunctionType from functools import lru_cache +from inspect import getmembers from .utils.logger import log from .common import * @@ -579,3 +580,37 @@ def redirect_builtin_function(fcn): if isinstance(fcn, BuiltinFunctionType) and executor._builtin_redirector: return executor._builtin_redirector(fcn) return fcn + + +def copy_members(dest, src): + """ + Copies all non-callable, non-dunder members from src to dest if they exist in src. + Skips members that are callables or have names starting with double underscores. + """ + if id(dest) == id(src): + return + + members = getmembers(dest) + for name, value in members: + if ( + name.startswith("__") + or isinstance(value, Callable) + or not hasattr(src, name) + ): + continue + setattr(dest, name, getattr(src, name)) + + +def get_locals_or_none(locals, symbols): + """ + Given a locals() dictionary and a list of symbol names, return a list of their values + in the same order as the symbols list. If a symbol is not present in locals, None is returned + for that symbol. + """ + variables = [] + for symbol in symbols: + if symbol in locals: + variables.append(locals[symbol]) + else: + variables.append(None) + return variables diff --git a/python/CuTeDSL/base_dsl/ast_preprocessor.py b/python/CuTeDSL/base_dsl/ast_preprocessor.py index b9991a75..11f2d1ae 100644 --- a/python/CuTeDSL/base_dsl/ast_preprocessor.py +++ b/python/CuTeDSL/base_dsl/ast_preprocessor.py @@ -668,12 +668,7 @@ class DSLPreprocessor(ast.NodeTransformer): ast.keyword(arg="prefetch_stages", value=prefetch_stages), ast.keyword( arg="write_args", - value=ast.List( - elts=[ - ast.Name(id=arg, ctx=ast.Load()) for arg in write_args - ], - ctx=ast.Load(), - ), + value=self.generate_get_locals_or_none_call(write_args), ), ast.keyword( arg="full_write_args_count", @@ -707,28 +702,6 @@ class DSLPreprocessor(ast.NodeTransformer): node, ) - def create_loop_call(self, func_name, iter_args): - """ - Assigns the returned value from the loop function directly (without a tuple unpacking). - """ - if len(iter_args) == 0: - return ast.Expr(value=ast.Name(id=func_name, ctx=ast.Load())) - elif len(iter_args) == 1: - return ast.Assign( - targets=[ast.Name(id=iter_args[0], ctx=ast.Store())], - value=ast.Name(id=func_name, ctx=ast.Load()), - ) - else: - return ast.Assign( - targets=[ - ast.Tuple( - elts=[ast.Name(id=var, ctx=ast.Store()) for var in iter_args], - ctx=ast.Store(), - ) - ], - value=ast.Name(id=func_name, ctx=ast.Load()), - ) - def visit_BoolOp(self, node): # Visit child nodes first self.generic_visit(node) @@ -1140,10 +1113,10 @@ class DSLPreprocessor(ast.NodeTransformer): full_write_args_count, ) - assign = ast.copy_location(self.create_loop_call(func_name, write_args), node) + assign = self.create_cf_call(func_name, write_args, node) # This should work fine as it modifies the AST structure - exprs = exprs + [func_def, assign] + exprs = exprs + [func_def] + assign if target_var_is_active_before_loop: # Create a new assignment to the target variable @@ -1429,11 +1402,9 @@ class DSLPreprocessor(ast.NodeTransformer): func_def = self.create_while_function( func_name, node, write_args, full_write_args_count ) - assign = ast.copy_location( - self.create_loop_call(func_name, write_args), node - ) + assign = self.create_cf_call(func_name, write_args, node) - return [func_def, assign] + return [func_def] + assign def visit_Try(self, node): with self.scope_manager: @@ -1447,17 +1418,27 @@ class DSLPreprocessor(ast.NodeTransformer): self.generic_visit(node) return node - def create_if_call(self, func_name, yield_args): + def create_cf_call(self, func_name, yield_args, node): """Creates the assignment statement for the if function call""" if not yield_args: - return ast.Expr(value=ast.Name(id=func_name, ctx=ast.Load())) - elif len(yield_args) == 1: - return ast.Assign( + return [ + ast.copy_location( + ast.Expr(value=ast.Name(id=func_name, ctx=ast.Load())), node + ) + ] + has_self = False + for i, arg in enumerate(yield_args): + if arg == "self": + has_self = True + yield_args[i] = "yield_self" + break + if len(yield_args) == 1: + assign = ast.Assign( targets=[ast.Name(id=yield_args[0], ctx=ast.Store())], value=ast.Name(id=func_name, ctx=ast.Load()), ) else: - return ast.Assign( + assign = ast.Assign( targets=[ ast.Tuple( elts=[ast.Name(id=var, ctx=ast.Store()) for var in yield_args], @@ -1467,6 +1448,23 @@ class DSLPreprocessor(ast.NodeTransformer): value=ast.Name(id=func_name, ctx=ast.Load()), ) + if has_self: + fix_self = ast.Expr( + value=ast.Call( + func=self._create_module_attribute( + "copy_members", lineno=node.lineno, col_offset=node.col_offset + ), + args=[ + ast.Name(id="self", ctx=ast.Load()), + ast.Name(id="yield_self", ctx=ast.Load()), + ], + keywords=[], + ) + ) + return [ast.copy_location(assign, node), ast.copy_location(fix_self, node)] + else: + return [ast.copy_location(assign, node)] + def visit_IfExp(self, node): """ Visits an inline if-else expression (ternary operator). @@ -1567,9 +1565,24 @@ class DSLPreprocessor(ast.NodeTransformer): func_def = self.create_if_function( func_name, node, yield_args, full_write_args_count ) - assign = ast.copy_location(self.create_if_call(func_name, yield_args), node) + assign = self.create_cf_call(func_name, yield_args, node) - return [func_def, assign] + return [func_def] + assign + + def generate_get_locals_or_none_call(self, write_args): + return ast.Call( + func=self._create_module_attribute("get_locals_or_none"), + args=[ + ast.Call( + func=ast.Name(id="locals", ctx=ast.Load()), args=[], keywords=[] + ), + ast.List( + elts=[ast.Constant(value=arg) for arg in write_args], + ctx=ast.Load(), + ), + ], + keywords=[], + ) def create_if_function(self, func_name, node, write_args, full_write_args_count): test_expr = self.visit(node.test) @@ -1627,10 +1640,7 @@ class DSLPreprocessor(ast.NodeTransformer): ), # ast.Name(id="pred", ctx=ast.Load()) ast.keyword( arg="write_args", - value=ast.List( - elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in write_args], - ctx=ast.Load(), - ), + value=self.generate_get_locals_or_none_call(write_args), ), ] @@ -1813,10 +1823,7 @@ class DSLPreprocessor(ast.NodeTransformer): ast.keyword(arg="pred", value=test_expr), ast.keyword( arg="write_args", - value=ast.List( - elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in write_args], - ctx=ast.Load(), - ), + value=self.generate_get_locals_or_none_call(write_args), ), ] decorator = ast.copy_location( diff --git a/python/CuTeDSL/base_dsl/runtime/cuda.py b/python/CuTeDSL/base_dsl/runtime/cuda.py index c2ad2203..97ae778c 100644 --- a/python/CuTeDSL/base_dsl/runtime/cuda.py +++ b/python/CuTeDSL/base_dsl/runtime/cuda.py @@ -255,7 +255,13 @@ def initialize_cuda_context(device_id: int = 0, flags: int = 0): _log().info(f"{cuDevice} <-- cuDeviceGet") # Create context _log().info(f"cuCtxCreate {0} {cuDevice}") - context = checkCudaErrors(cuda.cuCtxCreate(0, cuDevice)) + if cuda.CUDA_VERSION >= 13000: + # Use cuCtxCreate_v4 API with explicit CUctxCreateParams None, since v2 + # and v3 API has been removed from CTK 13. + # See https://github.com/NVIDIA/cuda-python/pull/792 + context = checkCudaErrors(cuda.cuCtxCreate(None, 0, cuDevice)) + else: + context = checkCudaErrors(cuda.cuCtxCreate(0, cuDevice)) _log().info(f"{context} <-- cuCtxCreate") return context diff --git a/python/CuTeDSL/base_dsl/utils/logger.py b/python/CuTeDSL/base_dsl/utils/logger.py index b239f346..d4e4b4ed 100644 --- a/python/CuTeDSL/base_dsl/utils/logger.py +++ b/python/CuTeDSL/base_dsl/utils/logger.py @@ -47,7 +47,8 @@ def setup_log( if log_to_console or log_to_file: logger.setLevel(log_level) else: - logger.setLevel(logging.NOTSET) + # Makes sure logging is OFF + logger.setLevel(logging.CRITICAL + 1) # Clear existing handlers to prevent duplicate logs if logger.hasHandlers(): diff --git a/python/CuTeDSL/cutlass/__init__.py b/python/CuTeDSL/cutlass/__init__.py index d0e7c93b..f2c7ed26 100644 --- a/python/CuTeDSL/cutlass/__init__.py +++ b/python/CuTeDSL/cutlass/__init__.py @@ -55,3 +55,5 @@ LaunchConfig = _dsl.BaseDSL.LaunchConfig register_jit_arg_adapter = _dsl.JitArgAdapterRegistry.register_jit_arg_adapter gpu = _dsl.cutlass_gpu cuda = _dsl.cuda_helpers + +CACHE_FILE = "compiled_cache.db" diff --git a/python/CuTeDSL/cutlass_dsl/__init__.py b/python/CuTeDSL/cutlass_dsl/__init__.py index 5492fb51..06ea3f6f 100644 --- a/python/CuTeDSL/cutlass_dsl/__init__.py +++ b/python/CuTeDSL/cutlass_dsl/__init__.py @@ -31,6 +31,8 @@ from ..base_dsl.ast_helpers import ( range_perf_warning, cf_symbol_check, redirect_builtin_function, + copy_members, + get_locals_or_none, ) from ..base_dsl import * diff --git a/python/CuTeDSL/requirements.txt b/python/CuTeDSL/requirements.txt index 3d1d5b00..f588ea75 100644 --- a/python/CuTeDSL/requirements.txt +++ b/python/CuTeDSL/requirements.txt @@ -1,3 +1,3 @@ # Use `pip install -r requirements.txt` with the present file to install a # wheel consistent with the present state of the github repository -nvidia-cutlass-dsl==4.2.0 +nvidia-cutlass-dsl==4.2.1 diff --git a/python/cutlass_library/generator.py b/python/cutlass_library/generator.py index c10fe315..063e8fb1 100644 --- a/python/cutlass_library/generator.py +++ b/python/cutlass_library/generator.py @@ -7467,8 +7467,9 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version, kernel_schedule = to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100, grouped) epi_schedule = to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped) + epi_schedule_nosmem = to_grouped_schedule(EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm, grouped) CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, - [[kernel_schedule, epi_schedule]], + [[kernel_schedule, epi_schedule], [kernel_schedule, epi_schedule_nosmem]], tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.Universal3x): diff --git a/python/cutlass_library/library.py b/python/cutlass_library/library.py index 19875a43..56d22dc4 100644 --- a/python/cutlass_library/library.py +++ b/python/cutlass_library/library.py @@ -811,10 +811,14 @@ class EpilogueScheduleType(enum.Enum): NoSmemWarpSpecialized2Sm = enum_auto() FastF32NoSmemWarpSpecialized1Sm = enum_auto() FastF32NoSmemWarpSpecialized2Sm = enum_auto() + BlockwiseNoSmemWarpSpecialized1Sm = enum_auto() + BlockwiseNoSmemWarpSpecialized2Sm = enum_auto() PtrArrayNoSmemWarpSpecialized1Sm = enum_auto() PtrArrayNoSmemWarpSpecialized2Sm = enum_auto() PtrArrayFastF32NoSmemWarpSpecialized1Sm = enum_auto() PtrArrayFastF32NoSmemWarpSpecialized2Sm = enum_auto() + PtrArrayBlockwiseNoSmemWarpSpecialized1Sm = enum_auto() + PtrArrayBlockwiseNoSmemWarpSpecialized2Sm = enum_auto() TmaWarpSpecialized = enum_auto() TmaWarpSpecializedCooperative = enum_auto() TmaWarpSpecialized1Sm = enum_auto() @@ -834,10 +838,14 @@ EpilogueScheduleTag = { EpilogueScheduleType.NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::NoSmemWarpSpecialized2Sm', EpilogueScheduleType.FastF32NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::FastF32NoSmemWarpSpecialized1Sm', EpilogueScheduleType.FastF32NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::FastF32NoSmemWarpSpecialized2Sm', + EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm', + EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::BlockwiseNoSmemWarpSpecialized2Sm', EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized1Sm', EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized2Sm', EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayFastF32NoSmemWarpSpecialized1Sm', EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayFastF32NoSmemWarpSpecialized2Sm', + EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayBlockwiseNoSmemWarpSpecialized1Sm', + EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayBlockwiseNoSmemWarpSpecialized2Sm', EpilogueScheduleType.TmaWarpSpecialized: 'cutlass::epilogue::TmaWarpSpecialized', EpilogueScheduleType.TmaWarpSpecializedCooperative: 'cutlass::epilogue::TmaWarpSpecializedCooperative', EpilogueScheduleType.TmaWarpSpecialized1Sm: 'cutlass::epilogue::TmaWarpSpecialized1Sm', @@ -858,10 +866,14 @@ EpilogueScheduleSuffixes = { EpilogueScheduleType.NoSmemWarpSpecialized2Sm: '_epi_nosmem', EpilogueScheduleType.FastF32NoSmemWarpSpecialized1Sm: '_epi_nosmem_fastf32', EpilogueScheduleType.FastF32NoSmemWarpSpecialized2Sm: '_epi_nosmem_fastf32', + EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm: '_epi_nosmem', + EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized2Sm: '_epi_nosmem', EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: '_epi_nosmem', EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: '_epi_nosmem', EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized1Sm: '_epi_nosmem_fastf32', EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized2Sm: '_epi_nosmem_fastf32', + EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized1Sm: '_epi_nosmem', + EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized2Sm: '_epi_nosmem', EpilogueScheduleType.TmaWarpSpecialized: '_epi_tma', EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma', EpilogueScheduleType.TmaWarpSpecialized1Sm: '', @@ -926,6 +938,8 @@ def to_grouped_schedule(schedule, grouped): EpilogueScheduleType.TmaWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm, EpilogueScheduleType.NoSmemWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm, EpilogueScheduleType.NoSmemWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm, + EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized1Sm, + EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized2Sm, # SM103 KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103, KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103, diff --git a/test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_f32_blockwise.cu b/test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_f32_blockwise.cu index 0e4c4cda..85abdb6d 100644 --- a/test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_f32_blockwise.cu +++ b/test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_f32_blockwise.cu @@ -69,6 +69,7 @@ template, Int, Int, C, LayoutA, LayoutB, LayoutCD, - MmaTileShape, ClusterShape) { - + MmaTileShape, ClusterShape, + C) { + using Epilogue1SM = conditional_t; + using Epilogue2SM = conditional_t; using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig; using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand @@ -90,7 +93,7 @@ bool groupwise_test( float, float, cutlass::float_e4m3_t, LayoutCD, 16, cutlass::float_e4m3_t, LayoutCD, 16, - conditional_t + conditional_t >::CollectiveOp; using CollectiveMainloop = @@ -259,11 +262,26 @@ TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_1sm_f32_align16_blockwise, 128 cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{}, cutlass::layout::RowMajor{}, Shape<_128,_128,_128>{}, - Shape<_1,_1,_1>{}); + Shape<_1,_1,_1>{}, + false_type{}); EXPECT_TRUE(passed); } +TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_1sm_f32_align16_blockwise, 128x128x128_1x1x1_2x2x32_scale_direct_store) { + + bool passed = groupwise_test( + Int<2>{}, Int<2>{}, Int<32>{}, false_type{}, + cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{}, + cutlass::layout::RowMajor{}, + Shape<_128,_128,_128>{}, + Shape<_1,_1,_1>{}, + true_type{}); + + EXPECT_TRUE(passed); +} + + TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_2sm_f32_align16_blockwise, 256x128x128_2x1x1_64x4x32_scale) { bool passed = groupwise_test( @@ -271,7 +289,8 @@ TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_2sm_f32_align16_blockwise, 256 cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{}, cutlass::layout::RowMajor{}, Shape<_256,_128,_128>{}, - Shape<_2,_1,_1>{}); + Shape<_2,_1,_1>{}, + false_type{}); EXPECT_TRUE(passed); @@ -284,7 +303,8 @@ TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_1sm_f32_align16_blockwise, 128 cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{}, cutlass::layout::RowMajor{}, Shape<_128,_128,_128>{}, - Shape<_1,_1,_1>{}); + Shape<_1,_1,_1>{}, + false_type{}); EXPECT_TRUE(passed); @@ -297,7 +317,8 @@ TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_2sm_f32_align16_blockwise, 256 cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{}, cutlass::layout::RowMajor{}, Shape<_256,_128,_128>{}, - Shape<_2,_1,_1>{}); + Shape<_2,_1,_1>{}, + false_type{}); EXPECT_TRUE(passed); @@ -311,7 +332,8 @@ TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_2sm_f32_align16_blockwise, 256 cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{}, cutlass::layout::RowMajor{}, Shape<_256,_128,_128>{}, - Shape<_2,_1,_1>{}); + Shape<_2,_1,_1>{}, + false_type{}); EXPECT_TRUE(passed);