Merge pull request #239 from KeDengMS/kedeng/gelu

Fixes to Gelu for half and fusion
This commit is contained in:
Haicheng Wu
2021-05-08 12:51:42 -04:00
committed by GitHub
3 changed files with 51 additions and 2 deletions

View File

@ -139,7 +139,7 @@ struct GELU {
CUTLASS_HOST_DEVICE
T operator()(T const &scalar) const {
return T(cutlass::constants::half<T>() * scalar *
(cutlass::constants::one<T>() + erff( scalar / cutlass::constants::root_two<T>() )));
(cutlass::constants::one<T>() + (T)erff((float)(scalar / cutlass::constants::root_two<T>()))));
}
};
@ -152,6 +152,15 @@ struct GELU<float> {
}
};
template <>
struct GELU<double> {
CUTLASS_HOST_DEVICE
double operator()(double const &scalar) const {
return cutlass::constants::half<double>() * scalar *
(cutlass::constants::one<double>() + erf( scalar / cutlass::constants::root_two<double>() ));
}
};
template <typename T, int N>
struct GELU<Array<T, N> > {
CUTLASS_HOST_DEVICE

View File

@ -133,7 +133,8 @@ public:
/// Functionally required for serial reduction in the epilogue
CUTLASS_HOST_DEVICE
void set_k_partition(int k_partition) {
void set_k_partition(int k_partition, int k_partition_count) {
CUTLASS_UNUSED(k_partition_count);
if (k_partition) {
beta_ = ElementCompute(1);
}