Fix typos and unify size(s)/stride(s) API calls (#8799)
This commit is contained in:
@@ -640,9 +640,9 @@ void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch:
|
||||
TORCH_CHECK(output.size(0) == num_tokens, "required output.shape[0] == mat_a.shape[0]")
|
||||
TORCH_CHECK(output.size(1) == hd_out, "required output.shape[1] == mat_b.shape[1]")
|
||||
|
||||
TORCH_CHECK(mat_a.strides()[1] == 1); // Row-major
|
||||
TORCH_CHECK(output.strides()[1] == 1); // Row-major
|
||||
TORCH_CHECK(mat_b.strides()[0] == 1); // Column-major
|
||||
TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor"); // Row-major
|
||||
TORCH_CHECK(output.stride(1) == 1, "output must be a row major tensor"); // Row-major
|
||||
TORCH_CHECK(mat_b.stride(0) == 1, "mat_b must be a column major tensor"); // Column-major
|
||||
|
||||
auto const data_type = mat_a.scalar_type();
|
||||
TORCH_CHECK(
|
||||
|
||||
@@ -353,7 +353,7 @@ torch::Tensor fp8_blockwise_scaled_mm(
|
||||
TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor");
|
||||
TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor");
|
||||
TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor");
|
||||
TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor");
|
||||
TORCH_CHECK(mat_b.stride(0) == 1, "mat_b must be a column major tensor");
|
||||
TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied");
|
||||
|
||||
TORCH_CHECK(
|
||||
|
||||
@@ -1080,7 +1080,7 @@ torch::Tensor fp8_scaled_mm(
|
||||
TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor");
|
||||
TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor");
|
||||
TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor");
|
||||
TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor");
|
||||
TORCH_CHECK(mat_b.stride(0) == 1, "mat_b must be a column major tensor");
|
||||
TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied");
|
||||
|
||||
TORCH_CHECK(
|
||||
|
||||
@@ -672,7 +672,7 @@ torch::Tensor int8_scaled_mm(
|
||||
TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor");
|
||||
TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor");
|
||||
TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor");
|
||||
TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor");
|
||||
TORCH_CHECK(mat_b.stride(0) == 1, "mat_b must be a column major tensor");
|
||||
TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied");
|
||||
TORCH_CHECK(mat_a.size(1) % 16 == 0, "mat_a.size(1) must be multiple of 16 for memory alignment");
|
||||
TORCH_CHECK(mat_b.size(0) % 16 == 0, "mat_b.size(0) must be multiple of 16 for memory alignment");
|
||||
|
||||
@@ -273,20 +273,20 @@ void cutlass_scaled_fp4_mm_sm100a(
|
||||
TORCH_CHECK(A.dim() == 2, "a must be a matrix");
|
||||
TORCH_CHECK(B.dim() == 2, "b must be a matrix");
|
||||
TORCH_CHECK(
|
||||
A.sizes()[1] == B.sizes()[1],
|
||||
A.size(1) == B.size(1),
|
||||
"a and b shapes cannot be multiplied (",
|
||||
A.sizes()[0],
|
||||
A.size(0),
|
||||
"x",
|
||||
A.sizes()[1],
|
||||
A.size(1),
|
||||
" and ",
|
||||
B.sizes()[0],
|
||||
B.size(0),
|
||||
"x",
|
||||
B.sizes()[1],
|
||||
B.size(1),
|
||||
")");
|
||||
|
||||
auto const m = A.sizes()[0];
|
||||
auto const n = B.sizes()[0];
|
||||
auto const k = A.sizes()[1] * 2;
|
||||
auto const m = A.size(0);
|
||||
auto const n = B.size(0);
|
||||
auto const k = A.size(1) * 2;
|
||||
|
||||
constexpr int alignment = 32;
|
||||
TORCH_CHECK(
|
||||
@@ -294,9 +294,9 @@ void cutlass_scaled_fp4_mm_sm100a(
|
||||
"Expected k to be divisible by ",
|
||||
alignment,
|
||||
", but got a shape: (",
|
||||
A.sizes()[0],
|
||||
A.size(0),
|
||||
"x",
|
||||
A.sizes()[1],
|
||||
A.size(1),
|
||||
"), k: ",
|
||||
k,
|
||||
".");
|
||||
@@ -305,9 +305,9 @@ void cutlass_scaled_fp4_mm_sm100a(
|
||||
"Expected n to be divisible by ",
|
||||
alignment,
|
||||
", but got b shape: (",
|
||||
B.sizes()[0],
|
||||
B.size(0),
|
||||
"x",
|
||||
B.sizes()[1],
|
||||
B.size(1),
|
||||
").");
|
||||
|
||||
auto round_up = [](int x, int y) { return (x + y - 1) / y * y; };
|
||||
@@ -320,37 +320,37 @@ void cutlass_scaled_fp4_mm_sm100a(
|
||||
TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
|
||||
TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
|
||||
TORCH_CHECK(
|
||||
A_sf.sizes()[1] == B_sf.sizes()[1],
|
||||
A_sf.size(1) == B_sf.size(1),
|
||||
"scale_a and scale_b shapes cannot be multiplied (",
|
||||
A_sf.sizes()[0],
|
||||
A_sf.size(0),
|
||||
"x",
|
||||
A_sf.sizes()[1],
|
||||
A_sf.size(1),
|
||||
" and ",
|
||||
B_sf.sizes()[0],
|
||||
B_sf.size(0),
|
||||
"x",
|
||||
B_sf.sizes()[1],
|
||||
B_sf.size(1),
|
||||
")");
|
||||
TORCH_CHECK(
|
||||
A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k,
|
||||
A_sf.size(0) == rounded_m && A_sf.size(1) == rounded_k,
|
||||
"scale_a must be padded and swizzled to a shape (",
|
||||
rounded_m,
|
||||
"x",
|
||||
rounded_k,
|
||||
"), but got a shape (",
|
||||
A_sf.sizes()[0],
|
||||
A_sf.size(0),
|
||||
"x",
|
||||
A_sf.sizes()[1],
|
||||
A_sf.size(1),
|
||||
")");
|
||||
TORCH_CHECK(
|
||||
B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k,
|
||||
B_sf.size(0) == rounded_n && B_sf.size(1) == rounded_k,
|
||||
"scale_b must be padded and swizzled to a shape (",
|
||||
rounded_n,
|
||||
"x",
|
||||
rounded_k,
|
||||
"), but got a shape (",
|
||||
B_sf.sizes()[0],
|
||||
B_sf.size(0),
|
||||
"x",
|
||||
B_sf.sizes()[1],
|
||||
B_sf.size(1),
|
||||
")");
|
||||
|
||||
auto out_dtype = D.dtype();
|
||||
|
||||
Reference in New Issue
Block a user