Fix typos and unify size(s)/stride(s) API calls (#8799)
This commit is contained in:
@@ -640,9 +640,9 @@ void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch:
|
||||
TORCH_CHECK(output.size(0) == num_tokens, "required output.shape[0] == mat_a.shape[0]")
|
||||
TORCH_CHECK(output.size(1) == hd_out, "required output.shape[1] == mat_b.shape[1]")
|
||||
|
||||
TORCH_CHECK(mat_a.strides()[1] == 1); // Row-major
|
||||
TORCH_CHECK(output.strides()[1] == 1); // Row-major
|
||||
TORCH_CHECK(mat_b.strides()[0] == 1); // Column-major
|
||||
TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor"); // Row-major
|
||||
TORCH_CHECK(output.stride(1) == 1, "output must be a row major tensor"); // Row-major
|
||||
TORCH_CHECK(mat_b.stride(0) == 1, "mat_b must be a column major tensor"); // Column-major
|
||||
|
||||
auto const data_type = mat_a.scalar_type();
|
||||
TORCH_CHECK(
|
||||
|
||||
Reference in New Issue
Block a user