diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index b117e27ae..220896975 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -164,7 +164,6 @@ set(SOURCES "csrc/elementwise/rope.cu" "csrc/gemm/awq_kernel.cu" "csrc/gemm/bmm_fp8.cu" - "csrc/gemm/cublas_grouped_gemm.cu" "csrc/gemm/fp8_blockwise_gemm_kernel.cu" "csrc/gemm/fp8_gemm_kernel.cu" "csrc/gemm/int8_gemm_kernel.cu" diff --git a/sgl-kernel/benchmark/bench_cublas_grouped_gemm.py b/sgl-kernel/benchmark/bench_cublas_grouped_gemm.py deleted file mode 100644 index 15f4d7b7d..000000000 --- a/sgl-kernel/benchmark/bench_cublas_grouped_gemm.py +++ /dev/null @@ -1,262 +0,0 @@ -import argparse - -import torch -import triton -import triton.language as tl -from sgl_kernel import cublas_grouped_gemm - -WEIGHT_CONFIGS = { - "DeepSeek-V2-Lite": { - "num_routed_experts": 64, - "ffn_shapes": [ - [2048, 2816], - [1408, 2048], - ], - }, - "DeepSeek-V2": { - "num_routed_experts": 160, - "ffn_shapes": [ - [5120, 3072], - [1536, 5120], - ], - }, -} - - -# This Triton Grouped Gemm Kernel is adapted from -# https://triton-lang.org/main/getting-started/tutorials/08-grouped-gemm.html -@triton.jit -def grouped_matmul_kernel( - # device tensor of matrices pointers - group_a_ptrs, - group_b_ptrs, - group_c_ptrs, - # device tensor of gemm sizes. its shape is [group_size, 3] - # dim 0 is group_size, dim 1 is the values of of each gemm - group_gemm_sizes, - # device tensor of leading dimension sizes. its shape is [group_size, 3] - # dim 0 is group_size, dim 1 is the values of of each gemm - g_lds, - # Factors for multiplication. - alphas, - betas, - # number of gemms - group_size, - # number of virtual SM - NUM_SM: tl.constexpr, - # tile sizes - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, -): - tile_idx = tl.program_id(0) - last_problem_end = 0 - for g in range(group_size): - # get the gemm size of the current problem - gm = tl.load(group_gemm_sizes + g * 3) - gn = tl.load(group_gemm_sizes + g * 3 + 1) - gk = tl.load(group_gemm_sizes + g * 3 + 2) - num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M) - num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N) - num_tiles = num_m_tiles * num_n_tiles - # load multiplication factors - alpha = tl.load(alphas + g) - beta = tl.load(betas + g) - # iterate through the tiles in the current gemm problem - while tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles: - # pick up a tile from the current gemm problem - k = gk - lda = tl.load(g_lds + g * 3) - ldb = tl.load(g_lds + g * 3 + 1) - ldc = tl.load(g_lds + g * 3 + 2) - a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(tl.float16)) - b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(tl.float16)) - c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(tl.float16)) - # figure out tile coordinates - tile_idx_in_gemm = tile_idx - last_problem_end - tile_m_idx = tile_idx_in_gemm // num_n_tiles - tile_n_idx = tile_idx_in_gemm % num_n_tiles - - # do regular gemm here - offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + offs_am[:, None] * lda + offs_k[None, :] - b_ptrs = b_ptr + offs_k[:, None] * ldb + offs_bn[None, :] - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)): - a = tl.load( - a_ptrs, - mask=(offs_am[:, None] < gm) - and (offs_k[None, :] < gk - kk * BLOCK_SIZE_K), - other=0.0, - ) - b = tl.load( - b_ptrs, - mask=(offs_k[:, None] < gk - kk * BLOCK_SIZE_K) - and (offs_bn[None, :] < gn), - other=0.0, - ) - accumulator += tl.dot(a, b) - a_ptrs += BLOCK_SIZE_K - b_ptrs += BLOCK_SIZE_K * ldb - accumulator *= alpha - c = accumulator.to(tl.float16) - - offs_cm = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + ldc * offs_cm[:, None] + offs_cn[None, :] - output_mask = (offs_am[:, None] < gm) and (offs_bn[None, :] < gn) - c += beta * tl.load(c_ptrs, mask=output_mask) - tl.store(c_ptrs, c, mask=output_mask) - - # go to the next tile by advancing NUM_SM - tile_idx += NUM_SM - - # get ready to go to the next gemm problem - last_problem_end = last_problem_end + num_tiles - - -def triton_perf_fn(group_A, group_B, group_C, dtype): - # We put the process of matrix lengths and pointers here out of fairness, - # since cublas_grouped_gemm kernel also does these work. - group_size = len(group_A) - A_addrs = [] - B_addrs = [] - C_addrs = [] - g_sizes = [] - g_lds = [] - alphas = [1.0] * group_size - betas = [0.0] * group_size - for i in range(group_size): - M, N, K = group_A[i].shape[0], group_B[i].shape[1], group_A[i].shape[1] - g_sizes += [M, N, K] - g_lds += [K, N, N] - A_addrs.append(group_A[i].data_ptr()) - B_addrs.append(group_B[i].data_ptr()) - C_addrs.append(group_C[i].data_ptr()) - - d_a_ptrs = torch.tensor(A_addrs, device="cuda") - d_b_ptrs = torch.tensor(B_addrs, device="cuda") - d_c_ptrs = torch.tensor(C_addrs, device="cuda") - d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device="cuda") - d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device="cuda") - d_alphas = torch.tensor(alphas, dtype=torch.float32, device="cuda") - d_betas = torch.tensor(betas, dtype=torch.float32, device="cuda") - - NUM_SM = 128 - grid = (NUM_SM,) - grouped_matmul_kernel[grid]( - d_a_ptrs, - d_b_ptrs, - d_c_ptrs, - d_g_sizes, - d_g_lds, - d_alphas, - d_betas, - group_size, - NUM_SM=NUM_SM, - BLOCK_SIZE_M=128, - BLOCK_SIZE_N=128, - BLOCK_SIZE_K=32, - ) - - -def cublas_perf_fn(group_A, group_B, group_C, dtype): - cublas_grouped_gemm(group_A, group_B, group_C, dtype) - - -@triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["M"], - x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048], - x_log=False, - line_arg="provider", - line_vals=[ - "triton", - "cublas", - ], - line_names=[ - "triton", - "cublas", - ], - styles=[("green", "-"), ("blue", "-")], - ylabel="gbps", - plot_name="grouped gemm", - args={}, - ) -) -def benchmark(M, provider, N, K): - group_size = 20 # Number of used experts per gpu is usually around 20 - group_A = [] - group_B_row_major = [] - group_B_col_major = [] - group_C = [] - dtype = torch.float16 - for i in range(group_size): - A = torch.rand((M, K), device="cuda", dtype=dtype) - B_row_major = torch.rand((K, N), device="cuda", dtype=dtype) - B_col_major = torch.rand((N, K), device="cuda", dtype=dtype) - C = torch.empty((M, N), device="cuda", dtype=dtype) - group_A.append(A) - group_B_row_major.append(B_row_major) - group_B_col_major.append(B_col_major) - group_C.append(C) - - quantiles = [0.5, 0.2, 0.8] - if "triton" in provider: - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: triton_perf_fn(group_A, group_B_row_major, group_C, dtype), - quantiles=quantiles, - ) - elif "cublas" in provider: - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: cublas_perf_fn(group_A, group_B_col_major, group_C, dtype), - quantiles=quantiles, - ) - - gbps = ( - lambda ms: group_size - * (2 * M * N * K + 2 * M * N) - * group_A[0].element_size() - * 1e-9 - / (ms * 1e-3) - ) - return gbps(ms), gbps(max_ms), gbps(min_ms) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--models", - nargs="+", - type=str, - default=["DeepSeek-V2"], - help="List of models to benchmark", - ) - parser.add_argument( - "--tp-size", - type=int, - default=8, - help="Tensor parallel size", - ) - args = parser.parse_args() - for model in args.models: - assert model in WEIGHT_CONFIGS - num_experts_per_device = ( - WEIGHT_CONFIGS[model]["num_routed_experts"] // args.tp_size - ) - for K, N in WEIGHT_CONFIGS[model]["ffn_shapes"]: - print( - f"{model} N={N} K={K} tp_size={args.tp_size} " - f"group_size=num_experts_per_device={num_experts_per_device}: " - ) - benchmark.run( - print_data=True, - show_plots=True, - save_path="bench_grouped_gemm_res", - N=N, - K=K, - ) - - print("Benchmark finished!") diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index ea9060972..c9b2c8516 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -112,11 +112,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.def("sgl_per_token_quant_fp8(Tensor input, Tensor output_q, Tensor output_s) -> ()"); m.impl("sgl_per_token_quant_fp8", torch::kCUDA, &sgl_per_token_quant_fp8); - m.def( - "cublas_grouped_gemm(Tensor[] inputs, Tensor[] weights, Tensor[] outputs," - " ScalarType out_dtype, int cublas_handle, int cuda_stream) -> ()"); - m.impl("cublas_grouped_gemm", torch::kCUDA, &cublas_grouped_gemm); - m.def( "cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b," " Tensor block_scale_a, Tensor block_scale_b," diff --git a/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu b/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu deleted file mode 100644 index ca7d131e1..000000000 --- a/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu +++ /dev/null @@ -1,172 +0,0 @@ -// References: -// https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmgroupedbatchedex -// https://github.com/NVIDIA/CUDALibrarySamples/blob/master/cuBLAS/Extensions/GemmGroupedBatchedEx/cublas_GemmGroupedBatchedEx_example.cu -// https://github.com/zhihu/ZhiLight/blob/main/src/nn/linear/gemm_grouped.cpp - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -#include "utils.h" - -static void check_group_count( - const std::vector& inputs, - const std::vector& weights, - const std::vector& outputs) { - TORCH_CHECK( - ((inputs.size() == weights.size()) && (inputs.size() == outputs.size())), - "The group count of inputs, weights and outputs should be the same."); -} - -static void check_device_dtype(const torch::Dtype& dtype, const std::vector& tensors) { - for (const auto& t : tensors) { - TORCH_CHECK(dtype == t.dtype(), "dtype of all the tensors should be the same"); - TORCH_CHECK(t.is_cuda(), "All tensors should be in Cuda memory"); - } -} - -static std::vector get_dims(const std::vector& tensors, int dim) { - std::vector results; - for (const auto& t : tensors) { - TORCH_CHECK(t.dim() == 2, "Should pass in 2D matrices"); - results.push_back(t.size(dim)); - } - return std::move(results); -} - -static std::vector get_strides(const std::vector& tensors, int dim) { - std::vector results; - for (const auto& t : tensors) { - results.push_back(t.stride(dim)); - } - return std::move(results); -} - -static void check_equal(const std::vector& a, const std::vector& b, const std::string& err_msg) { - for (int i = 0; i < a.size(); ++i) { - TORCH_CHECK(a[i] == b[i], err_msg); - } -} - -static std::vector get_tensor_ptrs(const std::vector& tensors) { - std::vector ptrs; - for (auto& t : tensors) { - ptrs.push_back(t.data_ptr()); - } - return std::move(ptrs); -} - -static torch::Tensor create_ptr_pointer(const std::vector& ptrs, cudaStream_t stream) { - auto options = torch::TensorOptions().dtype(torch::kDouble).device(torch::kCUDA); - torch::Tensor gpu_ptrs = torch::empty({static_cast(ptrs.size())}, options); - TORCH_CHECK( - cudaMemcpyAsync(gpu_ptrs.data_ptr(), ptrs.data(), sizeof(void*) * ptrs.size(), cudaMemcpyHostToDevice, stream) == - CUBLAS_STATUS_SUCCESS); - return gpu_ptrs; -} - -// We want compute input @ weight^T in row major -// This is equivalent to computing weight @ input^T in col major -// Cublas only accepts matrix in column major, so this arrangement is needed -void cublas_grouped_gemm( - const std::vector& inputs, // b: (m, k) row major = (k, m) col major - const std::vector& weights, // a: (n, k) row major = (n, k)^T col major - const std::vector& outputs, // c: (m, n) row major = (n, m) col major - const torch::Dtype& out_dtype, - int64_t cublas_handle, - int64_t cuda_stream) { - TORCH_CHECK( - out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, - "cublas grouped_gemm can" - "only be applied to float16 and bfloat16 dtype"); - - int group_count = inputs.size(); - check_group_count(inputs, weights, outputs); - std::vector group_size(group_count, 1); - - // Make sure all tensors are on cuda and use the same dtype - check_device_dtype(out_dtype, inputs); - check_device_dtype(out_dtype, weights); - check_device_dtype(out_dtype, outputs); - - // Weights should be transposed to (n, k) of column major - std::vector transa_array(group_count, CUBLAS_OP_T); - std::vector transb_array(group_count, CUBLAS_OP_N); - - // Get dim arrays - std::vector m_array = get_dims(weights, 0); - std::vector n_array = get_dims(inputs, 0); - std::vector k_array = get_dims(inputs, 1); - - // Make sure the dimensions in each group match - std::vector m_array1 = get_dims(outputs, 1); - std::vector n_array1 = get_dims(outputs, 0); - std::vector k_array1 = get_dims(weights, 1); - check_equal(m_array, m_array1, "sizes don't match on m dimension"); - check_equal(n_array, n_array1, "sizes don't match on n dimension"); - check_equal(k_array, k_array1, "sizes don't match on k dimension"); - - // Get leading dimensions - std::vector lda_array = get_strides(weights, 0); - std::vector ldb_array = get_strides(inputs, 0); - std::vector ldc_array = get_strides(outputs, 0); - - // Use default scaling factors - std::vector alpha_array(group_count, 1); - std::vector beta_array(group_count, 0); - - std::vector a_array = get_tensor_ptrs(weights); - std::vector b_array = get_tensor_ptrs(inputs); - std::vector c_array = get_tensor_ptrs(outputs); - - auto stream = reinterpret_cast(cuda_stream); - - // Should allocate tensors for storage of pointers - torch::Tensor d_a = create_ptr_pointer(a_array, stream); - torch::Tensor d_b = create_ptr_pointer(b_array, stream); - torch::Tensor d_c = create_ptr_pointer(c_array, stream); - -#if defined CUDA_VERSION && CUDA_VERSION >= 12050 - auto handle = reinterpret_cast(cublas_handle); - cudaDataType_t cuda_data_type = (out_dtype == torch::kHalf ? CUDA_R_16F : CUDA_R_16BF); - - auto status = cublasGemmGroupedBatchedEx( - handle, - transa_array.data(), - transb_array.data(), - m_array.data(), - n_array.data(), - k_array.data(), - alpha_array.data(), - (void**)d_a.data_ptr(), - cuda_data_type, - lda_array.data(), - (void**)d_b.data_ptr(), - cuda_data_type, - ldb_array.data(), - beta_array.data(), - (void**)d_c.data_ptr(), - cuda_data_type, - ldc_array.data(), - group_count, - group_size.data(), - CUBLAS_COMPUTE_32F); - TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, "cublas grouped gemm failed: ", cublasGetStatusString(status)); - TORCH_CHECK(cudaStreamSynchronize(stream) == cudaSuccess, "Failed when stream synchronization"); - return; -#endif - - TORCH_CHECK_NOT_IMPLEMENTED( - false, "Cublas GroupGemm is not implemented with current compute capability: ", getSMVersion()); -} diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 8f8abb234..4aa8535f5 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -160,13 +160,6 @@ void sgl_per_token_group_quant_int8( double int8_max); void sgl_per_tensor_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, bool is_static); void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s); -void cublas_grouped_gemm( - const std::vector& inputs, - const std::vector& weights, - const std::vector& outputs, - const torch::Dtype& out_dtype, - int64_t cublas_handle, - int64_t cuda_stream); void bmm_fp8( at::Tensor A, at::Tensor B, diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 789cbcedf..0fc68cfcc 100644 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -25,7 +25,6 @@ from sgl_kernel.elementwise import ( from sgl_kernel.gemm import ( awq_dequantize, bmm_fp8, - cublas_grouped_gemm, cutlass_scaled_fp4_mm, fp8_blockwise_scaled_mm, fp8_scaled_mm, diff --git a/sgl-kernel/python/sgl_kernel/gemm.py b/sgl-kernel/python/sgl_kernel/gemm.py index b1ef34596..7035519c2 100644 --- a/sgl-kernel/python/sgl_kernel/gemm.py +++ b/sgl-kernel/python/sgl_kernel/gemm.py @@ -121,26 +121,6 @@ def sgl_per_tensor_quant_fp8( ) -def cublas_grouped_gemm( - inputs: List[torch.Tensor], - weights: List[torch.Tensor], - outputs: List[torch.Tensor], - out_dtype: torch.dtype, -) -> None: - assert ( - len(inputs) > 0 and len(weights) > 0 and len(outputs) > 0 - ), "Inputs/weights/outputs should not be empty!" - cublas_handle = torch.cuda.current_blas_handle() - torch.ops.sgl_kernel.cublas_grouped_gemm.default( - inputs, - weights, - outputs, - out_dtype, - cublas_handle, - get_cuda_stream(), - ) - - def sgl_per_token_quant_fp8( input: torch.Tensor, output_q: torch.Tensor, diff --git a/sgl-kernel/tests/test_cublas_grouped_gemm.py b/sgl-kernel/tests/test_cublas_grouped_gemm.py deleted file mode 100644 index 70b3dc5cf..000000000 --- a/sgl-kernel/tests/test_cublas_grouped_gemm.py +++ /dev/null @@ -1,40 +0,0 @@ -import pytest -import torch -from sgl_kernel import cublas_grouped_gemm - - -def torch_grouped_gemm(a_array, b_array, out_dtype): - return [torch.matmul(a, b.t()).to(out_dtype) for a, b in zip(a_array, b_array)] - - -skip_condition = not torch.cuda.is_available() or ( - torch.version.cuda is None - or tuple(map(int, torch.version.cuda.split("."))) < (12, 5) -) - - -@pytest.mark.skipif( - skip_condition, reason="CUDA not available or CUDA version lower than 12.5" -) -@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("M", [1, 16, 32, 256, 1024]) -@pytest.mark.parametrize("N", [2, 16, 128, 256, 4096]) -@pytest.mark.parametrize("K", [3, 16, 32, 512, 8192]) -def test_grouped_gemm_accuracy(out_dtype, M, N, K): - a = torch.randn((M, K), device="cuda", dtype=out_dtype) * 5 - b = torch.randn((N, K), device="cuda", dtype=out_dtype) * 5 - expected = torch.matmul(a, b.t()).to(out_dtype) - - a_array = [a] - b_array = [b] - c_array = [torch.empty((M, N), device="cuda", dtype=out_dtype)] - - result_torch = torch_grouped_gemm(a_array, b_array, out_dtype)[0] - cublas_grouped_gemm(a_array, b_array, c_array, out_dtype) - - torch.testing.assert_close(result_torch, expected) - torch.testing.assert_close(c_array[0], expected) - - -if __name__ == "__main__": - pytest.main([__file__])