CUTLASS 2.9 (#468)
This commit is contained in:
@ -9,10 +9,10 @@ import os.path
|
||||
import shutil
|
||||
import functools
|
||||
import operator
|
||||
import collections
|
||||
|
||||
from library import *
|
||||
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
# Data structure modeling a GEMM operation
|
||||
@ -159,7 +159,9 @@ class GemmOperation:
|
||||
class EmitGemmInstance:
|
||||
''' Responsible for emitting a CUTLASS template definition'''
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, operation_suffix = ''):
|
||||
self.operation_suffix = operation_suffix
|
||||
self.includes = []
|
||||
self.gemm_template = """
|
||||
// Gemm operator ${operation_name}
|
||||
using Operation_${operation_name} = cutlass::gemm::device::Gemm<
|
||||
@ -214,6 +216,15 @@ class EmitGemmInstance:
|
||||
>;
|
||||
"""
|
||||
|
||||
#
|
||||
def instance_template(self):
|
||||
return """
|
||||
${compile_guard_start}
|
||||
manifest.append(new ${gemm_kind}<Operation_${operation_name}>("${operation_name}"));
|
||||
${compile_guard_end}
|
||||
"""
|
||||
|
||||
#
|
||||
def emit(self, operation):
|
||||
|
||||
warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
|
||||
@ -264,7 +275,9 @@ class EmitGemmInstance:
|
||||
class EmitSparseGemmInstance:
|
||||
''' Responsible for emitting a CUTLASS template definition'''
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, operation_suffix = ''):
|
||||
self.operation_suffix = operation_suffix
|
||||
self.includes = []
|
||||
self.gemm_template = """
|
||||
// Gemm operator ${operation_name}
|
||||
using Operation_${operation_name} = cutlass::gemm::device::SparseGemm<
|
||||
@ -293,6 +306,15 @@ class EmitSparseGemmInstance:
|
||||
>;
|
||||
"""
|
||||
|
||||
#
|
||||
def instance_template(self):
|
||||
return """
|
||||
${compile_guard_start}
|
||||
manifest.append(new ${gemm_kind}<Operation_${operation_name}>("${operation_name}"));
|
||||
${compile_guard_end}
|
||||
"""
|
||||
|
||||
#
|
||||
def emit(self, operation):
|
||||
|
||||
warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
|
||||
@ -345,7 +367,26 @@ class EmitSparseGemmInstance:
|
||||
class EmitGemmUniversalInstance:
|
||||
''' Responsible for emitting a CUTLASS template definition'''
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, operation_suffix = ''):
|
||||
self.operation_suffix = operation_suffix
|
||||
self.includes = [
|
||||
"cutlass/cutlass.h",
|
||||
"cutlass/numeric_types.h",
|
||||
"cutlass/arch/arch.h",
|
||||
"cutlass/arch/mma.h",
|
||||
"cutlass/layout/matrix.h",
|
||||
"cutlass/gemm/device/gemm.h",
|
||||
"cutlass/gemm/device/gemm_universal_adapter.h",
|
||||
"cutlass/gemm/kernel/default_gemm_universal.h",
|
||||
]
|
||||
self.builtin_epilogue_functor_template = """
|
||||
${epilogue_functor}<
|
||||
${element_c},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>
|
||||
"""
|
||||
self.gemm_template = """
|
||||
// Gemm operator ${operation_name}
|
||||
using ${operation_name}_base =
|
||||
@ -359,19 +400,14 @@ using ${operation_name}_base =
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
${epilogue_functor}<
|
||||
${element_c},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>,
|
||||
${epilogue_functor},
|
||||
${swizzling_functor},
|
||||
${stages},
|
||||
${math_operation}
|
||||
>::GemmKernel;
|
||||
|
||||
// Define named type
|
||||
struct ${operation_name} :
|
||||
struct ${operation_name}${operation_suffix} :
|
||||
public ${operation_name}_base { };
|
||||
"""
|
||||
self.gemm_template_interleaved = """
|
||||
@ -387,22 +423,28 @@ using ${operation_name}_base =
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
${epilogue_functor}<
|
||||
${element_c},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>,
|
||||
${epilogue_functor},
|
||||
${swizzling_functor},
|
||||
${stages},
|
||||
${math_operation}
|
||||
>::GemmKernel;
|
||||
|
||||
// Define named type
|
||||
struct ${operation_name} :
|
||||
struct ${operation_name}${operation_suffix} :
|
||||
public ${operation_name}_base { };
|
||||
"""
|
||||
|
||||
#
|
||||
def instance_template(self):
|
||||
return """
|
||||
${compile_guard_start}
|
||||
manifest.append(new ${gemm_kind}<
|
||||
cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
|
||||
>("${operation_name}"));
|
||||
${compile_guard_end}
|
||||
"""
|
||||
|
||||
#
|
||||
def emit(self, operation):
|
||||
|
||||
threadblock_shape = operation.tile_description.threadblock_shape
|
||||
@ -410,8 +452,6 @@ struct ${operation_name} :
|
||||
|
||||
warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)]
|
||||
|
||||
epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
|
||||
|
||||
transpose_layouts = {
|
||||
LayoutType.ColumnMajor: LayoutType.RowMajor,
|
||||
LayoutType.RowMajor: LayoutType.ColumnMajor
|
||||
@ -433,8 +473,25 @@ struct ${operation_name} :
|
||||
gemm_template = self.gemm_template_interleaved
|
||||
#
|
||||
|
||||
# Support built-in epilogue functors or user-defined functions
|
||||
if isinstance(operation.epilogue_functor, enum.Enum):
|
||||
|
||||
epilogue_vector_length = \
|
||||
min(operation.C.alignment * DataTypeSize[operation.C.element], 128) // DataTypeSize[operation.C.element]
|
||||
|
||||
values = {
|
||||
'epilogue_vector_length': str(epilogue_vector_length),
|
||||
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
||||
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
||||
}
|
||||
epilogue_functor = SubstituteTemplate(self.builtin_epilogue_functor_template, values)
|
||||
else:
|
||||
epilogue_functor = self.epilogue_functor.emit_declaration()
|
||||
#
|
||||
|
||||
values = {
|
||||
'operation_name': operation.procedural_name(),
|
||||
'operation_suffix': self.operation_suffix,
|
||||
'element_a': DataTypeTag[operation.A.element],
|
||||
'layout_a': LayoutTag[instance_layout_A],
|
||||
'element_b': DataTypeTag[operation.B.element],
|
||||
@ -453,9 +510,7 @@ struct ${operation_name} :
|
||||
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
|
||||
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
||||
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
||||
'epilogue_vector_length': str(epilogue_vector_length),
|
||||
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
||||
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
||||
'epilogue_functor': epilogue_functor,
|
||||
'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
|
||||
'stages': str(operation.tile_description.stages),
|
||||
'align_a': str(operation.A.alignment),
|
||||
@ -473,7 +528,9 @@ struct ${operation_name} :
|
||||
class EmitGemmPlanarComplexInstance:
|
||||
''' Responsible for emitting a CUTLASS template definition'''
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, operation_suffix = ''):
|
||||
self.operation_suffix = operation_suffix
|
||||
self.includes = []
|
||||
self.template = """
|
||||
// Gemm operator ${operation_name}
|
||||
using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal<
|
||||
@ -501,6 +558,17 @@ class EmitGemmPlanarComplexInstance:
|
||||
public Operation_${operation_name} { };
|
||||
"""
|
||||
|
||||
#
|
||||
def instance_template(self):
|
||||
return """
|
||||
${compile_guard_start}
|
||||
manifest.append(new ${gemm_kind}<
|
||||
cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
|
||||
>("${operation_name}"));
|
||||
${compile_guard_end}
|
||||
"""
|
||||
|
||||
#
|
||||
def emit(self, operation):
|
||||
|
||||
warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
|
||||
@ -547,7 +615,9 @@ class EmitGemmPlanarComplexInstance:
|
||||
class EmitGemmPlanarComplexArrayInstance:
|
||||
''' Responsible for emitting a CUTLASS template definition'''
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, operation_suffix = ''):
|
||||
self.operation_suffix = operation_suffix
|
||||
self.includes = []
|
||||
self.template = """
|
||||
// Gemm operator ${operation_name}
|
||||
using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal<
|
||||
@ -574,6 +644,17 @@ class EmitGemmPlanarComplexArrayInstance:
|
||||
struct ${operation_name} : public Operation_${operation_name} { };
|
||||
"""
|
||||
|
||||
#
|
||||
def instance_template(self):
|
||||
return """
|
||||
${compile_guard_start}
|
||||
manifest.append(new ${gemm_kind}<
|
||||
cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
|
||||
>("${operation_name}"));
|
||||
${compile_guard_end}
|
||||
"""
|
||||
|
||||
#
|
||||
def emit(self, operation):
|
||||
|
||||
warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
|
||||
@ -616,6 +697,130 @@ class EmitGemmPlanarComplexArrayInstance:
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
class EmitGemmGroupedInstance:
|
||||
''' Responsible for emitting a CUTLASS template definition'''
|
||||
|
||||
def __init__(self, operation_suffix = ''):
|
||||
self.operation_suffix = operation_suffix
|
||||
self.includes = [
|
||||
"cutlass/cutlass.h",
|
||||
"cutlass/numeric_types.h",
|
||||
"cutlass/arch/arch.h",
|
||||
"cutlass/arch/mma.h",
|
||||
"cutlass/layout/matrix.h",
|
||||
"cutlass/gemm/device/gemm.h",
|
||||
"cutlass/gemm/kernel/gemm_grouped.h",
|
||||
"cutlass/gemm/kernel/default_gemm_grouped.h",
|
||||
"cutlass/gemm/device/gemm_grouped.h"
|
||||
]
|
||||
self.builtin_epilogue_functor_template = """
|
||||
${epilogue_functor}<
|
||||
${element_c},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>
|
||||
"""
|
||||
self.gemm_template = """
|
||||
// Gemm operator ${operation_name}
|
||||
using ${operation_name}_base =
|
||||
typename cutlass::gemm::kernel::DefaultGemmGrouped<
|
||||
${element_a}, ${layout_a}, ${transform_a}, ${align_a},
|
||||
${element_b}, ${layout_b}, ${transform_b}, ${align_b},
|
||||
${element_c}, ${layout_c},
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
${arch},
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
${epilogue_functor},
|
||||
${swizzling_functor},
|
||||
${stages},
|
||||
${math_operation}
|
||||
>::GemmKernel;
|
||||
|
||||
// Define named type
|
||||
struct ${operation_name}${operation_suffix} :
|
||||
public ${operation_name}_base { };
|
||||
"""
|
||||
|
||||
#
|
||||
def instance_template(self):
|
||||
return """
|
||||
${compile_guard_start}
|
||||
manifest.append(new ${gemm_kind}<
|
||||
cutlass::gemm::device::GemmGrouped<${operation_name}>
|
||||
>("${operation_name}"));
|
||||
${compile_guard_end}
|
||||
"""
|
||||
|
||||
#
|
||||
def emit(self, operation):
|
||||
|
||||
threadblock_shape = operation.tile_description.threadblock_shape
|
||||
warp_count = operation.tile_description.warp_count
|
||||
|
||||
warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)]
|
||||
|
||||
transpose_layouts = {
|
||||
LayoutType.ColumnMajor: LayoutType.RowMajor,
|
||||
LayoutType.RowMajor: LayoutType.ColumnMajor
|
||||
}
|
||||
|
||||
instance_layout_A, instance_layout_B, instance_layout_C = \
|
||||
(operation.A.layout, operation.B.layout, operation.C.layout)
|
||||
#
|
||||
|
||||
# Support built-in epilogue functors or user-defined functions
|
||||
if isinstance(operation.epilogue_functor, enum.Enum):
|
||||
|
||||
epilogue_vector_length = \
|
||||
min(operation.C.alignment * DataTypeSize[operation.C.element], 128) // DataTypeSize[operation.C.element]
|
||||
|
||||
values = {
|
||||
'epilogue_vector_length': str(epilogue_vector_length),
|
||||
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
||||
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
||||
}
|
||||
epilogue_functor = SubstituteTemplate(self.builtin_epilogue_functor_template, values)
|
||||
else:
|
||||
epilogue_functor = self.epilogue_functor.emit_declaration()
|
||||
#
|
||||
|
||||
values = {
|
||||
'operation_name': operation.procedural_name(),
|
||||
'operation_suffix': self.operation_suffix,
|
||||
'element_a': DataTypeTag[operation.A.element],
|
||||
'layout_a': LayoutTag[instance_layout_A],
|
||||
'element_b': DataTypeTag[operation.B.element],
|
||||
'layout_b': LayoutTag[instance_layout_B],
|
||||
'element_c': DataTypeTag[operation.C.element],
|
||||
'layout_c': LayoutTag[instance_layout_C],
|
||||
'element_accumulator': DataTypeTag[operation.accumulator_type()],
|
||||
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
||||
'arch': "cutlass::arch::Sm%d" % operation.arch,
|
||||
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
|
||||
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
|
||||
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
|
||||
'warp_shape_m': str(warp_shape[0]),
|
||||
'warp_shape_n': str(warp_shape[1]),
|
||||
'warp_shape_k': str(warp_shape[2]),
|
||||
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
|
||||
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
||||
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
||||
'epilogue_functor': epilogue_functor,
|
||||
'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
|
||||
'stages': str(operation.tile_description.stages),
|
||||
'align_a': str(operation.A.alignment),
|
||||
'align_b': str(operation.B.alignment),
|
||||
'transform_a': ComplexTransformTag[operation.A.complex_transform],
|
||||
'transform_b': ComplexTransformTag[operation.B.complex_transform],
|
||||
'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation]
|
||||
}
|
||||
|
||||
return SubstituteTemplate(self.gemm_template, values)
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
@ -633,7 +838,8 @@ class EmitGemmConfigurationLibrary:
|
||||
GemmKind.Sparse: EmitSparseGemmInstance,
|
||||
GemmKind.Universal: EmitGemmUniversalInstance,
|
||||
GemmKind.PlanarComplex: EmitGemmPlanarComplexInstance,
|
||||
GemmKind.PlanarComplexArray: EmitGemmPlanarComplexArrayInstance
|
||||
GemmKind.PlanarComplexArray: EmitGemmPlanarComplexArrayInstance,
|
||||
GemmKind.Grouped: EmitGemmGroupedInstance
|
||||
}
|
||||
|
||||
self.gemm_kind_wrappers = {
|
||||
@ -641,61 +847,21 @@ class EmitGemmConfigurationLibrary:
|
||||
GemmKind.Sparse: 'GemmSparseOperation',
|
||||
GemmKind.Universal: 'GemmUniversalOperation',
|
||||
GemmKind.PlanarComplex: 'GemmPlanarComplexOperation',
|
||||
GemmKind.PlanarComplexArray: 'GemmPlanarComplexArrayOperation'
|
||||
GemmKind.PlanarComplexArray: 'GemmPlanarComplexArrayOperation',
|
||||
GemmKind.Grouped: 'GemmGroupedOperation'
|
||||
}
|
||||
|
||||
self.wmma_guard_start = "#if defined(CUTLASS_ARCH_WMMA_SM${sm_number}_ENABLED)"
|
||||
|
||||
self.instance_template = {
|
||||
GemmKind.Gemm: """
|
||||
${compile_guard_start}
|
||||
manifest.append(new ${gemm_kind}<Operation_${operation_name}>("${operation_name}"));
|
||||
${compile_guard_end}
|
||||
""",
|
||||
GemmKind.Sparse: """
|
||||
${compile_guard_start}
|
||||
manifest.append(new ${gemm_kind}<Operation_${operation_name}>("${operation_name}"));
|
||||
${compile_guard_end}
|
||||
""",
|
||||
GemmKind.Universal: """
|
||||
${compile_guard_start}
|
||||
manifest.append(new ${gemm_kind}<
|
||||
cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
|
||||
>("${operation_name}"));
|
||||
${compile_guard_end}
|
||||
""",
|
||||
GemmKind.PlanarComplex: """
|
||||
${compile_guard_start}
|
||||
manifest.append(new ${gemm_kind}<
|
||||
cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
|
||||
>("${operation_name}"));
|
||||
${compile_guard_end}
|
||||
""",
|
||||
GemmKind.PlanarComplexArray: """
|
||||
${compile_guard_start}
|
||||
manifest.append(new ${gemm_kind}<
|
||||
cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
|
||||
>("${operation_name}"));
|
||||
${compile_guard_end}
|
||||
self.separator = """
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
"""
|
||||
}
|
||||
|
||||
self.header_template = """
|
||||
/*
|
||||
Generated by gemm_operation.py - Do not edit.
|
||||
*/
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#include "cutlass/arch/wmma.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "library_internal.h"
|
||||
#include "gemm_operation.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
"""
|
||||
|
||||
self.initialize_function_template = """
|
||||
@ -726,7 +892,16 @@ void initialize_${configuration_name}(Manifest &manifest) {
|
||||
def __enter__(self):
|
||||
self.configuration_file = open(self.configuration_path, "w")
|
||||
self.configuration_file.write(self.header_template)
|
||||
self.configuration_file.write(self.separator)
|
||||
|
||||
self.includes = collections.OrderedDict([
|
||||
("cutlass/cutlass.h", None),
|
||||
("cutlass/library/library.h", None),
|
||||
("cutlass/library/manifest.h", None),
|
||||
("library_internal.h", None),
|
||||
("gemm_operation.h", None),
|
||||
("cutlass/arch/wmma.h", None),
|
||||
])
|
||||
self.instance_definitions = []
|
||||
self.instance_wrappers = []
|
||||
|
||||
@ -736,11 +911,14 @@ void initialize_${configuration_name}(Manifest &manifest) {
|
||||
def emit(self, operation):
|
||||
emitter = self.instance_emitter[operation.gemm_kind]()
|
||||
|
||||
for incl in emitter.includes:
|
||||
self.includes[incl] = None
|
||||
|
||||
self.operations.append(operation)
|
||||
|
||||
self.instance_definitions.append(emitter.emit(operation))
|
||||
|
||||
self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.gemm_kind], {
|
||||
self.instance_wrappers.append(SubstituteTemplate(emitter.instance_template(), {
|
||||
'configuration_name': self.configuration_name,
|
||||
'operation_name': operation.procedural_name(),
|
||||
'gemm_kind': self.gemm_kind_wrappers[operation.gemm_kind],
|
||||
@ -752,6 +930,13 @@ void initialize_${configuration_name}(Manifest &manifest) {
|
||||
|
||||
def __exit__(self, exception_type, exception_value, traceback):
|
||||
|
||||
# Write includes
|
||||
for incl, _ in self.includes.items():
|
||||
include_statement = "#include \"%s\"\n" % incl
|
||||
self.configuration_file.write(include_statement)
|
||||
|
||||
self.configuration_file.write(self.separator)
|
||||
|
||||
# Write instance definitions in top-level namespace
|
||||
for instance_definition in self.instance_definitions:
|
||||
self.configuration_file.write(instance_definition)
|
||||
|
||||
Reference in New Issue
Block a user