Fix typos and unify size(s)/stride(s) API calls (#8799)
This commit is contained in:
@@ -353,7 +353,7 @@ torch::Tensor fp8_blockwise_scaled_mm(
|
||||
TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor");
|
||||
TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor");
|
||||
TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor");
|
||||
TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor");
|
||||
TORCH_CHECK(mat_b.stride(0) == 1, "mat_b must be a column major tensor");
|
||||
TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied");
|
||||
|
||||
TORCH_CHECK(
|
||||
|
||||
Reference in New Issue
Block a user