[fix] resolve cutlass_scaled_mm inference error
This commit is contained in:
@@ -1646,7 +1646,9 @@ def cutlass_scaled_mm(
|
|||||||
bias: Optional[torch.Tensor] = None,
|
bias: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device)
|
out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device)
|
||||||
torch.ops.xspeedgate_ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
|
torch.ops.xspeedgate_ops.cutlass_scaled_mm(
|
||||||
|
out, a.contiguous(), b.contiguous(), scale_a, scale_b, bias
|
||||||
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@@ -1660,7 +1662,9 @@ def cutlass_scaled_mm_cuda(
|
|||||||
bias: Optional[torch.Tensor] = None,
|
bias: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device)
|
out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device)
|
||||||
torch.ops.xspeedgate_ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
|
torch.ops.xspeedgate_ops.cutlass_scaled_mm(
|
||||||
|
out, a.contiguous(), b.contiguous(), scale_a, scale_b, bias
|
||||||
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@@ -1694,7 +1698,7 @@ def cutlass_scaled_mm_azp(
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device)
|
out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device)
|
||||||
torch.ops.xspeedgate_ops.cutlass_scaled_mm_azp(
|
torch.ops.xspeedgate_ops.cutlass_scaled_mm_azp(
|
||||||
out, a, b, scale_a, scale_b, azp_adj, azp, bias
|
out, a.contiguous(), b.contiguous(), scale_a, scale_b, azp_adj, azp, bias
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@@ -1712,7 +1716,7 @@ def cutlass_scaled_mm_azp_cuda(
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device)
|
out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device)
|
||||||
torch.ops.xspeedgate_ops.cutlass_scaled_mm_azp(
|
torch.ops.xspeedgate_ops.cutlass_scaled_mm_azp(
|
||||||
out, a, b, scale_a, scale_b, azp_adj, azp, bias
|
out, a.contiguous(), b.contiguous(), scale_a, scale_b, azp_adj, azp, bias
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user