Support for Mixed Input TensorOp (#1084)
* Passing warp-level mixed input F16*(S8/U8) tests * passing device-level mixed input F16*(S8/U8) tests * add to profiler - I8 (111 TFLOPs), U (123 TFLOPs) * fast numeric conversions (I8 = 132 TFLOPs, U8 = 148 TFLOPs) * Speedup reference compilation (REVERT THIS COMMIT) * wider_add.u32_packed_sub.f16x2 (I8 = 132TFLOP/s, U8 = 170 TFLOP/s) * Improve s8->f16 cvt and support bf16*u8 @158 TFLOPs * BF16 * S8 (142 TFLOPs) * Handle mixed-input upcast on OperandA (Support [S8|U8]*[F16|BF16] * rename OpMultiplyAddMixedInput to OpMultiplyAddMixedInputUpcast * Add device-level test and profiler support for upcast on operand A * Move shfl before the cvt and reduce #shfls by 1/2 * fix smem_usage calculation for mixed_input types * uncomment the stuff (getting ready for merge) * profiler changes and mixed-input reference * mixed input reference are in a new file * use platform instead of std * comments and typo only * Use CreateGemmOperator and delete CreateMixedInputGemmOperator * copyright for new files * rebase follow-up
This commit is contained in:
@ -60,7 +60,11 @@ class Conv3dOperation:
|
||||
self.iterator_algorithm = iterator_algorithm
|
||||
self.stride_support = stride_support
|
||||
self.swizzling_functor = swizzling_functor
|
||||
|
||||
|
||||
#
|
||||
def is_mixed_input(self):
|
||||
return self.A.element != self.B.element
|
||||
|
||||
#
|
||||
def core_name(self):
|
||||
''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
|
||||
|
||||
Reference in New Issue
Block a user