Updates for CUTLASS 3.5.0 (#1468)

This commit is contained in:
Vijay Thakkar
2024-04-11 21:33:40 -04:00
committed by GitHub
parent a40e08e9d5
commit 7d49e6c7e2
171 changed files with 7526 additions and 1888 deletions

View File

@ -1654,7 +1654,7 @@ class GemmOperationBase:
extended_name = "{core_name}_{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format(
element_a=DataTypeNames[self.A.element],
element_b=DataTypeNames[self.B.element],
element_acc=DataTypeNames[self.tile_description.math_instruction.element_accumulator],
element_acc=DataTypeNames[self.accumulator_type()],
element_c=DataTypeNames[self.C.element],
element_d=DataTypeNames[self.epilogue_functor.element_output],
core_name=self.core_name())

View File

@ -118,16 +118,18 @@ cutlass::Status ${name}_kernel_run(
typename DeviceKernel::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{M, N, K, L}, // problem size
A, // ptrA
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)), // stride A
B, // ptrB
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)), // stride B
{
A, // ptrA
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)), // stride A
B, // ptrB
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)), // stride B
},
{
{alpha, beta},
C, // ptrC
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)), // stride C
D, // ptrD
cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)), // stride D
{alpha, beta},
},
hw_info
};

View File

@ -232,7 +232,7 @@ _PYTORCH_GEMM_INCLUDES = {
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/util/packed_stride.hpp"
""",
}
@ -583,7 +583,11 @@ setup(
'${name}_kernel.cu',
],
include_dirs=['${cutlass_path}/include', '${cutlass_path}/tools/util/include'],
extra_compile_args=['-std=c++17']
extra_compile_args={
'cxx': ['-std=c++17'],
'nvcc': ['-std=c++17', ${extra_compile_args}],
},
libraries=['cuda']
),
],
cmdclass={
@ -593,7 +597,7 @@ setup(
"""
def _generate_setup(name: str, sourcedir: str):
def _generate_setup(name: str, sourcedir: str, extra_compile_args: str=""):
"""
Generates a setup.py file for the extension
@ -601,10 +605,12 @@ def _generate_setup(name: str, sourcedir: str):
:type name: str
:param sourcedir: directory to which generated source files should be written
:type sourcedir: str
:param extra_compile_args: additional arguments to pass to setup.py
:type extra_args: str
"""
setup_py_file = os.path.join(sourcedir, "setup.py")
setup_source = SubstituteTemplate(
_PYTORCH_SETUP_PY, {"name": name, "cutlass_path": CUTLASS_PATH}
_PYTORCH_SETUP_PY, {"name": name, "cutlass_path": CUTLASS_PATH, "extra_compile_args": extra_compile_args}
)
with open(setup_py_file, "w") as outfile:
outfile.write(setup_source)
@ -696,6 +702,7 @@ def _jit(name: str, cc: int, cpp_file: str, cuda_file: str):
os.path.join(CUTLASS_PATH, "include"),
os.path.join(CUTLASS_PATH, "tools/util/include"),
],
extra_ldflags=["-lcuda"],
verbose=(logger.level == logging.DEBUG)
)
return jitmodule
@ -759,7 +766,10 @@ def _pytorch_gemm(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""
with open(cpp_file, "w") as outfile:
outfile.write(cpp_source)
_generate_setup(name, sourcedir)
extra_compile_args = ""
if cc == 90:
extra_compile_args = "'--generate-code=arch=compute_90a,code=[sm_90a]'"
_generate_setup(name, sourcedir, extra_compile_args)
if jit:
return _jit(name, cc, cpp_file, cuda_file)

View File

@ -137,9 +137,9 @@ class KernelsForDataType:
# Finally, go through all available alignment combinations and find
# one for which all values are less than those passed in.
key = None
alignments = sorted([(int(x) for x in k.split(" ")) for k in self.kernels_by_alignment.keys()], reverse=True)
alignments = sorted([tuple(int(x) for x in k.split(" ")) for k in self.kernels_by_alignment.keys()], reverse=True)
for align_A, align_B, align_C in alignments:
if align_A <= alignment_A and align_B <= alignment_B and align_C <= alignment_C:
if alignment_A % align_A == 0 and alignment_B % align_B == 0 and alignment_C % align_C == 0:
key = f"{align_A} {align_B} {align_C}"
break

View File

@ -712,4 +712,4 @@ class Gemm(OperationBase):
if sync:
arguments.sync()
return arguments
return arguments

View File

@ -205,7 +205,7 @@ class GemmOperation:
extended_name = "{core_name}_{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format(
element_a = DataTypeNames[self.A.element],
element_b = DataTypeNames[self.B.element],
element_acc = DataTypeNames[self.tile_description.math_instruction.element_accumulator],
element_acc = DataTypeNames[self.accumulator_type()],
element_c = DataTypeNames[self.C.element],
element_d = DataTypeNames[self.D.element],
core_name = self.core_name())
@ -216,7 +216,7 @@ class GemmOperation:
datatype_name = "{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format(
element_a = DataTypeNames[self.A.element],
element_b = DataTypeNames[self.B.element],
element_acc = DataTypeNames[self.tile_description.math_instruction.element_accumulator],
element_acc = DataTypeNames[self.accumulator_type()],
element_c = DataTypeNames[self.C.element],
element_d = DataTypeNames[self.D.element])
return datatype_name
@ -744,7 +744,7 @@ using ${operation_name}_mainloop =
cute::Shape<cute::_${tile_shape_main_m}, cute::_${tile_shape_main_n}, cute::_${tile_shape_main_k}>,
cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>,
${stages},
${kernel_schedule}
${kernel_schedule}
>::CollectiveOp;
// Gemm operator ${operation_name}
@ -817,8 +817,9 @@ ${compile_guard_end}
else:
epilogue_functor = self.epilogue_functor.emit_declaration()
#
element_a = DataTypeTag[operation.A.element]
element_b = DataTypeTag[operation.B.element]
# Cutlass3x complex kernels' ElementA(B) is a tuple in collective mainloop builder, e.g. cute::tuple<Element, Transform>, Transform : cute::identity / cute::conjugate.
element_a = DataTypeTag[operation.A.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.A.element])},{str(ComplexTransformTag3x[operation.A.complex_transform])}>"
element_b = DataTypeTag[operation.B.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.B.element])},{str(ComplexTransformTag3x[operation.B.complex_transform])}>"
epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule]
values = {
'operation_name': operation.procedural_name(),

View File

@ -967,6 +967,7 @@ class ConvOperation3x:
def configuration_name(self):
prefix = 'cutlass3x'
arch = self.arch
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
tbm = self.tile_description.tile_shape[0]
tbn = self.tile_description.tile_shape[1]
@ -979,7 +980,7 @@ class ConvOperation3x:
kernel_schedule = KernelScheduleSuffixes[self.kernel_schedule]
epilogue_schedule = EpilogueScheduleSuffixes[self.epilogue_schedule]
return f"{prefix}_{opcode_class_name}_{self.extended_name()}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{self.tile_description.stages}_align{alignment}{tile_scheduler}{kernel_schedule}{epilogue_schedule}"
return f"{prefix}_sm{arch}_{opcode_class_name}_{self.extended_name()}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{self.tile_description.stages}_align{alignment}{tile_scheduler}{kernel_schedule}{epilogue_schedule}"
def procedural_name(self):
return self.configuration_name()

View File

@ -250,6 +250,12 @@ ComplexTransformTag = {
ComplexTransform.conj: 'cutlass::ComplexTransform::kConjugate',
}
# Used for cutlass3x complex kernel collective mainloop builder instantiation
ComplexTransformTag3x = {
ComplexTransform.none: 'cute::identity',
ComplexTransform.conj: 'cute::conjugate',
}
#
RealComplexBijection = [
(DataType.f16, DataType.cf16),