[Kernel] Add cuda kernel for gpt_oss activation (#22951)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2025-08-17 13:03:24 +08:00
committed by GitHub
parent 87f48623a5
commit 4d4061b6e7
9 changed files with 157 additions and 42 deletions

View File

@ -128,6 +128,45 @@ __global__ void act_and_mul_kernel_with_param(
}
}
template <typename T>
__device__ __forceinline__ T swigluoai_and_mul(const T& gate, const T& up,
float alpha, float limit) {
// clamp gate: min=None, max=limit
const float gate_f = (float)gate;
const float clamped_gate = gate_f > limit ? limit : gate_f;
// clamp up: min=-limit, max=limit
const float up_f = (float)up;
const float clamped_up =
up_f > limit ? limit : (up_f < -limit ? -limit : up_f);
// glu = gate * sigmoid(gate * alpha)
const float sigmoid_val = 1.0f / (1.0f + expf(-clamped_gate * alpha));
const float glu = clamped_gate * sigmoid_val;
// (up + 1) * glu
return (T)((clamped_up + 1.0f) * glu);
}
template <typename scalar_t,
scalar_t (*ACT_FN)(const scalar_t&, const scalar_t&, const float,
const float)>
__global__ void swigluoai_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d, const float alpha, const float limit) {
const int64_t token_idx = blockIdx.x;
// TODO: Vectorize loads and stores.
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
// gate = x[..., ::2] (even indices)
const scalar_t gate = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx]);
// up = x[..., 1::2] (odd indices)
const scalar_t up = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx + 1]);
out[token_idx * d + idx] = ACT_FN(gate, up, alpha, limit);
}
}
} // namespace vllm
#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \
@ -145,11 +184,31 @@ __global__ void act_and_mul_kernel_with_param(
PARAM); \
});
#define LAUNCH_SIGLUOAI_AND_MUL(KERNEL, ALPHA, LIMIT) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "clamp_swiglu_kernel_with_params", [&] { \
vllm::swigluoai_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d, ALPHA, \
LIMIT); \
});
void fatrelu_and_mul(torch::Tensor& out, // [..., d],
torch::Tensor& input, // [..., 2 * d]
double threshold) {
LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(vllm::fatrelu_kernel, threshold);
}
void swigluoai_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., 2 * d]
double alpha, double limit) {
LAUNCH_SIGLUOAI_AND_MUL(vllm::swigluoai_and_mul, alpha, limit);
}
namespace vllm {
// Element-wise activation kernel template.

View File

@ -138,6 +138,8 @@ void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
void fatrelu_and_mul(torch::Tensor& out, torch::Tensor& input,
double threshold);
void swigluoai_and_mul(torch::Tensor& out, torch::Tensor& input,
double alpha = 1.702, double limit = 7.0);
void gelu_new(torch::Tensor& out, torch::Tensor& input);

View File

@ -130,6 +130,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()");
ops.impl("fatrelu_and_mul", torch::kCUDA, &fatrelu_and_mul);
ops.def(
"swigluoai_and_mul(Tensor! out, Tensor input, float alpha=1.702, float "
"limit=7.0) "
"-> ()");
ops.impl("swigluoai_and_mul", torch::kCUDA, &swigluoai_and_mul);
// GELU implementation used in GPT-2.
ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_new", torch::kCUDA, &gelu_new);