Lazy scipy import (#2250)
This commit is contained in:
@ -34,7 +34,6 @@ import ctypes
|
||||
|
||||
from cutlass_library import SubstituteTemplate
|
||||
import numpy as np
|
||||
from scipy.special import erf
|
||||
|
||||
from cutlass_library import DataType, DataTypeTag
|
||||
from cutlass.backend.c_types import MatrixCoord_, tuple_factory
|
||||
@ -530,6 +529,7 @@ class hardswish(ActivationFunctor, metaclass=hardswishMeta):
|
||||
class geluMeta(ActivationMeta):
|
||||
@classmethod
|
||||
def numpy(cls, x):
|
||||
from scipy.special import erf
|
||||
return 0.5 * x * (1 + erf(x / np.sqrt(2.0)))
|
||||
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user