Updates for CUTLASS 3.5.0 (#1468)
This commit is contained in:
@ -1654,7 +1654,7 @@ class GemmOperationBase:
|
||||
extended_name = "{core_name}_{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format(
|
||||
element_a=DataTypeNames[self.A.element],
|
||||
element_b=DataTypeNames[self.B.element],
|
||||
element_acc=DataTypeNames[self.tile_description.math_instruction.element_accumulator],
|
||||
element_acc=DataTypeNames[self.accumulator_type()],
|
||||
element_c=DataTypeNames[self.C.element],
|
||||
element_d=DataTypeNames[self.epilogue_functor.element_output],
|
||||
core_name=self.core_name())
|
||||
|
||||
@ -118,16 +118,18 @@ cutlass::Status ${name}_kernel_run(
|
||||
typename DeviceKernel::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{M, N, K, L}, // problem size
|
||||
A, // ptrA
|
||||
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)), // stride A
|
||||
B, // ptrB
|
||||
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)), // stride B
|
||||
{
|
||||
A, // ptrA
|
||||
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)), // stride A
|
||||
B, // ptrB
|
||||
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)), // stride B
|
||||
},
|
||||
{
|
||||
{alpha, beta},
|
||||
C, // ptrC
|
||||
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)), // stride C
|
||||
D, // ptrD
|
||||
cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)), // stride D
|
||||
{alpha, beta},
|
||||
},
|
||||
hw_info
|
||||
};
|
||||
|
||||
@ -232,7 +232,7 @@ _PYTORCH_GEMM_INCLUDES = {
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
""",
|
||||
}
|
||||
@ -583,7 +583,11 @@ setup(
|
||||
'${name}_kernel.cu',
|
||||
],
|
||||
include_dirs=['${cutlass_path}/include', '${cutlass_path}/tools/util/include'],
|
||||
extra_compile_args=['-std=c++17']
|
||||
extra_compile_args={
|
||||
'cxx': ['-std=c++17'],
|
||||
'nvcc': ['-std=c++17', ${extra_compile_args}],
|
||||
},
|
||||
libraries=['cuda']
|
||||
),
|
||||
],
|
||||
cmdclass={
|
||||
@ -593,7 +597,7 @@ setup(
|
||||
"""
|
||||
|
||||
|
||||
def _generate_setup(name: str, sourcedir: str):
|
||||
def _generate_setup(name: str, sourcedir: str, extra_compile_args: str=""):
|
||||
"""
|
||||
Generates a setup.py file for the extension
|
||||
|
||||
@ -601,10 +605,12 @@ def _generate_setup(name: str, sourcedir: str):
|
||||
:type name: str
|
||||
:param sourcedir: directory to which generated source files should be written
|
||||
:type sourcedir: str
|
||||
:param extra_compile_args: additional arguments to pass to setup.py
|
||||
:type extra_args: str
|
||||
"""
|
||||
setup_py_file = os.path.join(sourcedir, "setup.py")
|
||||
setup_source = SubstituteTemplate(
|
||||
_PYTORCH_SETUP_PY, {"name": name, "cutlass_path": CUTLASS_PATH}
|
||||
_PYTORCH_SETUP_PY, {"name": name, "cutlass_path": CUTLASS_PATH, "extra_compile_args": extra_compile_args}
|
||||
)
|
||||
with open(setup_py_file, "w") as outfile:
|
||||
outfile.write(setup_source)
|
||||
@ -696,6 +702,7 @@ def _jit(name: str, cc: int, cpp_file: str, cuda_file: str):
|
||||
os.path.join(CUTLASS_PATH, "include"),
|
||||
os.path.join(CUTLASS_PATH, "tools/util/include"),
|
||||
],
|
||||
extra_ldflags=["-lcuda"],
|
||||
verbose=(logger.level == logging.DEBUG)
|
||||
)
|
||||
return jitmodule
|
||||
@ -759,7 +766,10 @@ def _pytorch_gemm(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""
|
||||
with open(cpp_file, "w") as outfile:
|
||||
outfile.write(cpp_source)
|
||||
|
||||
_generate_setup(name, sourcedir)
|
||||
extra_compile_args = ""
|
||||
if cc == 90:
|
||||
extra_compile_args = "'--generate-code=arch=compute_90a,code=[sm_90a]'"
|
||||
_generate_setup(name, sourcedir, extra_compile_args)
|
||||
|
||||
if jit:
|
||||
return _jit(name, cc, cpp_file, cuda_file)
|
||||
|
||||
@ -137,9 +137,9 @@ class KernelsForDataType:
|
||||
# Finally, go through all available alignment combinations and find
|
||||
# one for which all values are less than those passed in.
|
||||
key = None
|
||||
alignments = sorted([(int(x) for x in k.split(" ")) for k in self.kernels_by_alignment.keys()], reverse=True)
|
||||
alignments = sorted([tuple(int(x) for x in k.split(" ")) for k in self.kernels_by_alignment.keys()], reverse=True)
|
||||
for align_A, align_B, align_C in alignments:
|
||||
if align_A <= alignment_A and align_B <= alignment_B and align_C <= alignment_C:
|
||||
if alignment_A % align_A == 0 and alignment_B % align_B == 0 and alignment_C % align_C == 0:
|
||||
key = f"{align_A} {align_B} {align_C}"
|
||||
break
|
||||
|
||||
|
||||
@ -712,4 +712,4 @@ class Gemm(OperationBase):
|
||||
if sync:
|
||||
arguments.sync()
|
||||
|
||||
return arguments
|
||||
return arguments
|
||||
|
||||
@ -205,7 +205,7 @@ class GemmOperation:
|
||||
extended_name = "{core_name}_{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format(
|
||||
element_a = DataTypeNames[self.A.element],
|
||||
element_b = DataTypeNames[self.B.element],
|
||||
element_acc = DataTypeNames[self.tile_description.math_instruction.element_accumulator],
|
||||
element_acc = DataTypeNames[self.accumulator_type()],
|
||||
element_c = DataTypeNames[self.C.element],
|
||||
element_d = DataTypeNames[self.D.element],
|
||||
core_name = self.core_name())
|
||||
@ -216,7 +216,7 @@ class GemmOperation:
|
||||
datatype_name = "{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format(
|
||||
element_a = DataTypeNames[self.A.element],
|
||||
element_b = DataTypeNames[self.B.element],
|
||||
element_acc = DataTypeNames[self.tile_description.math_instruction.element_accumulator],
|
||||
element_acc = DataTypeNames[self.accumulator_type()],
|
||||
element_c = DataTypeNames[self.C.element],
|
||||
element_d = DataTypeNames[self.D.element])
|
||||
return datatype_name
|
||||
@ -744,7 +744,7 @@ using ${operation_name}_mainloop =
|
||||
cute::Shape<cute::_${tile_shape_main_m}, cute::_${tile_shape_main_n}, cute::_${tile_shape_main_k}>,
|
||||
cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>,
|
||||
${stages},
|
||||
${kernel_schedule}
|
||||
${kernel_schedule}
|
||||
>::CollectiveOp;
|
||||
|
||||
// Gemm operator ${operation_name}
|
||||
@ -817,8 +817,9 @@ ${compile_guard_end}
|
||||
else:
|
||||
epilogue_functor = self.epilogue_functor.emit_declaration()
|
||||
#
|
||||
element_a = DataTypeTag[operation.A.element]
|
||||
element_b = DataTypeTag[operation.B.element]
|
||||
# Cutlass3x complex kernels' ElementA(B) is a tuple in collective mainloop builder, e.g. cute::tuple<Element, Transform>, Transform : cute::identity / cute::conjugate.
|
||||
element_a = DataTypeTag[operation.A.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.A.element])},{str(ComplexTransformTag3x[operation.A.complex_transform])}>"
|
||||
element_b = DataTypeTag[operation.B.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.B.element])},{str(ComplexTransformTag3x[operation.B.complex_transform])}>"
|
||||
epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule]
|
||||
values = {
|
||||
'operation_name': operation.procedural_name(),
|
||||
|
||||
@ -967,6 +967,7 @@ class ConvOperation3x:
|
||||
|
||||
def configuration_name(self):
|
||||
prefix = 'cutlass3x'
|
||||
arch = self.arch
|
||||
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
|
||||
tbm = self.tile_description.tile_shape[0]
|
||||
tbn = self.tile_description.tile_shape[1]
|
||||
@ -979,7 +980,7 @@ class ConvOperation3x:
|
||||
kernel_schedule = KernelScheduleSuffixes[self.kernel_schedule]
|
||||
epilogue_schedule = EpilogueScheduleSuffixes[self.epilogue_schedule]
|
||||
|
||||
return f"{prefix}_{opcode_class_name}_{self.extended_name()}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{self.tile_description.stages}_align{alignment}{tile_scheduler}{kernel_schedule}{epilogue_schedule}"
|
||||
return f"{prefix}_sm{arch}_{opcode_class_name}_{self.extended_name()}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{self.tile_description.stages}_align{alignment}{tile_scheduler}{kernel_schedule}{epilogue_schedule}"
|
||||
|
||||
def procedural_name(self):
|
||||
return self.configuration_name()
|
||||
|
||||
@ -250,6 +250,12 @@ ComplexTransformTag = {
|
||||
ComplexTransform.conj: 'cutlass::ComplexTransform::kConjugate',
|
||||
}
|
||||
|
||||
# Used for cutlass3x complex kernel collective mainloop builder instantiation
|
||||
ComplexTransformTag3x = {
|
||||
ComplexTransform.none: 'cute::identity',
|
||||
ComplexTransform.conj: 'cute::conjugate',
|
||||
}
|
||||
|
||||
#
|
||||
RealComplexBijection = [
|
||||
(DataType.f16, DataType.cf16),
|
||||
|
||||
Reference in New Issue
Block a user