diff --git a/sgl-kernel/benchmark/bench_dsv3_router_gemm.py b/sgl-kernel/benchmark/bench_dsv3_router_gemm.py index 4502746f9..dee090e21 100644 --- a/sgl-kernel/benchmark/bench_dsv3_router_gemm.py +++ b/sgl-kernel/benchmark/bench_dsv3_router_gemm.py @@ -13,9 +13,14 @@ from sgl_kernel import dsv3_router_gemm x_vals=[i + 1 for i in range(16)], x_log=False, line_arg="impl", - line_vals=["torch", "sgl-kernel"], - line_names=["torch", "dsv3_router_gemm"], - styles=[("blue", "-"), ("orange", "-")], + line_vals=["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"], + line_names=[ + "torch-256", + "dsv3_router_gemm-256", + "torch-384", + "dsv3_router_gemm-384", + ], + styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")], ylabel="TFLOPs", plot_name="input-bf16-output-bf16 dsv3 router gemm throughput", args={}, @@ -23,19 +28,26 @@ from sgl_kernel import dsv3_router_gemm ) def benchmark_bf16_output(num_tokens, impl): # M: num_tokens, K: hidden_dim, N: num_experts - M, K, N = num_tokens, 7168, 256 + M, K = num_tokens, 7168 + + if impl == "torch-256" or impl == "sgl-kernel-256": + N = 256 + elif impl == "torch-384" or impl == "sgl-kernel-384": + N = 384 + else: + raise ValueError(f"Unknown impl: {impl}") mat_a = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous() mat_b = torch.randn((N, K), dtype=torch.bfloat16, device="cuda").contiguous() quantiles = [0.5, 0.2, 0.8] - if impl == "torch": + if impl == "torch-256" or impl == "torch-384": def runner(): F.linear(mat_a, mat_b) - elif impl == "sgl-kernel": + elif impl == "sgl-kernel-256" or impl == "sgl-kernel-384": def runner(): dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.bfloat16) @@ -55,9 +67,14 @@ def benchmark_bf16_output(num_tokens, impl): x_vals=[i + 1 for i in range(16)], x_log=False, line_arg="impl", - line_vals=["torch", "sgl-kernel"], - line_names=["torch", "dsv3_router_gemm"], - styles=[("blue", "-"), ("orange", "-")], + line_vals=["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"], + line_names=[ + "torch-256", + "dsv3_router_gemm-256", + "torch-384", + "dsv3_router_gemm-384", + ], + styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")], ylabel="TFLOPs", plot_name="input-bf16-output-fp32 dsv3 router gemm throughput", args={}, @@ -65,19 +82,26 @@ def benchmark_bf16_output(num_tokens, impl): ) def benchmark_float_output(num_tokens, impl): # M: num_tokens, K: hidden_dim, N: num_experts - M, K, N = num_tokens, 7168, 256 + M, K = num_tokens, 7168 + + if impl == "torch-256" or impl == "sgl-kernel-256": + N = 256 + elif impl == "torch-384" or impl == "sgl-kernel-384": + N = 384 + else: + raise ValueError(f"Unknown impl: {impl}") mat_a = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous() mat_b = torch.randn((N, K), dtype=torch.bfloat16, device="cuda").contiguous() quantiles = [0.5, 0.2, 0.8] - if impl == "torch": + if impl == "torch-256" or impl == "torch-384": def runner(): F.linear(mat_a, mat_b).to(torch.float32) - elif impl == "sgl-kernel": + elif impl == "sgl-kernel-256" or impl == "sgl-kernel-384": def runner(): dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.float32) diff --git a/sgl-kernel/csrc/gemm/dsv3_router_gemm_bf16_out.cu b/sgl-kernel/csrc/gemm/dsv3_router_gemm_bf16_out.cu index ef011dfb0..e613bd75c 100644 --- a/sgl-kernel/csrc/gemm/dsv3_router_gemm_bf16_out.cu +++ b/sgl-kernel/csrc/gemm/dsv3_router_gemm_bf16_out.cu @@ -185,6 +185,7 @@ void invokeRouterGemmBf16Output(__nv_bfloat16* output, T const* mat_a, T const* mat_b); } +// Template instantiations for DEFAULT_NUM_EXPERTS experts template void invokeRouterGemmBf16Output<__nv_bfloat16, 1, 256, 7168>( __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); @@ -232,3 +233,52 @@ template void invokeRouterGemmBf16Output<__nv_bfloat16, 15, 256, 7168>( template void invokeRouterGemmBf16Output<__nv_bfloat16, 16, 256, 7168>( __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +// Template instantiations for KIMI_K2_NUM_EXPERTS experts +template void invokeRouterGemmBf16Output<__nv_bfloat16, 1, 384, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 2, 384, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 3, 384, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 4, 384, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 5, 384, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 6, 384, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 7, 384, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 8, 384, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 9, 384, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 10, 384, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 11, 384, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 12, 384, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 13, 384, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 14, 384, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 15, 384, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 16, 384, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); diff --git a/sgl-kernel/csrc/gemm/dsv3_router_gemm_entry.cu b/sgl-kernel/csrc/gemm/dsv3_router_gemm_entry.cu index c316a8193..4f09e6cf4 100644 --- a/sgl-kernel/csrc/gemm/dsv3_router_gemm_entry.cu +++ b/sgl-kernel/csrc/gemm/dsv3_router_gemm_entry.cu @@ -25,6 +25,10 @@ #include "cuda_runtime.h" #include "utils.h" +static constexpr int DEFAULT_NUM_EXPERTS = 256; +static constexpr int KIMI_K2_NUM_EXPERTS = 384; +static constexpr int DEFAULT_HIDDEN_DIM = 7168; + template void invokeRouterGemmFloatOutput(float* output, T const* mat_a, T const* mat_b, cudaStream_t stream); @@ -91,12 +95,24 @@ void dsv3_router_gemm( TORCH_CHECK(output.dim() == 2 && mat_a.dim() == 2 && mat_b.dim() == 2); const int num_tokens = mat_a.size(0); - constexpr int num_experts = 256; - constexpr int hidden_dim = 7168; + const int num_experts = mat_b.size(0); + const int hidden_dim = mat_a.size(1); TORCH_CHECK(mat_a.size(1) == mat_b.size(1), "mat_a and mat_b must have the same hidden_dim"); - TORCH_CHECK(mat_a.size(1) == hidden_dim, "currently hidden_dim only supports 7168"); - TORCH_CHECK(mat_b.size(0) == num_experts, "currently num_experts only supports 256"); + TORCH_CHECK( + hidden_dim == DEFAULT_HIDDEN_DIM, + "Expected hidden_dim=", + DEFAULT_HIDDEN_DIM, + ", but got hidden_dim=", + hidden_dim); + TORCH_CHECK( + num_experts == DEFAULT_NUM_EXPERTS || num_experts == KIMI_K2_NUM_EXPERTS, + "Expected num_experts=", + DEFAULT_NUM_EXPERTS, + " or num_experts=", + KIMI_K2_NUM_EXPERTS, + ", but got num_experts=", + num_experts); TORCH_CHECK( num_tokens >= 1 && num_tokens <= 16, "currently num_tokens must be less than or equal to 16 for router_gemm"); TORCH_CHECK(mat_a.dtype() == torch::kBFloat16, "mat_a must be bf16"); @@ -110,18 +126,36 @@ void dsv3_router_gemm( const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (output.dtype() == torch::kFloat32) { - LoopUnroller<1, 16, num_experts, hidden_dim>::unroll_float_output( - num_tokens, - reinterpret_cast(output.mutable_data_ptr()), - reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()), - reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), - stream); + if (num_experts == DEFAULT_NUM_EXPERTS) { + LoopUnroller<1, 16, DEFAULT_NUM_EXPERTS, DEFAULT_HIDDEN_DIM>::unroll_float_output( + num_tokens, + reinterpret_cast(output.mutable_data_ptr()), + reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()), + reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), + stream); + } else if (num_experts == KIMI_K2_NUM_EXPERTS) { + LoopUnroller<1, 16, KIMI_K2_NUM_EXPERTS, DEFAULT_HIDDEN_DIM>::unroll_float_output( + num_tokens, + reinterpret_cast(output.mutable_data_ptr()), + reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()), + reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), + stream); + } } else if (output.dtype() == torch::kBFloat16) { - LoopUnroller<1, 16, num_experts, hidden_dim>::unroll_bf16_output( - num_tokens, - reinterpret_cast<__nv_bfloat16*>(output.mutable_data_ptr()), - reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()), - reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), - stream); + if (num_experts == DEFAULT_NUM_EXPERTS) { + LoopUnroller<1, 16, DEFAULT_NUM_EXPERTS, DEFAULT_HIDDEN_DIM>::unroll_bf16_output( + num_tokens, + reinterpret_cast<__nv_bfloat16*>(output.mutable_data_ptr()), + reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()), + reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), + stream); + } else if (num_experts == KIMI_K2_NUM_EXPERTS) { + LoopUnroller<1, 16, KIMI_K2_NUM_EXPERTS, DEFAULT_HIDDEN_DIM>::unroll_bf16_output( + num_tokens, + reinterpret_cast<__nv_bfloat16*>(output.mutable_data_ptr()), + reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()), + reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), + stream); + } } } diff --git a/sgl-kernel/csrc/gemm/dsv3_router_gemm_float_out.cu b/sgl-kernel/csrc/gemm/dsv3_router_gemm_float_out.cu index e7577c55b..88a364e2c 100644 --- a/sgl-kernel/csrc/gemm/dsv3_router_gemm_float_out.cu +++ b/sgl-kernel/csrc/gemm/dsv3_router_gemm_float_out.cu @@ -184,6 +184,7 @@ void invokeRouterGemmFloatOutput(float* output, T const* mat_a, T const* mat_b, mat_b); } +// Template instantiations for DEFAULT_NUM_EXPERTS experts template void invokeRouterGemmFloatOutput<__nv_bfloat16, 1, 256, 7168>( float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); @@ -231,3 +232,52 @@ template void invokeRouterGemmFloatOutput<__nv_bfloat16, 15, 256, 7168>( template void invokeRouterGemmFloatOutput<__nv_bfloat16, 16, 256, 7168>( float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +// Template instantiations for KIMI_K2_NUM_EXPERTS experts +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 1, 384, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 2, 384, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 3, 384, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 4, 384, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 5, 384, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 6, 384, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 7, 384, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 8, 384, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 9, 384, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 10, 384, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 11, 384, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 12, 384, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 13, 384, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 14, 384, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 15, 384, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 16, 384, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); diff --git a/sgl-kernel/tests/test_dsv3_router_gemm.py b/sgl-kernel/tests/test_dsv3_router_gemm.py index 169c99671..575769d6d 100644 --- a/sgl-kernel/tests/test_dsv3_router_gemm.py +++ b/sgl-kernel/tests/test_dsv3_router_gemm.py @@ -5,8 +5,8 @@ from sgl_kernel import dsv3_router_gemm @pytest.mark.parametrize("num_tokens", [i + 1 for i in range(16)]) -def test_dsv3_router_gemm(num_tokens): - num_experts = 256 +@pytest.mark.parametrize("num_experts", [256, 384]) +def test_dsv3_router_gemm(num_tokens, num_experts): hidden_dim = 7168 mat_a = torch.randn(