From 444013585d6d3d3583aea6d105e9bfb4e2ccf933 Mon Sep 17 00:00:00 2001 From: triple-mu Date: Fri, 8 Aug 2025 15:18:08 +0800 Subject: [PATCH] Fix typos and unify size(s)/stride(s) API calls (#8799) --- .../csrc/attention/cutlass_mla_kernel.cu | 10 ++-- sgl-kernel/csrc/gemm/dsv3_fused_a_gemm.cu | 6 +-- .../csrc/gemm/fp8_blockwise_gemm_kernel.cu | 2 +- sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu | 2 +- sgl-kernel/csrc/gemm/int8_gemm_kernel.cu | 2 +- .../csrc/gemm/nvfp4_scaled_mm_kernels.cu | 46 +++++++++---------- 6 files changed, 34 insertions(+), 34 deletions(-) diff --git a/sgl-kernel/csrc/attention/cutlass_mla_kernel.cu b/sgl-kernel/csrc/attention/cutlass_mla_kernel.cu index 7c060274b..88c4c89e2 100644 --- a/sgl-kernel/csrc/attention/cutlass_mla_kernel.cu +++ b/sgl-kernel/csrc/attention/cutlass_mla_kernel.cu @@ -105,10 +105,10 @@ typename T::Fmha::Arguments args_from_options( hw_info.device_id = q_nope.device().index(); hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); - int batches = q_nope.sizes()[0]; - int page_count_per_seq = page_table.sizes()[1]; - int page_count_total = kv_c_and_k_pe_cache.sizes()[0]; - int page_size = kv_c_and_k_pe_cache.sizes()[1]; + int batches = q_nope.size(0); + int page_count_per_seq = page_table.size(1); + int page_count_total = kv_c_and_k_pe_cache.size(0); + int page_size = kv_c_and_k_pe_cache.size(1); int max_seq_len = page_size * page_count_per_seq; using TileShapeH = typename T::TileShapeH; using TileShapeD = typename T::TileShapeD; @@ -220,7 +220,7 @@ void cutlass_mla_decode( auto in_dtype = q_nope.dtype(); at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()}; const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope.get_device()); - const int page_size = kv_c_and_k_pe_cache.sizes()[1]; + const int page_size = kv_c_and_k_pe_cache.size(1); // NOTE(alcanderian): IsPersistent has bug with manual split_kv. // Kernel will hang if batch is too large with large num_kv_splits. (for example bs=8, num_kv_splits=8) diff --git a/sgl-kernel/csrc/gemm/dsv3_fused_a_gemm.cu b/sgl-kernel/csrc/gemm/dsv3_fused_a_gemm.cu index 39548c537..28dcaaee1 100644 --- a/sgl-kernel/csrc/gemm/dsv3_fused_a_gemm.cu +++ b/sgl-kernel/csrc/gemm/dsv3_fused_a_gemm.cu @@ -640,9 +640,9 @@ void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch: TORCH_CHECK(output.size(0) == num_tokens, "required output.shape[0] == mat_a.shape[0]") TORCH_CHECK(output.size(1) == hd_out, "required output.shape[1] == mat_b.shape[1]") - TORCH_CHECK(mat_a.strides()[1] == 1); // Row-major - TORCH_CHECK(output.strides()[1] == 1); // Row-major - TORCH_CHECK(mat_b.strides()[0] == 1); // Column-major + TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor"); // Row-major + TORCH_CHECK(output.stride(1) == 1, "output must be a row major tensor"); // Row-major + TORCH_CHECK(mat_b.stride(0) == 1, "mat_b must be a column major tensor"); // Column-major auto const data_type = mat_a.scalar_type(); TORCH_CHECK( diff --git a/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu b/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu index 33e536f16..1c082da4e 100755 --- a/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu +++ b/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu @@ -353,7 +353,7 @@ torch::Tensor fp8_blockwise_scaled_mm( TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor"); TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor"); TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor"); - TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor"); + TORCH_CHECK(mat_b.stride(0) == 1, "mat_b must be a column major tensor"); TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied"); TORCH_CHECK( diff --git a/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu b/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu index d3bc610f3..0d25e9985 100644 --- a/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu +++ b/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu @@ -1080,7 +1080,7 @@ torch::Tensor fp8_scaled_mm( TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor"); TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor"); TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor"); - TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor"); + TORCH_CHECK(mat_b.stride(0) == 1, "mat_b must be a column major tensor"); TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied"); TORCH_CHECK( diff --git a/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu b/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu index a81dba3d9..f18c81865 100644 --- a/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu +++ b/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu @@ -672,7 +672,7 @@ torch::Tensor int8_scaled_mm( TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor"); TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor"); TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor"); - TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor"); + TORCH_CHECK(mat_b.stride(0) == 1, "mat_b must be a column major tensor"); TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied"); TORCH_CHECK(mat_a.size(1) % 16 == 0, "mat_a.size(1) must be multiple of 16 for memory alignment"); TORCH_CHECK(mat_b.size(0) % 16 == 0, "mat_b.size(0) must be multiple of 16 for memory alignment"); diff --git a/sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu b/sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu index d1193ea44..cc4804298 100644 --- a/sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu +++ b/sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu @@ -273,20 +273,20 @@ void cutlass_scaled_fp4_mm_sm100a( TORCH_CHECK(A.dim() == 2, "a must be a matrix"); TORCH_CHECK(B.dim() == 2, "b must be a matrix"); TORCH_CHECK( - A.sizes()[1] == B.sizes()[1], + A.size(1) == B.size(1), "a and b shapes cannot be multiplied (", - A.sizes()[0], + A.size(0), "x", - A.sizes()[1], + A.size(1), " and ", - B.sizes()[0], + B.size(0), "x", - B.sizes()[1], + B.size(1), ")"); - auto const m = A.sizes()[0]; - auto const n = B.sizes()[0]; - auto const k = A.sizes()[1] * 2; + auto const m = A.size(0); + auto const n = B.size(0); + auto const k = A.size(1) * 2; constexpr int alignment = 32; TORCH_CHECK( @@ -294,9 +294,9 @@ void cutlass_scaled_fp4_mm_sm100a( "Expected k to be divisible by ", alignment, ", but got a shape: (", - A.sizes()[0], + A.size(0), "x", - A.sizes()[1], + A.size(1), "), k: ", k, "."); @@ -305,9 +305,9 @@ void cutlass_scaled_fp4_mm_sm100a( "Expected n to be divisible by ", alignment, ", but got b shape: (", - B.sizes()[0], + B.size(0), "x", - B.sizes()[1], + B.size(1), ")."); auto round_up = [](int x, int y) { return (x + y - 1) / y * y; }; @@ -320,37 +320,37 @@ void cutlass_scaled_fp4_mm_sm100a( TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix"); TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix"); TORCH_CHECK( - A_sf.sizes()[1] == B_sf.sizes()[1], + A_sf.size(1) == B_sf.size(1), "scale_a and scale_b shapes cannot be multiplied (", - A_sf.sizes()[0], + A_sf.size(0), "x", - A_sf.sizes()[1], + A_sf.size(1), " and ", - B_sf.sizes()[0], + B_sf.size(0), "x", - B_sf.sizes()[1], + B_sf.size(1), ")"); TORCH_CHECK( - A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k, + A_sf.size(0) == rounded_m && A_sf.size(1) == rounded_k, "scale_a must be padded and swizzled to a shape (", rounded_m, "x", rounded_k, "), but got a shape (", - A_sf.sizes()[0], + A_sf.size(0), "x", - A_sf.sizes()[1], + A_sf.size(1), ")"); TORCH_CHECK( - B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k, + B_sf.size(0) == rounded_n && B_sf.size(1) == rounded_k, "scale_b must be padded and swizzled to a shape (", rounded_n, "x", rounded_k, "), but got a shape (", - B_sf.sizes()[0], + B_sf.size(0), "x", - B_sf.sizes()[1], + B_sf.size(1), ")"); auto out_dtype = D.dtype();