Add support for mixed 4-bit/8-bit data types GEMM (#1413)

* Add support for mixed 4-bit/8-bit data types GEMM

* fix ( and )

---------

Co-authored-by: Aleksandar Samardžić <asamardzic@matf.bg.ac.rs>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Aleksandar Samardžić
2024-08-30 05:11:06 +02:00
committed by GitHub
parent f7b19de32c
commit e1976daacc
15 changed files with 960 additions and 14 deletions

View File

@ -2855,6 +2855,167 @@ def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version):
op.C.alignment = 8
#
def GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_a(manifest, cuda_version):
if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
return
layouts = [
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
]
# Upcast on Operand A
math_instructions = [
MathInstruction( \
[16, 8, 32], \
DataType.s4, DataType.s8, DataType.s32, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input_upcast),
]
min_cc = 80
max_cc = 1024
# For mixed-input alignment constraints are a list of lists, where the
# inner list contains the alignment constraints for operands/matrices
# [[alignA, alignB, alignC],..]
alignment_constraints = [[32, 16, 4],]
for math_inst in math_instructions:
tile_descriptions = [
TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([ 32, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
]
data_type = [
math_inst.element_a,
math_inst.element_b,
math_inst.element_accumulator,
math_inst.element_accumulator,
]
# streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit.
operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8)
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. S8 accumulation)
if math_inst.element_a != math_inst.element_accumulator:
alignment_constraints = [[32, 16, 16],]
data_type_mixed = [
math_inst.element_a,
math_inst.element_b,
math_inst.element_b,
DataType.f32
]
operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp, SwizzlingFunctor.Identity8)
for op in operations:
if op.tile_description.threadblock_shape[1] >= 128:
if op.tile_description.threadblock_shape[0] == 32:
op.C.alignment = 8
else:
op.C.alignment = 16
else:
op.C.alignment = 8
#
def GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_b(manifest, cuda_version):
if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
return
layouts = [
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
]
# Upcast on Operand B
math_instructions = [
MathInstruction( \
[16, 8, 32], \
DataType.s8, DataType.s4, DataType.s32, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input_upcast),
]
min_cc = 80
max_cc = 1024
# For mixed-input alignment constraints are a list of lists, where the
# inner list contains the alignment constraints for operands/matrices
# [[alignA, alignB, alignC],..]
alignment_constraints = [[16, 32, 4],]
for math_inst in math_instructions:
tile_descriptions = [
TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 32, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 32, 64], 6, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
]
data_type = [
math_inst.element_a,
math_inst.element_b,
math_inst.element_accumulator,
math_inst.element_accumulator,
]
# streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit.
operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8)
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. S8 accumulation)
if math_inst.element_a != math_inst.element_accumulator:
alignment_constraints = [[16, 32, 16],]
data_type_mixed = [
math_inst.element_a,
math_inst.element_b,
math_inst.element_a,
DataType.f32,
]
operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp, SwizzlingFunctor.Identity8)
for op in operations:
if op.tile_description.threadblock_shape[1] >= 128:
if op.tile_description.threadblock_shape[0] == 32:
op.C.alignment = 8
else:
op.C.alignment = 16
else:
op.C.alignment = 8
#
def GenerateSM80_SparseTensorOp_16864_TN(manifest, cuda_version):
@ -4699,6 +4860,8 @@ def GenerateSM80(manifest, cuda_version):
GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version)
GenerateSM80_TensorOp_16816_mixed_input_upcast_b(manifest, cuda_version)
GenerateSM80_TensorOp_16832_TN(manifest, cuda_version)
GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_a(manifest, cuda_version)
GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_b(manifest, cuda_version)
GenerateSM80_SparseTensorOp_16864_TN(manifest, cuda_version)
GenerateSM80_TensorOp_16832_Interleaved(manifest, cuda_version)
GenerateSM80_TensorOp_16864_TN(manifest, cuda_version)