Fix typos and unify size(s)/stride(s) API calls (#8799)

This commit is contained in:
triple-mu
2025-08-08 15:18:08 +08:00
committed by GitHub
parent 9c7e392465
commit 444013585d
6 changed files with 34 additions and 34 deletions

View File

@@ -640,9 +640,9 @@ void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch:
TORCH_CHECK(output.size(0) == num_tokens, "required output.shape[0] == mat_a.shape[0]")
TORCH_CHECK(output.size(1) == hd_out, "required output.shape[1] == mat_b.shape[1]")
TORCH_CHECK(mat_a.strides()[1] == 1); // Row-major
TORCH_CHECK(output.strides()[1] == 1); // Row-major
TORCH_CHECK(mat_b.strides()[0] == 1); // Column-major
TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor"); // Row-major
TORCH_CHECK(output.stride(1) == 1, "output must be a row major tensor"); // Row-major
TORCH_CHECK(mat_b.stride(0) == 1, "mat_b must be a column major tensor"); // Column-major
auto const data_type = mat_a.scalar_type();
TORCH_CHECK(