diff --git a/vllm_kunlun/vllm_utils_wrapper.py b/vllm_kunlun/vllm_utils_wrapper.py index 2753f71..e4010ae 100644 --- a/vllm_kunlun/vllm_utils_wrapper.py +++ b/vllm_kunlun/vllm_utils_wrapper.py @@ -1646,7 +1646,9 @@ def cutlass_scaled_mm( bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device) - torch.ops.xspeedgate_ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) + torch.ops.xspeedgate_ops.cutlass_scaled_mm( + out, a.contiguous(), b.contiguous(), scale_a, scale_b, bias + ) return out @@ -1660,7 +1662,9 @@ def cutlass_scaled_mm_cuda( bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device) - torch.ops.xspeedgate_ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) + torch.ops.xspeedgate_ops.cutlass_scaled_mm( + out, a.contiguous(), b.contiguous(), scale_a, scale_b, bias + ) return out @@ -1694,7 +1698,7 @@ def cutlass_scaled_mm_azp( ) -> torch.Tensor: out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device) torch.ops.xspeedgate_ops.cutlass_scaled_mm_azp( - out, a, b, scale_a, scale_b, azp_adj, azp, bias + out, a.contiguous(), b.contiguous(), scale_a, scale_b, azp_adj, azp, bias ) return out @@ -1712,7 +1716,7 @@ def cutlass_scaled_mm_azp_cuda( ) -> torch.Tensor: out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device) torch.ops.xspeedgate_ops.cutlass_scaled_mm_azp( - out, a, b, scale_a, scale_b, azp_adj, azp, bias + out, a.contiguous(), b.contiguous(), scale_a, scale_b, azp_adj, azp, bias ) return out