[fix] resolve cutlass_scaled_mm inference error

This commit is contained in:
tangshiwen
2026-01-06 20:52:12 +08:00
parent c54b2d2a2d
commit f811ae968a

View File

@@ -1646,7 +1646,9 @@ def cutlass_scaled_mm(
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device)
torch.ops.xspeedgate_ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
torch.ops.xspeedgate_ops.cutlass_scaled_mm(
out, a.contiguous(), b.contiguous(), scale_a, scale_b, bias
)
return out
@@ -1660,7 +1662,9 @@ def cutlass_scaled_mm_cuda(
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device)
torch.ops.xspeedgate_ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
torch.ops.xspeedgate_ops.cutlass_scaled_mm(
out, a.contiguous(), b.contiguous(), scale_a, scale_b, bias
)
return out
@@ -1694,7 +1698,7 @@ def cutlass_scaled_mm_azp(
) -> torch.Tensor:
out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device)
torch.ops.xspeedgate_ops.cutlass_scaled_mm_azp(
out, a, b, scale_a, scale_b, azp_adj, azp, bias
out, a.contiguous(), b.contiguous(), scale_a, scale_b, azp_adj, azp, bias
)
return out
@@ -1712,7 +1716,7 @@ def cutlass_scaled_mm_azp_cuda(
) -> torch.Tensor:
out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device)
torch.ops.xspeedgate_ops.cutlass_scaled_mm_azp(
out, a, b, scale_a, scale_b, azp_adj, azp, bias
out, a.contiguous(), b.contiguous(), scale_a, scale_b, azp_adj, azp, bias
)
return out