diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 72f3c0699..739b2b1c2 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -222,6 +222,7 @@ set(SOURCES "csrc/gemm/awq_kernel.cu" "csrc/gemm/bmm_fp8.cu" "csrc/gemm/dsv3_fused_a_gemm.cu" + "csrc/gemm/dsv3_router_gemm.cu" "csrc/gemm/fp8_blockwise_gemm_kernel.cu" "csrc/gemm/fp8_gemm_kernel.cu" "csrc/gemm/int8_gemm_kernel.cu" diff --git a/sgl-kernel/benchmark/bench_dsv3_router_gemm.py b/sgl-kernel/benchmark/bench_dsv3_router_gemm.py new file mode 100644 index 000000000..16b3143f0 --- /dev/null +++ b/sgl-kernel/benchmark/bench_dsv3_router_gemm.py @@ -0,0 +1,56 @@ +import argparse + +import torch +import torch.nn.functional as F +import triton +import triton.testing +from sgl_kernel import dsv3_router_gemm + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["num_tokens"], + x_vals=[i + 1 for i in range(16)], + x_log=False, + line_arg="impl", + line_vals=["torch", "sgl-kernel"], + line_names=["torch", "dsv3_router_gemm"], + styles=[("blue", "-"), ("orange", "-")], + ylabel="TFLOPs", + plot_name="input-bf16-output-fp32 dsv3 router gemm throughput", + args={}, + ) +) +def benchmark(num_tokens, impl): + # M: num_tokens, K: hidden_dim, N: num_experts + M, K, N = num_tokens, 7168, 256 + + mat_a = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous() + mat_b = torch.randn((N, K), dtype=torch.bfloat16, device="cuda").contiguous() + + quantiles = [0.5, 0.2, 0.8] + + if impl == "torch": + + def runner(): + F.linear(mat_a, mat_b).to(torch.float32) + + elif impl == "sgl-kernel": + + def runner(): + dsv3_router_gemm(mat_a, mat_b) + + ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles) + + def tflops(t_ms): + flops = 2 * M * K * N + return flops / (t_ms * 1e-3) / 1e12 + + return tflops(ms), tflops(max_ms), tflops(min_ms) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + args = parser.parse_args() + + benchmark.run(print_data=True, show_plots=True, save_path="bench_dsv3_router_gemm") diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index d8629e6ab..941d0f836 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -158,6 +158,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { " Tensor expert_offsets, Tensor sf_offsets) -> ()"); m.impl("cutlass_fp4_group_mm", torch::kCUDA, &cutlass_fp4_group_mm); + m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()"); + m.impl("dsv3_router_gemm", torch::kCUDA, &dsv3_router_gemm); + /* * From csrc/moe */ diff --git a/sgl-kernel/csrc/gemm/dsv3_router_gemm.cu b/sgl-kernel/csrc/gemm/dsv3_router_gemm.cu new file mode 100644 index 000000000..410bbcefd --- /dev/null +++ b/sgl-kernel/csrc/gemm/dsv3_router_gemm.cu @@ -0,0 +1,286 @@ +/* + * Adapted from + * https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3RouterGemm.cu + * https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/thop/dsv3RouterGemmOp.cpp + * + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "cuda_bf16.h" +#include "cuda_runtime.h" +#include "utils.h" + +// Custom FMA implementation using PTX assembly instructions +__device__ __forceinline__ void fma(float2& d, float2 const& a, float2 const& b, float2 const& c) { + asm volatile("fma.rn.f32x2 %0, %1, %2, %3;\n" + : "=l"(reinterpret_cast(d)) + : "l"(reinterpret_cast(a)), + "l"(reinterpret_cast(b)), + "l"(reinterpret_cast(c))); +} + +// Convert 8 bfloat16 values from a uint4 to float array - optimized conversion +template +__device__ __forceinline__ void bf16_uint4_to_float8(uint4 const& vec, float* dst) { + __nv_bfloat16* bf16_ptr = reinterpret_cast<__nv_bfloat16*>(const_cast(&vec)); + +#pragma unroll + for (int i = 0; i < VPT; i++) { + dst[i] = __bfloat162float(bf16_ptr[i]); + } +} + +template +__global__ __launch_bounds__(128, 1) void router_gemm_kernel(float* out, T const* mat_a, T const* mat_b) { + // Each block handles one expert column + int const n_idx = blockIdx.x; + int const tid = threadIdx.x; + constexpr int kWarpSize = 32; + constexpr int kNumWarps = kBlockSize / kWarpSize; + // Constants for this kernel + constexpr int k_elems_per_k_iteration = VPT * kBlockSize; + constexpr int k_iterations = kHiddenDim / k_elems_per_k_iteration; // Total K iterations + + // Initialize accumulators for all M rows + float acc[kNumTokens] = {}; + + // Shared memory for warp-level reduction + __shared__ float sm_reduction[kNumTokens][kNumWarps]; // kNumWarps + + // B matrix is in column-major order, so we can directly load a column for the n_idx expert + T const* b_col = mat_b + n_idx * kHiddenDim; + + // Pre-compute k_base values for each iteration to help compiler optimize + // int k_bases[k_iterations]; + int k_bases[k_iterations]; +#pragma unroll + for (int ki = 0; ki < k_iterations; ki++) { + k_bases[ki] = ki * k_elems_per_k_iteration + tid * VPT; + } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif + + // Process the GEMM in chunks + for (int ki = 0; ki < k_iterations; ki++) { + int const k_base = k_bases[ki]; + + // Load B matrix values using vector load (8 bf16 values) + uint4 b_vec = *reinterpret_cast(b_col + k_base); + + // Convert B values to float + float b_float[VPT]; + bf16_uint4_to_float8(b_vec, b_float); + +// Process each token +#pragma unroll + for (int m_idx = 0; m_idx < kNumTokens; m_idx++) { + // Load both rows of A matrix using vector loads + uint4 a_vec = *reinterpret_cast(mat_a + (m_idx * kHiddenDim) + k_base); + + // Convert A values to float + float a_float[VPT]; + bf16_uint4_to_float8(a_vec, a_float); + +// Process elements in this chunk +#pragma unroll + for (int k = 0; k < VPT; k++) { + float a = a_float[k]; + float b = b_float[k]; + acc[m_idx] += a * b; + } + } + } + + // Perform warp-level reduction + int const warpSize = 32; + int const warpId = tid / warpSize; + int const laneId = tid % warpSize; + + // Register for warp-level reduction results + float warp_result[kNumTokens]; + +#pragma unroll + for (int m_idx = 0; m_idx < kNumTokens; m_idx++) { + warp_result[m_idx] = acc[m_idx]; + } + +// Perform warp-level reduction using optimized butterfly pattern +#pragma unroll + for (int m = 0; m < kNumTokens; m++) { + float sum = warp_result[m]; + + // Butterfly reduction pattern + sum += __shfl_xor_sync(0xffffffff, sum, 16); + sum += __shfl_xor_sync(0xffffffff, sum, 8); + sum += __shfl_xor_sync(0xffffffff, sum, 4); + sum += __shfl_xor_sync(0xffffffff, sum, 2); + sum += __shfl_xor_sync(0xffffffff, sum, 1); + + // Only the first thread in each warp stores to shared memory + if (laneId == 0) { + sm_reduction[m][warpId] = sum; + } + } + + __syncthreads(); + + // Final reduction across warps (only first thread) + if (tid == 0) { +#pragma unroll + for (int m = 0; m < kNumTokens; m++) { + float final_sum = 0.0f; + +// Sum across the kNumWarps +#pragma unroll + for (int w = 0; w < kNumWarps; w++) { + final_sum += sm_reduction[m][w]; + } + + // Write final result + out[m * kNumExperts + n_idx] = final_sum; + } + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} + +template +void invokeRouterGemm(float* output, T const* mat_a, T const* mat_b, cudaStream_t stream) { + constexpr int VPT = 16 / sizeof(T); + constexpr int kBlockSize = 128; + cudaLaunchConfig_t config; + config.gridDim = kNumExperts; + config.blockDim = kBlockSize; + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = getEnvEnablePDL(); + config.numAttrs = 1; + config.attrs = attrs; + cudaLaunchKernelEx( + &config, router_gemm_kernel, output, mat_a, mat_b); +} + +template void +invokeRouterGemm<__nv_bfloat16, 1, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void +invokeRouterGemm<__nv_bfloat16, 2, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void +invokeRouterGemm<__nv_bfloat16, 3, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void +invokeRouterGemm<__nv_bfloat16, 4, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void +invokeRouterGemm<__nv_bfloat16, 5, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void +invokeRouterGemm<__nv_bfloat16, 6, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void +invokeRouterGemm<__nv_bfloat16, 7, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void +invokeRouterGemm<__nv_bfloat16, 8, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void +invokeRouterGemm<__nv_bfloat16, 9, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void +invokeRouterGemm<__nv_bfloat16, 10, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void +invokeRouterGemm<__nv_bfloat16, 11, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void +invokeRouterGemm<__nv_bfloat16, 12, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void +invokeRouterGemm<__nv_bfloat16, 13, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void +invokeRouterGemm<__nv_bfloat16, 14, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void +invokeRouterGemm<__nv_bfloat16, 15, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void +invokeRouterGemm<__nv_bfloat16, 16, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template +struct LoopUnroller { + static void + unroll(int num_tokens, float* output, __nv_bfloat16 const* input, __nv_bfloat16 const* weights, cudaStream_t stream) { + if (num_tokens == kBegin) { + invokeRouterGemm<__nv_bfloat16, kBegin, kNumExperts, kHiddenDim>(output, input, weights, stream); + } else { + LoopUnroller::unroll(num_tokens, output, input, weights, stream); + } + } +}; + +template +struct LoopUnroller { + static void + unroll(int num_tokens, float* output, __nv_bfloat16 const* input, __nv_bfloat16 const* weights, cudaStream_t stream) { + if (num_tokens == kEnd) { + invokeRouterGemm<__nv_bfloat16, kEnd, kNumExperts, kHiddenDim>(output, input, weights, stream); + } else { + throw std::invalid_argument("Invalid num_tokens, only supports 1 to 16"); + } + } +}; + +void dsv3_router_gemm( + torch::Tensor& output, // [num_tokens, num_experts] + const torch::Tensor& mat_a, // [num_tokens, hidden_dim] + const torch::Tensor& mat_b // [num_experts, hidden_dim] +) { + TORCH_CHECK(output.dim() == 2 && mat_a.dim() == 2 && mat_b.dim() == 2); + + const int num_tokens = mat_a.size(0); + constexpr int num_experts = 256; + constexpr int hidden_dim = 7168; + + TORCH_CHECK(mat_a.size(1) == mat_b.size(1), "mat_a and mat_b must have the same hidden_dim"); + TORCH_CHECK(mat_a.size(1) == hidden_dim, "currently hidden_dim only supports 7168"); + TORCH_CHECK(mat_b.size(0) == num_experts, "currently num_experts only supports 256"); + TORCH_CHECK( + num_tokens >= 1 && num_tokens <= 16, "currently num_tokens must be less than or equal to 16 for router_gemm"); + TORCH_CHECK(mat_a.dtype() == torch::kBFloat16, "mat_a must be bf16"); + TORCH_CHECK(mat_b.dtype() == torch::kBFloat16, "mat_b must be bf16"); + TORCH_CHECK(output.dtype() == torch::kFloat32, "output must be float32"); + + auto const sm = getSMVersion(); + TORCH_CHECK(sm >= 90, "required CUDA ARCH >= SM_90"); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + LoopUnroller<1, 16, num_experts, hidden_dim>::unroll( + num_tokens, + reinterpret_cast(output.mutable_data_ptr()), + reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()), + reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), + stream); +} diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 30bc5dab5..abea81e3f 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -200,6 +200,7 @@ void bmm_fp8( at::Tensor workspace_buffer, int64_t cublas_handle, int64_t cuda_stream); +void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a, const torch::Tensor& mat_b); void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch::Tensor const& mat_b); diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 4cba8cc78..2353004ce 100755 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -34,6 +34,7 @@ from sgl_kernel.gemm import ( bmm_fp8, cutlass_scaled_fp4_mm, dsv3_fused_a_gemm, + dsv3_router_gemm, fp8_blockwise_scaled_mm, fp8_scaled_mm, int8_scaled_mm, diff --git a/sgl-kernel/python/sgl_kernel/gemm.py b/sgl-kernel/python/sgl_kernel/gemm.py index 2946c8ae9..6ec4ce78a 100644 --- a/sgl-kernel/python/sgl_kernel/gemm.py +++ b/sgl-kernel/python/sgl_kernel/gemm.py @@ -259,6 +259,24 @@ def qserve_w4a8_per_group_gemm( return out_feats +def dsv3_router_gemm( + hidden_states: torch.Tensor, + router_weights: torch.Tensor, +) -> torch.Tensor: + output = torch.empty( + hidden_states.shape[0], + router_weights.shape[0], + device=hidden_states.device, + dtype=torch.float32, + ) + torch.ops.sgl_kernel.dsv3_router_gemm( + output, + hidden_states, + router_weights, + ) + return output + + def shuffle_rows(input_tensor, dst2src_map, output_tensor_shape): output_tensor = torch.empty( output_tensor_shape, diff --git a/sgl-kernel/tests/test_dsv3_router_gemm.py b/sgl-kernel/tests/test_dsv3_router_gemm.py new file mode 100644 index 000000000..1b60bcf92 --- /dev/null +++ b/sgl-kernel/tests/test_dsv3_router_gemm.py @@ -0,0 +1,32 @@ +import pytest +import torch +import torch.nn.functional as F +from sgl_kernel import dsv3_router_gemm + + +@pytest.mark.parametrize("num_tokens", [i + 1 for i in range(16)]) +def test_dsv3_router_gemm(num_tokens): + num_experts = 256 + hidden_dim = 7168 + + mat_a = torch.randn( + (num_tokens, hidden_dim), dtype=torch.bfloat16, device="cuda" + ).contiguous() + mat_b = torch.randn( + (num_experts, hidden_dim), dtype=torch.bfloat16, device="cuda" + ).contiguous() + output = torch.empty( + (num_tokens, num_experts), dtype=torch.float32, device="cuda" + ).contiguous() + + ref = F.linear(mat_a, mat_b).to(torch.float32) + + output = dsv3_router_gemm(mat_a, mat_b) + + assert torch.allclose( + output, ref, rtol=1e-2, atol=1e-3 + ), "Router GEMM output mismatch with torch.nn.functional.linear reference" + + +if __name__ == "__main__": + pytest.main([__file__])