[Kernel/Model] Migrate mamba_ssm and causal_conv1d kernels to vLLM (#7651)
This commit is contained in:
@ -202,6 +202,31 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
|
||||
ops.impl("cutlass_scaled_mm_supports_fp8", torch::kCUDA,
|
||||
&cutlass_scaled_mm_supports_fp8);
|
||||
// Mamba selective scan kernel
|
||||
ops.def(
|
||||
"selective_scan_fwd(Tensor! u, Tensor! delta,"
|
||||
"Tensor! A, Tensor! B, Tensor! C,"
|
||||
"Tensor? D_, Tensor? z_, Tensor? delta_bias_,"
|
||||
"bool delta_softplus,"
|
||||
"Tensor? index_, Tensor? x) -> Tensor[]");
|
||||
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
|
||||
|
||||
ops.def(
|
||||
"causal_conv1d_update(Tensor! x,"
|
||||
"Tensor! conv_state,"
|
||||
"Tensor! weight,"
|
||||
"Tensor? bias_,"
|
||||
"bool silu_activation) -> Tensor");
|
||||
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
|
||||
|
||||
ops.def(
|
||||
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
|
||||
"Tensor? bias_,"
|
||||
"Tensor? seq_idx_,"
|
||||
"Tensor? initial_states_,"
|
||||
"Tensor? final_states_out_,"
|
||||
"bool silu_activation) -> Tensor");
|
||||
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
|
||||
#endif
|
||||
|
||||
// Quantized GEMM for GPTQ.
|
||||
|
||||
Reference in New Issue
Block a user