[Kernel/Model] Migrate mamba_ssm and causal_conv1d kernels to vLLM (#7651)

This commit is contained in:
Mor Zusman
2024-08-29 01:06:52 +03:00
committed by GitHub
parent 8c56e57def
commit fdd9daafa3
20 changed files with 2815 additions and 31 deletions

View File

@ -202,6 +202,31 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
ops.impl("cutlass_scaled_mm_supports_fp8", torch::kCUDA,
&cutlass_scaled_mm_supports_fp8);
// Mamba selective scan kernel
ops.def(
"selective_scan_fwd(Tensor! u, Tensor! delta,"
"Tensor! A, Tensor! B, Tensor! C,"
"Tensor? D_, Tensor? z_, Tensor? delta_bias_,"
"bool delta_softplus,"
"Tensor? index_, Tensor? x) -> Tensor[]");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
ops.def(
"causal_conv1d_update(Tensor! x,"
"Tensor! conv_state,"
"Tensor! weight,"
"Tensor? bias_,"
"bool silu_activation) -> Tensor");
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
ops.def(
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
"Tensor? bias_,"
"Tensor? seq_idx_,"
"Tensor? initial_states_,"
"Tensor? final_states_out_,"
"bool silu_activation) -> Tensor");
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
#endif
// Quantized GEMM for GPTQ.