[Kernel] Apply torch.Tag.needs_fixed_stride_order only for torch==2.6.0 (#19346)

Signed-off-by: rzou <zou3519@gmail.com>
This commit is contained in:
Richard Zou
2025-07-18 14:10:21 -04:00
committed by GitHub
parent 21274ab476
commit b2eb2b5ad7
3 changed files with 19 additions and 9 deletions

View File

@ -20,13 +20,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops
//
// The default behavior in PyTorch 2.6 is "requires_contiguous", so we need
// The default behavior in PyTorch 2.6 was changed to "requires_contiguous",
// so we need
// to override this for many GEMMs with the following tag. Otherwise,
// torch.compile will force all input tensors to be contiguous(), which
// will break many custom ops that require column-major weight matrices.
// TODO: remove this for PyTorch 2.8, when the default is planned to switch
// to match exact eager-mode strides.
at::Tag stride_tag = at::Tag::needs_fixed_stride_order;
// This was a bug and PyTorch 2.7 has since fixed this.
#if TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 6
#define stride_tag at::Tag::needs_fixed_stride_order
#else
#define stride_tag
#endif
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);