[Kernel] Fullgraph and opcheck tests (#8479)
This commit is contained in:
@ -275,7 +275,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"Tensor! A, Tensor! B, Tensor! C,"
|
||||
"Tensor? D_, Tensor? z_, Tensor? delta_bias_,"
|
||||
"bool delta_softplus,"
|
||||
"Tensor? index_, Tensor(a! -> *)? x) -> Tensor(a)[]");
|
||||
"Tensor? index_, Tensor!? x) -> Tensor[]");
|
||||
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
|
||||
|
||||
ops.def(
|
||||
@ -292,7 +292,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"Tensor? bias_,"
|
||||
"Tensor? seq_idx_,"
|
||||
"Tensor? initial_states_,"
|
||||
"Tensor? final_states_out_,"
|
||||
"Tensor!? final_states_out_,"
|
||||
"bool silu_activation) -> Tensor");
|
||||
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user