CUTLASS 2.10 updates (#622)

Co-authored-by: Aniket Shivam <ashivam@nvidia.com>
This commit is contained in:
ANIKET SHIVAM
2022-09-12 18:26:30 -07:00
committed by GitHub
parent beae168f90
commit e773429f7e
96 changed files with 8365 additions and 1667 deletions

View File

@ -92,25 +92,35 @@ Example 1: SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32_256x128x128_64x64x128
```python
python gemm.py -i 16 8 32 -ta int8 -tb int8 -tc int8 -tacc int32 -m multiply_add -op TensorOp -b 128 128 128 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 16 -lb ColumnMajor -ab 16 -lc RowMajor -ac 16 -te float32 -ep FastLinearCombinationClamp -sw IdentitySwizzle2 -p 512 512 512 -alpha 1.0 -beta 0.0 -gm Gemm -k 1
```
### Batched & Array GEMM
Example 1: Batched GEMM
```python
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Batched -k 1 -batch 3
```
Example 2: Array GEMM
```python
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Array -k 1 -batch 2
```
***
## GEMM Grouped Examples
The GEMM Grouped examples use numpy to create input tensors and verify the results.
Example 1: SM80_Device_GemmGrouped_f16t_f16t_f32t_tensor_op_f32_128x128x32_64x64x32, device schedule
```python
python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 0.0 -pm Device
python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 0.0 -pm Device
```
Example 2: SM80_Device_GemmGrouped_f64n_f64n_f64t_tensor_op_f64_64x64x16_32x32x16, host schedule
```python
python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle2 -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 1.0 -pm Host
python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 1.0 -pm Host
```
Example 3: SM80_Device_GemmGrouped_f32n_f32n_f32n_simt_f32_128x64x8_64x32x1, device schedule
```python
python gemm_grouped.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 64 8 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device
python gemm_grouped.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 64 8 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device
```
Example 4: SM80_Device_GemmGrouped_f16t_f16t_f32t_tensor_op_f32_128x128x32_64x64x32, device schedule
```python
python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle8 -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device
python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device
```
***
## Conv2d Example
@ -160,3 +170,61 @@ Example 4: SM80_Device_Conv2d_Strided_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nh
```python
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw StridedDgradIdentitySwizzle1 -co dgrad -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 56 56 12 -krsc 8 1 1 12 -pad 0 0 0 0 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
```
## Epilogue
### Bias
To replace C with a bias vector, add `-bias` flag.
### Activation function
Example 1: ReLU
```python
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 -bias -activ relu
```
Example 2: leaky ReLU
```python
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb ColumnMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2 -bias -activ leaky_relu -activ_arg 0.2
```
Example 3: tanh (alpha=0 to avoid saturation)
```python
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 2 -bias -activ tanh
```
Example 4: sigmoid
```python
python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 0.0 -beta 0.5 -pm Host -bias -activ sigmoid -bias -activ sigmoid
```
Example 5: SiLU
```python
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 2 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 4 4 12 -krsc 8 3 3 12 -pad 0 0 0 0 -stride 3 3 -dilation 1 1 -alpha 0.0 -beta 0.5 -bias -activ silu
```
Example 6: HardSwish
```python
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia few_channels -sm Serial -k 1 -nhwc 1 16 16 2 -krsc 16 3 3 2 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 0.0 -beta 0.5 -bias -activ hardswish
```
Example 7: GELU
```python
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 0.0 -beta 0.5 -gm GemmSplitKParallel -k 5 -bias -activ gelu
```
### Epilogue Visitor Tree
Example 1:
```python
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowBroadcast -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
```
Example 2:
```python
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -epv ColumnBroadcast -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
```
Example 3:
```python
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowReduction -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
```
Example 4:
```python
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv ColumnReduction -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
```
Example 5:
```python
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowReduction -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Batched -k 1 -batch 3
```
Example 6:
```python
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv ColumnBroadcast -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Array -k 1 -batch 3
```

View File

@ -33,6 +33,7 @@ import pycutlass
from pycutlass import *
from pycutlass.conv2d_operation import *
from pycutlass.utils import reference_model
import torch.nn.functional as F
import argparse
@ -127,6 +128,13 @@ parser.add_argument("-stride", "--stride", nargs=2, type=int, help="stride (stri
parser.add_argument("-dilation", "--dilation", nargs=2, type=int, help="dilation (dilation_h, dilation_w)")
parser.add_argument("-alpha", "--alpha", default=1.0, type=float, help="alpha")
parser.add_argument("-beta", "--beta", default=0.0, type=float, help="beta")
parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector")
# Activation function
parser.add_argument("-activ", "--activation_function", default="identity",
choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function")
parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float,
help="addition arguments for activation")
parser.add_argument('--print_cuda', action="store_true",
help="print the underlying CUDA kernel")
@ -138,6 +146,8 @@ except:
pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
np.random.seed(0)
element_a = getattr(cutlass, args.element_a)
element_b = getattr(cutlass, args.element_b)
element_c = getattr(cutlass, args.element_c)
@ -152,7 +162,7 @@ math_inst = MathInstruction(
tile_description = TileDescription(
args.threadblock_shape, args.stages, args.warp_count,
math_inst, args.compute_capability, args.compute_capability
math_inst
)
layout_a = getattr(cutlass, args.layout_a)
@ -172,7 +182,16 @@ C = TensorDescription(
)
element_epilogue = getattr(cutlass, args.element_epilogue)
epilogue_functor = getattr(EpilogueFunctor, args.epilogue_functor)
if (args.activation_function == "identity"
or (args.split_k_mode == "Parallel" and args.split_k_slices > 1)):
#
epilogue_functor = getattr(pycutlass, args.epilogue_functor)(
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
else:
epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")(
getattr(pycutlass, args.activation_function)(element_epilogue),
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
iterator_algorithm = getattr(cutlass.conv.IteratorAlgorithm, args.iterator_algorithm)
swizzling_functor = getattr(cutlass, args.swizzling_functor)
stride_support = getattr(StrideSupport, args.stride_support)
@ -181,7 +200,7 @@ conv_kind = getattr(cutlass.conv.Operator, args.conv_kind)
operation = Conv2dOperation(
conv_kind=conv_kind, iterator_algorithm=iterator_algorithm,
arch=args.compute_capability, tile_description=tile_description,
A=A, B=B, C=C, element_epilogue=element_epilogue, stride_support=stride_support,
A=A, B=B, C=C, stride_support=stride_support,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
)
@ -191,10 +210,18 @@ if args.print_cuda:
operations = [operation,]
if args.split_k_mode == "Parallel" and args.split_k_slices > 1:
if (args.activation_function == "identity"):
epilogue_functor_reduction = getattr(pycutlass, args.epilogue_functor)(
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
else:
epilogue_functor_reduction = getattr(pycutlass, "LinearCombinationGeneric")(
getattr(pycutlass, args.activation_function)(element_epilogue),
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
reduction_operation = ReductionOperation(
shape=cutlass.MatrixCoord(4, 32 * C.alignment),
C=C, element_accumulator=element_acc,
element_compute=element_epilogue,
epilogue_functor=epilogue_functor_reduction,
count=C.alignment
)
operations.append(reduction_operation)
@ -219,9 +246,18 @@ tensor_A_size = cutlass.conv.implicit_gemm_tensor_a_size(
tensor_B_size = cutlass.conv.implicit_gemm_tensor_b_size(
conv_kind, problem_size
)
tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_size(
conv_kind, problem_size
)
if args.bias:
tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_extent(
conv_kind, problem_size
).at(3)
else:
tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_size(
conv_kind, problem_size
)
tensor_D_size = cutlass.conv.implicit_gemm_tensor_c_size(
conv_kind, problem_size
)
if args.element_a != "int8":
tensor_A = torch.ceil(torch.empty(size=(tensor_A_size,), dtype=getattr(torch, args.element_a), device="cuda").uniform_(-8.5, 7.5))
@ -238,12 +274,12 @@ if args.element_c != "int8":
else:
tensor_C = torch.empty(size=(tensor_C_size,), dtype=getattr(torch, args.element_c), device="cuda").uniform_(-2, 2)
tensor_D = torch.ones_like(tensor_C)
tensor_D = torch.ones(size=(tensor_D_size,), dtype=getattr(torch, args.element_c), device="cuda")
arguments = Conv2dArguments(
operation=operation, problem_size=problem_size, A=tensor_A,
B=tensor_B, C=tensor_C, D=tensor_D,
output_op = LinearCombinationFunctorArguments(args.alpha, args.beta),
output_op = operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)),
split_k_mode=getattr(cutlass.conv.SplitKMode, args.split_k_mode),
split_k_slices=problem_size.split_k_slices
)
@ -257,7 +293,8 @@ if args.split_k_mode == "Parallel" and args.split_k_slices > 1:
workspace=arguments.ptr_D,
destination=tensor_D,
source=tensor_C,
output_op = LinearCombinationFunctorArguments(args.alpha, args.beta)
output_op = reduction_operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)),
bias = arguments.bias
)
operation.run(arguments)
@ -270,8 +307,12 @@ else:
reference_model = Conv2dReferenceModule(A, B, C, conv_kind)
tensor_D_ref = reference_model.run(tensor_A, tensor_B, tensor_C, arguments.problem_size, args.alpha, args.beta)
assert torch.equal(tensor_D, tensor_D_ref)
tensor_D_ref = reference_model.run(tensor_A, tensor_B, tensor_C, arguments.problem_size, args.alpha, args.beta, args.bias)
if (args.activation_function != "identity"):
tensor_D_ref = getattr(F, args.activation_function)(*([tensor_D_ref,] + args.activation_args))
try:
assert torch.equal(tensor_D, tensor_D_ref)
except:
assert torch.allclose(tensor_D, tensor_D_ref, rtol=1e-2)
print("Passed.")

View File

@ -99,9 +99,11 @@ parser.add_argument("-te", "--element_epilogue", default="float32", type=str,
parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination",
type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'],
help="This option describes the epilogue part of the kernel")
parser.add_argument("-epv", "--epilogue_visitor", default=None,
type=str, choices=['RowReduction', 'ColumnReduction', 'RowBroadcast', 'ColumnBroadcast'], help="epilogue visitor for more complex epilogues")
# swizzling
parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[
"IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle"],
"IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle", "BatchedIdentitySwizzle"],
help="This option describes how thread blocks are scheduled on GPU")
# Argument
@ -113,17 +115,22 @@ parser.add_argument("-alpha", "--alpha", default=1.0, type=float,
parser.add_argument("-beta", "--beta", default=0.0, type=float,
help="Scaling factor of C")
parser.add_argument("-gm", "--gemm_mode", default="Gemm", type=str,
choices=["Gemm", "GemmSplitKParallel"],
choices=["Gemm", "GemmSplitKParallel", "Batched", "Array"],
help="GEMM mode. Gemm is used for non-splitK or serial-splitK. \
GemmSplitKParallel is used for parallel splitK")
parser.add_argument('-k', '--split_k_slices', default=1,
type=int, help="Number of split-k partitions. (default 1)")
parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector")
parser.add_argument('-batch', '--batch', default=1, type=int, help="batch size for batched GEMM")
# Activation function
parser.add_argument("-activ", "--activation_function", default="identity",
choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function")
parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float,
help="addition arguments for activation")
parser.add_argument('--print_cuda', action="store_true",
help="print the underlying CUDA kernel")
# parser.add_argument('-h', '--help', action="store_true",
# help="print help information")
try:
args = parser.parse_args()
@ -131,6 +138,9 @@ except:
sys.exit(0)
pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
pycutlass.compiler.nvcc()
np.random.seed(0)
element_a = getattr(cutlass, args.element_a)
element_b = getattr(cutlass, args.element_b)
@ -146,7 +156,7 @@ math_inst = MathInstruction(
tile_description = TileDescription(
args.threadblock_shape, args.stages, args.warp_count,
math_inst, args.compute_capability, args.compute_capability
math_inst
)
layout_a = getattr(cutlass, args.layout_a)
@ -166,13 +176,83 @@ C = TensorDescription(
)
element_epilogue = getattr(cutlass, args.element_epilogue)
epilogue_functor = getattr(EpilogueFunctor, args.epilogue_functor)
if (args.activation_function == "identity"
or (args.gemm_mode == "GemmSplitKParallel" and args.split_k_slices > 1)):
#
epilogue_functor = getattr(pycutlass, args.epilogue_functor)(
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
else:
epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")(
getattr(pycutlass, args.activation_function)(element_epilogue),
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
swizzling_functor = getattr(cutlass, args.swizzling_functor)
visitor = args.epilogue_visitor is not None
if args.epilogue_visitor == "ColumnReduction":
class ColumnReduction_(EpilogueVisitTree):
def __call__(
self, accum: 'tensor', c: 'tensor',
alpha: 'scalar', beta: 'scalar'):
#
D = alpha * accum + beta * c
reduction = reduction_op(D, "column", "Add", args.threadblock_shape[0])
return D, reduction
epilogue_functor = ColumnReduction_(
epilogue_functor, tile_description, math_inst.element_accumulator,
C.alignment, element_epilogue, C.element)
epilogue_functor.initialize()
elif args.epilogue_visitor == "RowReduction":
class RowReduction_(EpilogueVisitTree):
def __call__(
self, accum: 'tensor', c: 'tensor',
alpha: 'scalar', beta: 'scalar'):
#
D = alpha * accum + tanh.numpy(beta * c)
reduction = reduction_op(D, "row", "Add", args.threadblock_shape[1])
return D, reduction
epilogue_functor = RowReduction_(
epilogue_functor, tile_description, math_inst.element_accumulator,
C.alignment, element_epilogue, C.element)
epilogue_functor.initialize()
elif args.epilogue_visitor == "RowBroadcast":
class RowBroadcast_(EpilogueVisitTree):
def __call__(
self, accum: 'tensor', c: 'tensor',
vector: 'row', alpha: 'scalar', beta: 'scalar'):
#
T = accum + vector
scale_T = alpha * T
Z = relu.numpy(scale_T + beta * c)
return Z, T
epilogue_functor = RowBroadcast_(
epilogue_functor, tile_description, math_inst.element_accumulator,
C.alignment, element_epilogue, C.element)
epilogue_functor.initialize()
elif args.epilogue_visitor == "ColumnBroadcast":
class ColumnBroadcast_(EpilogueVisitTree):
def __call__(
self, accum: 'tensor', c: 'tensor',
vector: 'column', alpha: 'scalar', beta: 'scalar'):
#
T = accum + vector
scale_T = leaky_relu.numpy(alpha * T, 0.2)
Z = scale_T + beta * c
return Z, T
epilogue_functor = ColumnBroadcast_(
epilogue_functor, tile_description, math_inst.element_accumulator,
C.alignment, element_epilogue, C.element)
epilogue_functor.initialize()
else:
epilogue_functor = epilogue_functor
operation = GemmOperationUniversal(
arch=args.compute_capability, tile_description=tile_description,
A=A, B=B, C=C, element_epilogue=element_epilogue,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
A=A, B=B, C=C,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor,
visitor=visitor
)
if args.print_cuda:
@ -181,10 +261,19 @@ if args.print_cuda:
operations = [operation, ]
if args.gemm_mode == "GemmSplitKParallel":
if (args.activation_function == "identity"):
epilogue_functor_reduction = getattr(pycutlass, args.epilogue_functor)(
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
else:
epilogue_functor_reduction = getattr(pycutlass, "LinearCombinationGeneric")(
getattr(pycutlass, args.activation_function)(element_epilogue),
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
reduction_operation = ReductionOperation(
shape=cutlass.MatrixCoord(4, 32 * C.alignment),
C=C, element_accumulator=element_acc,
element_compute=element_epilogue,
element_compute=element_epilogue,
epilogue_functor=epilogue_functor_reduction,
count=C.alignment
)
operations.append(reduction_operation)
@ -196,47 +285,102 @@ pycutlass.compiler.add_module(operations)
problem_size = cutlass.gemm.GemmCoord(
args.problem_size[0], args.problem_size[1], args.problem_size[2])
tensor_a_size = args.batch * problem_size.m() * problem_size.k()
if args.element_a != "int8":
if args.element_a == "bfloat16":
tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
* problem_size.k(),))).astype(bfloat16)
tensor_A = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(tensor_a_size,))
).astype(bfloat16)
else:
tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
* problem_size.k(),))).astype(getattr(np, args.element_a))
tensor_A = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(tensor_a_size,))
).astype(getattr(np, args.element_a))
else:
tensor_A = np.random.uniform(low=-2, high=2, size=(problem_size.m()
* problem_size.k(),)).astype(getattr(np, args.element_a))
tensor_A = np.random.uniform(
low=-2, high=2,size=(tensor_a_size,)
).astype(getattr(np, args.element_a))
tensor_b_size = args.batch * problem_size.k() * problem_size.n()
if args.element_b != "int8":
if args.element_b == "bfloat16":
tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k()
* problem_size.n(),))).astype(bfloat16)
tensor_B = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(tensor_b_size,))
).astype(bfloat16)
else:
tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k()
* problem_size.n(),))).astype(getattr(np, args.element_b))
tensor_B = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(tensor_b_size,))
).astype(getattr(np, args.element_b))
else:
tensor_B = np.random.uniform(low=-2, high=2, size=(problem_size.k()
* problem_size.n(),)).astype(getattr(np, args.element_b))
tensor_B = np.random.uniform(
low=-2, high=2, size=(tensor_b_size,)
).astype(getattr(np, args.element_b))
if args.element_c != "int8":
if args.element_c == "bfloat16":
tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
* problem_size.n(),))).astype(bfloat16)
if args.bias:
if args.layout_c == "RowMajor":
tensor_c_size = args.batch * problem_size.n()
elif args.layout_c == "ColumnMajor":
tensor_c_size = args.batch * problem_size.m()
else:
raise ValueError(args.layout_c)
else:
tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
* problem_size.n(),))).astype(getattr(np, args.element_c))
tensor_c_size = args.batch * problem_size.m() * problem_size.n()
if args.element_c == "bfloat16":
tensor_C = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(tensor_c_size,))
).astype(bfloat16)
else:
tensor_C = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(tensor_c_size,))
).astype(getattr(np, args.element_c))
else:
tensor_C = np.random.uniform(low=-2, high=2, size=(problem_size.m()
* problem_size.n(),)).astype(getattr(np, args.element_c))
tensor_C = np.random.uniform(
low=-2, high=2, size=(args.batch * problem_size.m() * problem_size.n(),)
).astype(getattr(np, args.element_c))
tensor_D = np.ones_like(tensor_C)
tensor_D = np.zeros(
shape=(args.batch * problem_size.m() * problem_size.n(),)
).astype(getattr(np, args.element_c))
if args.epilogue_visitor == "RowReduction":
cta_n = args.threadblock_shape[1]
num_cta_n = (problem_size.n() + cta_n - 1) // cta_n
reduction = np.zeros(shape=(args.batch * problem_size.m() * num_cta_n,), dtype=getattr(np, args.element_c))
output_op = operation.epilogue_type(
D=tensor_D, alpha=args.alpha, beta=args.beta, c=tensor_C, reduction=reduction, problem_size=[problem_size.m(), problem_size.n()]
)
elif args.epilogue_visitor == "ColumnReduction":
cta_m = args.threadblock_shape[0]
num_cta_m = (problem_size.m() + cta_m - 1) // cta_m
reduction = np.zeros(shape=(args.batch * problem_size.n() * num_cta_m,), dtype=getattr(np, args.element_c))
output_op = operation.epilogue_type(
D=tensor_D, alpha=args.alpha, beta=args.beta, c=tensor_C, reduction=reduction, problem_size=[problem_size.m(), problem_size.n()]
)
elif args.epilogue_visitor == "RowBroadcast":
vector = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(args.batch, 1, problem_size.n()))
).astype(getattr(np, args.element_c))
tensor_t = np.empty_like(tensor_D)
output_op = operation.epilogue_type(
c=tensor_C, vector=vector, alpha=args.alpha, beta=args.beta, Z=tensor_D, T=tensor_t, problem_size=[problem_size.m(), problem_size.n()]
)
elif args.epilogue_visitor == "ColumnBroadcast":
vector = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(args.batch, problem_size.m(), 1))
).astype(getattr(np, args.element_c))
tensor_t = np.empty_like(tensor_D)
output_op = operation.epilogue_type(
c=tensor_C, vector=vector, alpha=args.alpha, beta=args.beta, Z=tensor_D, T=tensor_t, problem_size=[problem_size.m(), problem_size.n()]
)
else:
output_op = operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args))
arguments = GemmArguments(
operation=operation, problem_size=problem_size,
A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D,
output_op=LinearCombinationFunctorArguments(args.alpha, args.beta),
output_op=output_op,
gemm_mode=getattr(cutlass.gemm.Mode, args.gemm_mode),
split_k_slices=args.split_k_slices
split_k_slices=args.split_k_slices, batch=args.batch
)
if args.gemm_mode == "GemmSplitKParallel":
@ -245,7 +389,8 @@ if args.gemm_mode == "GemmSplitKParallel":
problem_size=[problem_size.m(), problem_size.n()],
partitions=args.split_k_slices, workspace=arguments.ptr_D,
destination=tensor_D, source=tensor_C,
output_op=LinearCombinationFunctorArguments(args.alpha, args.beta)
output_op=reduction_operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)),
bias = arguments.bias
)
operation.run(arguments)
@ -259,8 +404,42 @@ else:
# run the host reference module
reference = ReferenceModule(A, B, C)
tensor_D_ref = reference.run(
tensor_A, tensor_B, tensor_C, problem_size, args.alpha, args.beta)
tensor_A, tensor_B, tensor_C, problem_size, args.alpha, args.beta, args.bias, args.batch)
assert np.array_equal(tensor_D, tensor_D_ref)
if args.epilogue_visitor in ["RowBroadcast", "ColumnBroadcast"]:
tensor_D_ref = (tensor_D_ref.reshape((args.batch, problem_size.m(), problem_size.n())) + vector).flatten()
tensor_D_ref = getattr(pycutlass, args.activation_function).numpy(*([tensor_D_ref,] + args.activation_args))
if args.epilogue_visitor in ["RowReduction", "ColumnReduction"]:
output_op.sync()
accum_ref = reference.run(
tensor_A, tensor_B, tensor_C, problem_size, 1.0, 0.0, args.bias, args.batch)
tensor_D_ref, reduction_ref = epilogue_functor(
accum_ref.reshape((args.batch, problem_size.m(), problem_size.n())),
tensor_C.reshape((args.batch, problem_size.m(), problem_size.n())),
args.alpha, args.beta
)
tensor_D_ref = tensor_D_ref.flatten()
reduction_ref = reduction_ref.flatten()
assert np.allclose(reduction_ref, reduction, atol=1e-2)
elif args.epilogue_visitor in ["RowBroadcast", "ColumnBroadcast"]:
output_op.sync()
accum_ref = reference.run(
tensor_A, tensor_B, tensor_C, problem_size, 1.0, 0.0, args.bias, args.batch)
tensor_D_ref, tensor_T_ref = epilogue_functor(
accum_ref.reshape((args.batch, problem_size.m(), problem_size.n())),
tensor_C.reshape((args.batch, problem_size.m(), problem_size.n())),
vector, args.alpha, args.beta)
tensor_D_ref = tensor_D_ref.flatten()
tensor_T_ref = tensor_T_ref.flatten()
assert np.array_equal(tensor_t, tensor_T_ref)
try:
assert np.array_equal(tensor_D, tensor_D_ref)
except:
assert np.allclose(tensor_D, tensor_D_ref, atol=1e-5)
print("Passed.")

View File

@ -99,7 +99,10 @@ parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination",
# swizzling
parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[
"IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle"],
help="This option describes how thread blocks are scheduled on GPU")
help="This option describes how thread blocks are scheduled on GPU. \
NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels. \
This parameter is passed in at present to match the APIs of other kernels. The parameter \
is unused within the kernel")
# precompute mode
parser.add_argument("-pm", "--precompute_mode",
default="Device", type=str, choices=["Host", "Device"],
@ -109,7 +112,13 @@ parser.add_argument("-p", "--problem_size_dir", type=str,
help="path to the csv file contains the problem sizes")
parser.add_argument("-alpha", "--alpha", default=1.0, type=float, help="alpha")
parser.add_argument("-beta", "--beta", default=0.0, type=float, help="beta")
parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector")
# Activation function
parser.add_argument("-activ", "--activation_function", default="identity",
choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function")
parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float,
help="addition arguments for activation")
parser.add_argument('--print_cuda', action="store_true",
help="print the underlying CUDA kernel")
@ -120,6 +129,8 @@ except:
pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
np.random.seed(0)
element_a = getattr(cutlass, args.element_a)
element_b = getattr(cutlass, args.element_b)
element_c = getattr(cutlass, args.element_c)
@ -134,7 +145,7 @@ math_inst = MathInstruction(
tile_description = TileDescription(
args.threadblock_shape, args.stages, args.warp_count,
math_inst, args.compute_capability, args.compute_capability
math_inst
)
layout_a = getattr(cutlass, args.layout_a)
@ -154,13 +165,19 @@ C = TensorDescription(
)
element_epilogue = getattr(cutlass, args.element_epilogue)
epilogue_functor = getattr(EpilogueFunctor, args.epilogue_functor)
if args.activation_function == "identity":
epilogue_functor = getattr(pycutlass, args.epilogue_functor)(
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
else:
epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")(
getattr(pycutlass, args.activation_function)(element_epilogue),
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
swizzling_functor = getattr(cutlass, args.swizzling_functor)
precompute_mode = getattr(SchedulerMode, args.precompute_mode)
operation = GemmOperationGrouped(
arch=args.compute_capability, tile_description=tile_description,
A=A, B=B, C=C, element_epilogue=element_epilogue,
A=A, B=B, C=C,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor,
precompute_mode=precompute_mode
)
@ -214,28 +231,45 @@ for problem_size in problem_sizes:
* problem_size.n(),)).astype(getattr(np, args.element_b))
if args.element_c != "int8":
if args.element_c == "bfloat16":
tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
* problem_size.n(),))).astype(bfloat16)
if args.bias:
if args.layout_c == "RowMajor":
c_size = problem_size.n()
elif args.layout_c == "ColumnMajor":
c_size = problem_size.m()
else:
raise ValueError(args.layout_c)
else:
tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
* problem_size.n(),))).astype(getattr(np, args.element_c))
c_size = problem_size.m() * problem_size.n()
if args.element_c == "bfloat16":
tensor_C = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(c_size,))
).astype(bfloat16)
else:
tensor_C = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(c_size,))
).astype(getattr(np, args.element_c))
else:
tensor_C = np.random.uniform(low=-2, high=2, size=(problem_size.m()
* problem_size.n(),)).astype(getattr(np, args.element_c))
tensor_D = np.zeros_like(tensor_C)
tensor_C = np.random.uniform(
low=-2, high=2, size=(problem_size.m() * problem_size.n(),)
).astype(getattr(np, args.element_c))
tensor_D = np.zeros(
shape=(problem_size.m() * problem_size.n(),)
).astype(getattr(np, args.element_c))
tensor_As.append(tensor_A)
tensor_Bs.append(tensor_B)
tensor_Cs.append(tensor_C)
tensor_Ds.append(tensor_D)
tensor_D_refs.append(reference_module.run(
tensor_A, tensor_B, tensor_C, problem_size, args.alpha, args.beta))
tensor_D_ref = reference_module.run(
tensor_A, tensor_B, tensor_C, problem_size,
args.alpha, args.beta, args.bias)
tensor_D_ref = getattr(pycutlass, args.activation_function).numpy(*([tensor_D_ref,] + args.activation_args))
tensor_D_refs.append(tensor_D_ref)
problem_sizes_coord.append(problem_size)
arguments = GemmGroupedArguments(
operation, problem_sizes_coord, tensor_As, tensor_Bs, tensor_Cs, tensor_Ds,
output_op=LinearCombinationFunctorArguments(args.alpha, args.beta)
output_op=operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args))
)
operation.run(arguments)
@ -243,6 +277,9 @@ operation.run(arguments)
arguments.sync()
for tensor_d, tensor_d_ref in zip(tensor_Ds, tensor_D_refs):
assert np.array_equal(tensor_d, tensor_d_ref)
try:
assert np.array_equal(tensor_d, tensor_d_ref)
except:
assert np.allclose(tensor_d, tensor_d_ref, rtol=1e-5)
print("Passed.")