[fix] resolve cutlass_scaled_mm inference error
This commit is contained in:
@@ -1646,7 +1646,9 @@ def cutlass_scaled_mm(
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device)
|
||||
torch.ops.xspeedgate_ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
|
||||
torch.ops.xspeedgate_ops.cutlass_scaled_mm(
|
||||
out, a.contiguous(), b.contiguous(), scale_a, scale_b, bias
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
@@ -1660,7 +1662,9 @@ def cutlass_scaled_mm_cuda(
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device)
|
||||
torch.ops.xspeedgate_ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
|
||||
torch.ops.xspeedgate_ops.cutlass_scaled_mm(
|
||||
out, a.contiguous(), b.contiguous(), scale_a, scale_b, bias
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
@@ -1694,7 +1698,7 @@ def cutlass_scaled_mm_azp(
|
||||
) -> torch.Tensor:
|
||||
out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device)
|
||||
torch.ops.xspeedgate_ops.cutlass_scaled_mm_azp(
|
||||
out, a, b, scale_a, scale_b, azp_adj, azp, bias
|
||||
out, a.contiguous(), b.contiguous(), scale_a, scale_b, azp_adj, azp, bias
|
||||
)
|
||||
return out
|
||||
|
||||
@@ -1712,7 +1716,7 @@ def cutlass_scaled_mm_azp_cuda(
|
||||
) -> torch.Tensor:
|
||||
out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device)
|
||||
torch.ops.xspeedgate_ops.cutlass_scaled_mm_azp(
|
||||
out, a, b, scale_a, scale_b, azp_adj, azp, bias
|
||||
out, a.contiguous(), b.contiguous(), scale_a, scale_b, azp_adj, azp, bias
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
Reference in New Issue
Block a user