diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index a299ba0ff..ea9060972 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -177,7 +177,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { */ m.def( "bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int " - "cublas_handle, int cuda_stream) -> ()"); + "cublas_handle, int cuda_stream) -> ()", + {at::Tag::needs_fixed_stride_order}); m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8); m.def(