Add support for mixed 4-bit/8-bit data types GEMM (#1413)
* Add support for mixed 4-bit/8-bit data types GEMM * fix ( and ) --------- Co-authored-by: Aleksandar Samardžić <asamardzic@matf.bg.ac.rs> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
committed by
GitHub
parent
f7b19de32c
commit
e1976daacc
@ -2855,6 +2855,167 @@ def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version):
|
||||
op.C.alignment = 8
|
||||
|
||||
#
|
||||
def GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_a(manifest, cuda_version):
|
||||
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
|
||||
return
|
||||
|
||||
layouts = [
|
||||
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
|
||||
]
|
||||
|
||||
# Upcast on Operand A
|
||||
math_instructions = [
|
||||
MathInstruction( \
|
||||
[16, 8, 32], \
|
||||
DataType.s4, DataType.s8, DataType.s32, \
|
||||
OpcodeClass.TensorOp, \
|
||||
MathOperation.multiply_add_mixed_input_upcast),
|
||||
]
|
||||
|
||||
min_cc = 80
|
||||
max_cc = 1024
|
||||
|
||||
# For mixed-input alignment constraints are a list of lists, where the
|
||||
# inner list contains the alignment constraints for operands/matrices
|
||||
# [[alignA, alignB, alignC],..]
|
||||
alignment_constraints = [[32, 16, 4],]
|
||||
|
||||
for math_inst in math_instructions:
|
||||
tile_descriptions = [
|
||||
TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 32, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
|
||||
]
|
||||
|
||||
data_type = [
|
||||
math_inst.element_a,
|
||||
math_inst.element_b,
|
||||
math_inst.element_accumulator,
|
||||
math_inst.element_accumulator,
|
||||
]
|
||||
|
||||
# streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit.
|
||||
operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8)
|
||||
|
||||
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. S8 accumulation)
|
||||
if math_inst.element_a != math_inst.element_accumulator:
|
||||
alignment_constraints = [[32, 16, 16],]
|
||||
|
||||
data_type_mixed = [
|
||||
math_inst.element_a,
|
||||
math_inst.element_b,
|
||||
math_inst.element_b,
|
||||
DataType.f32
|
||||
]
|
||||
|
||||
operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp, SwizzlingFunctor.Identity8)
|
||||
|
||||
for op in operations:
|
||||
if op.tile_description.threadblock_shape[1] >= 128:
|
||||
if op.tile_description.threadblock_shape[0] == 32:
|
||||
op.C.alignment = 8
|
||||
else:
|
||||
op.C.alignment = 16
|
||||
else:
|
||||
op.C.alignment = 8
|
||||
|
||||
#
|
||||
def GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_b(manifest, cuda_version):
|
||||
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
|
||||
return
|
||||
|
||||
layouts = [
|
||||
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
|
||||
]
|
||||
|
||||
# Upcast on Operand B
|
||||
math_instructions = [
|
||||
MathInstruction( \
|
||||
[16, 8, 32], \
|
||||
DataType.s8, DataType.s4, DataType.s32, \
|
||||
OpcodeClass.TensorOp, \
|
||||
MathOperation.multiply_add_mixed_input_upcast),
|
||||
]
|
||||
|
||||
min_cc = 80
|
||||
max_cc = 1024
|
||||
|
||||
# For mixed-input alignment constraints are a list of lists, where the
|
||||
# inner list contains the alignment constraints for operands/matrices
|
||||
# [[alignA, alignB, alignC],..]
|
||||
alignment_constraints = [[16, 32, 4],]
|
||||
|
||||
for math_inst in math_instructions:
|
||||
tile_descriptions = [
|
||||
TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 32, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 32, 64], 6, [4, 1, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
|
||||
]
|
||||
|
||||
data_type = [
|
||||
math_inst.element_a,
|
||||
math_inst.element_b,
|
||||
math_inst.element_accumulator,
|
||||
math_inst.element_accumulator,
|
||||
]
|
||||
|
||||
# streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit.
|
||||
operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8)
|
||||
|
||||
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. S8 accumulation)
|
||||
if math_inst.element_a != math_inst.element_accumulator:
|
||||
alignment_constraints = [[16, 32, 16],]
|
||||
|
||||
data_type_mixed = [
|
||||
math_inst.element_a,
|
||||
math_inst.element_b,
|
||||
math_inst.element_a,
|
||||
DataType.f32,
|
||||
]
|
||||
|
||||
operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp, SwizzlingFunctor.Identity8)
|
||||
|
||||
for op in operations:
|
||||
if op.tile_description.threadblock_shape[1] >= 128:
|
||||
if op.tile_description.threadblock_shape[0] == 32:
|
||||
op.C.alignment = 8
|
||||
else:
|
||||
op.C.alignment = 16
|
||||
else:
|
||||
op.C.alignment = 8
|
||||
|
||||
#
|
||||
def GenerateSM80_SparseTensorOp_16864_TN(manifest, cuda_version):
|
||||
@ -4699,6 +4860,8 @@ def GenerateSM80(manifest, cuda_version):
|
||||
GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version)
|
||||
GenerateSM80_TensorOp_16816_mixed_input_upcast_b(manifest, cuda_version)
|
||||
GenerateSM80_TensorOp_16832_TN(manifest, cuda_version)
|
||||
GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_a(manifest, cuda_version)
|
||||
GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_b(manifest, cuda_version)
|
||||
GenerateSM80_SparseTensorOp_16864_TN(manifest, cuda_version)
|
||||
GenerateSM80_TensorOp_16832_Interleaved(manifest, cuda_version)
|
||||
GenerateSM80_TensorOp_16864_TN(manifest, cuda_version)
|
||||
|
||||
Reference in New Issue
Block a user