diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 89a298c34..e8f9a0839 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -222,7 +222,9 @@ set(SOURCES "csrc/gemm/awq_kernel.cu" "csrc/gemm/bmm_fp8.cu" "csrc/gemm/dsv3_fused_a_gemm.cu" - "csrc/gemm/dsv3_router_gemm.cu" + "csrc/gemm/dsv3_router_gemm_bf16_out.cu" + "csrc/gemm/dsv3_router_gemm_entry.cu" + "csrc/gemm/dsv3_router_gemm_float_out.cu" "csrc/gemm/fp8_blockwise_gemm_kernel.cu" "csrc/gemm/fp8_gemm_kernel.cu" "csrc/gemm/int8_gemm_kernel.cu" diff --git a/sgl-kernel/benchmark/bench_dsv3_router_gemm.py b/sgl-kernel/benchmark/bench_dsv3_router_gemm.py index 16b3143f0..4502746f9 100644 --- a/sgl-kernel/benchmark/bench_dsv3_router_gemm.py +++ b/sgl-kernel/benchmark/bench_dsv3_router_gemm.py @@ -7,6 +7,48 @@ import triton.testing from sgl_kernel import dsv3_router_gemm +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["num_tokens"], + x_vals=[i + 1 for i in range(16)], + x_log=False, + line_arg="impl", + line_vals=["torch", "sgl-kernel"], + line_names=["torch", "dsv3_router_gemm"], + styles=[("blue", "-"), ("orange", "-")], + ylabel="TFLOPs", + plot_name="input-bf16-output-bf16 dsv3 router gemm throughput", + args={}, + ) +) +def benchmark_bf16_output(num_tokens, impl): + # M: num_tokens, K: hidden_dim, N: num_experts + M, K, N = num_tokens, 7168, 256 + + mat_a = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous() + mat_b = torch.randn((N, K), dtype=torch.bfloat16, device="cuda").contiguous() + + quantiles = [0.5, 0.2, 0.8] + + if impl == "torch": + + def runner(): + F.linear(mat_a, mat_b) + + elif impl == "sgl-kernel": + + def runner(): + dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.bfloat16) + + ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles) + + def tflops(t_ms): + flops = 2 * M * K * N + return flops / (t_ms * 1e-3) / 1e12 + + return tflops(ms), tflops(max_ms), tflops(min_ms) + + @triton.testing.perf_report( triton.testing.Benchmark( x_names=["num_tokens"], @@ -21,7 +63,7 @@ from sgl_kernel import dsv3_router_gemm args={}, ) ) -def benchmark(num_tokens, impl): +def benchmark_float_output(num_tokens, impl): # M: num_tokens, K: hidden_dim, N: num_experts M, K, N = num_tokens, 7168, 256 @@ -38,7 +80,7 @@ def benchmark(num_tokens, impl): elif impl == "sgl-kernel": def runner(): - dsv3_router_gemm(mat_a, mat_b) + dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.float32) ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles) @@ -53,4 +95,9 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() args = parser.parse_args() - benchmark.run(print_data=True, show_plots=True, save_path="bench_dsv3_router_gemm") + benchmark_bf16_output.run( + print_data=True, show_plots=True, save_path="bench_dsv3_router_gemm" + ) + benchmark_float_output.run( + print_data=True, show_plots=True, save_path="bench_dsv3_router_gemm" + ) diff --git a/sgl-kernel/csrc/gemm/dsv3_router_gemm_bf16_out.cu b/sgl-kernel/csrc/gemm/dsv3_router_gemm_bf16_out.cu new file mode 100644 index 000000000..ef011dfb0 --- /dev/null +++ b/sgl-kernel/csrc/gemm/dsv3_router_gemm_bf16_out.cu @@ -0,0 +1,234 @@ +/* + * Adapted from + * https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3RouterGemm.cu + * https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/thop/dsv3RouterGemmOp.cpp + * + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "cuda_bf16.h" +#include "cuda_runtime.h" +#include "utils.h" + +// Custom FMA implementation using PTX assembly instructions +__device__ __forceinline__ void fma(float2& d, float2 const& a, float2 const& b, float2 const& c) { + asm volatile("fma.rn.f32x2 %0, %1, %2, %3;\n" + : "=l"(reinterpret_cast(d)) + : "l"(reinterpret_cast(a)), + "l"(reinterpret_cast(b)), + "l"(reinterpret_cast(c))); +} + +// Convert 8 bfloat16 values from a uint4 to float array - optimized conversion +template +__device__ __forceinline__ void bf16_uint4_to_float8(uint4 const& vec, float* dst) { + __nv_bfloat16* bf16_ptr = reinterpret_cast<__nv_bfloat16*>(const_cast(&vec)); + +#pragma unroll + for (int i = 0; i < VPT; i++) { + dst[i] = __bfloat162float(bf16_ptr[i]); + } +} + +template +__global__ +__launch_bounds__(128, 1) void router_gemm_kernel_bf16_output(__nv_bfloat16* out, T const* mat_a, T const* mat_b) { + // Each block handles one expert column + int const n_idx = blockIdx.x; + int const tid = threadIdx.x; + constexpr int kWarpSize = 32; + constexpr int kNumWarps = kBlockSize / kWarpSize; + // Constants for this kernel + constexpr int k_elems_per_k_iteration = VPT * kBlockSize; + constexpr int k_iterations = kHiddenDim / k_elems_per_k_iteration; // Total K iterations + + // Initialize accumulators for all M rows + float acc[kNumTokens] = {}; + + // Shared memory for warp-level reduction + __shared__ float sm_reduction[kNumTokens][kNumWarps]; // kNumWarps + + // B matrix is in column-major order, so we can directly load a column for the n_idx expert + T const* b_col = mat_b + n_idx * kHiddenDim; + + // Pre-compute k_base values for each iteration to help compiler optimize + // int k_bases[k_iterations]; + int k_bases[k_iterations]; +#pragma unroll + for (int ki = 0; ki < k_iterations; ki++) { + k_bases[ki] = ki * k_elems_per_k_iteration + tid * VPT; + } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif + + // Process the GEMM in chunks + for (int ki = 0; ki < k_iterations; ki++) { + int const k_base = k_bases[ki]; + + // Load B matrix values using vector load (8 bf16 values) + uint4 b_vec = *reinterpret_cast(b_col + k_base); + + // Convert B values to float + float b_float[VPT]; + bf16_uint4_to_float8(b_vec, b_float); + +// Process each token +#pragma unroll + for (int m_idx = 0; m_idx < kNumTokens; m_idx++) { + // Load both rows of A matrix using vector loads + uint4 a_vec = *reinterpret_cast(mat_a + (m_idx * kHiddenDim) + k_base); + + // Convert A values to float + float a_float[VPT]; + bf16_uint4_to_float8(a_vec, a_float); + +// Process elements in this chunk +#pragma unroll + for (int k = 0; k < VPT; k++) { + float a = a_float[k]; + float b = b_float[k]; + acc[m_idx] += a * b; + } + } + } + + // Perform warp-level reduction + int const warpSize = 32; + int const warpId = tid / warpSize; + int const laneId = tid % warpSize; + + // Register for warp-level reduction results + float warp_result[kNumTokens]; + +#pragma unroll + for (int m_idx = 0; m_idx < kNumTokens; m_idx++) { + warp_result[m_idx] = acc[m_idx]; + } + +// Perform warp-level reduction using optimized butterfly pattern +#pragma unroll + for (int m = 0; m < kNumTokens; m++) { + float sum = warp_result[m]; + + // Butterfly reduction pattern + sum += __shfl_xor_sync(0xffffffff, sum, 16); + sum += __shfl_xor_sync(0xffffffff, sum, 8); + sum += __shfl_xor_sync(0xffffffff, sum, 4); + sum += __shfl_xor_sync(0xffffffff, sum, 2); + sum += __shfl_xor_sync(0xffffffff, sum, 1); + + // Only the first thread in each warp stores to shared memory + if (laneId == 0) { + sm_reduction[m][warpId] = sum; + } + } + + __syncthreads(); + + // Final reduction across warps (only first thread) + if (tid == 0) { +#pragma unroll + for (int m = 0; m < kNumTokens; m++) { + float final_sum = 0.0f; + +// Sum across the kNumWarps +#pragma unroll + for (int w = 0; w < kNumWarps; w++) { + final_sum += sm_reduction[m][w]; + } + + // Write final result + out[m * kNumExperts + n_idx] = __float2bfloat16(final_sum); + } + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} + +template +void invokeRouterGemmBf16Output(__nv_bfloat16* output, T const* mat_a, T const* mat_b, cudaStream_t stream) { + constexpr int VPT = 16 / sizeof(T); + constexpr int kBlockSize = 128; + cudaLaunchConfig_t config; + config.gridDim = kNumExperts; + config.blockDim = kBlockSize; + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = getEnvEnablePDL(); + config.numAttrs = 1; + config.attrs = attrs; + cudaLaunchKernelEx( + &config, + router_gemm_kernel_bf16_output, + output, + mat_a, + mat_b); +} + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 1, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 2, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 3, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 4, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 5, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 6, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 7, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 8, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 9, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 10, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 11, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 12, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 13, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 14, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 15, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 16, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); diff --git a/sgl-kernel/csrc/gemm/dsv3_router_gemm_entry.cu b/sgl-kernel/csrc/gemm/dsv3_router_gemm_entry.cu new file mode 100644 index 000000000..c316a8193 --- /dev/null +++ b/sgl-kernel/csrc/gemm/dsv3_router_gemm_entry.cu @@ -0,0 +1,127 @@ +/* + * Adapted from + * https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3RouterGemm.cu + * https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/thop/dsv3RouterGemmOp.cpp + * + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "cuda_bf16.h" +#include "cuda_runtime.h" +#include "utils.h" + +template +void invokeRouterGemmFloatOutput(float* output, T const* mat_a, T const* mat_b, cudaStream_t stream); + +template +void invokeRouterGemmBf16Output(__nv_bfloat16* output, T const* mat_a, T const* mat_b, cudaStream_t stream); + +template +struct LoopUnroller { + static void unroll_float_output( + int num_tokens, float* output, __nv_bfloat16 const* input, __nv_bfloat16 const* weights, cudaStream_t stream) { + if (num_tokens == kBegin) { + invokeRouterGemmFloatOutput<__nv_bfloat16, kBegin, kNumExperts, kHiddenDim>(output, input, weights, stream); + } else { + LoopUnroller::unroll_float_output( + num_tokens, output, input, weights, stream); + } + } + + static void unroll_bf16_output( + int num_tokens, + __nv_bfloat16* output, + __nv_bfloat16 const* input, + __nv_bfloat16 const* weights, + cudaStream_t stream) { + if (num_tokens == kBegin) { + invokeRouterGemmBf16Output<__nv_bfloat16, kBegin, kNumExperts, kHiddenDim>(output, input, weights, stream); + } else { + LoopUnroller::unroll_bf16_output( + num_tokens, output, input, weights, stream); + } + } +}; + +template +struct LoopUnroller { + static void unroll_float_output( + int num_tokens, float* output, __nv_bfloat16 const* input, __nv_bfloat16 const* weights, cudaStream_t stream) { + if (num_tokens == kEnd) { + invokeRouterGemmFloatOutput<__nv_bfloat16, kEnd, kNumExperts, kHiddenDim>(output, input, weights, stream); + } else { + throw std::invalid_argument("Invalid num_tokens, only supports 1 to 16"); + } + } + + static void unroll_bf16_output( + int num_tokens, + __nv_bfloat16* output, + __nv_bfloat16 const* input, + __nv_bfloat16 const* weights, + cudaStream_t stream) { + if (num_tokens == kEnd) { + invokeRouterGemmBf16Output<__nv_bfloat16, kEnd, kNumExperts, kHiddenDim>(output, input, weights, stream); + } else { + throw std::invalid_argument("Invalid num_tokens, only supports 1 to 16"); + } + } +}; + +void dsv3_router_gemm( + torch::Tensor& output, // [num_tokens, num_experts] + const torch::Tensor& mat_a, // [num_tokens, hidden_dim] + const torch::Tensor& mat_b // [num_experts, hidden_dim] +) { + TORCH_CHECK(output.dim() == 2 && mat_a.dim() == 2 && mat_b.dim() == 2); + + const int num_tokens = mat_a.size(0); + constexpr int num_experts = 256; + constexpr int hidden_dim = 7168; + + TORCH_CHECK(mat_a.size(1) == mat_b.size(1), "mat_a and mat_b must have the same hidden_dim"); + TORCH_CHECK(mat_a.size(1) == hidden_dim, "currently hidden_dim only supports 7168"); + TORCH_CHECK(mat_b.size(0) == num_experts, "currently num_experts only supports 256"); + TORCH_CHECK( + num_tokens >= 1 && num_tokens <= 16, "currently num_tokens must be less than or equal to 16 for router_gemm"); + TORCH_CHECK(mat_a.dtype() == torch::kBFloat16, "mat_a must be bf16"); + TORCH_CHECK(mat_b.dtype() == torch::kBFloat16, "mat_b must be bf16"); + TORCH_CHECK( + output.dtype() == torch::kFloat32 || output.dtype() == torch::kBFloat16, "output must be float32 or bf16"); + + auto const sm = getSMVersion(); + TORCH_CHECK(sm >= 90, "required CUDA ARCH >= SM_90"); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (output.dtype() == torch::kFloat32) { + LoopUnroller<1, 16, num_experts, hidden_dim>::unroll_float_output( + num_tokens, + reinterpret_cast(output.mutable_data_ptr()), + reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()), + reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), + stream); + } else if (output.dtype() == torch::kBFloat16) { + LoopUnroller<1, 16, num_experts, hidden_dim>::unroll_bf16_output( + num_tokens, + reinterpret_cast<__nv_bfloat16*>(output.mutable_data_ptr()), + reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()), + reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), + stream); + } +} diff --git a/sgl-kernel/csrc/gemm/dsv3_router_gemm.cu b/sgl-kernel/csrc/gemm/dsv3_router_gemm_float_out.cu similarity index 54% rename from sgl-kernel/csrc/gemm/dsv3_router_gemm.cu rename to sgl-kernel/csrc/gemm/dsv3_router_gemm_float_out.cu index 410bbcefd..e7577c55b 100644 --- a/sgl-kernel/csrc/gemm/dsv3_router_gemm.cu +++ b/sgl-kernel/csrc/gemm/dsv3_router_gemm_float_out.cu @@ -46,7 +46,7 @@ __device__ __forceinline__ void bf16_uint4_to_float8(uint4 const& vec, float* ds } template -__global__ __launch_bounds__(128, 1) void router_gemm_kernel(float* out, T const* mat_a, T const* mat_b) { +__global__ __launch_bounds__(128, 1) void router_gemm_kernel_float_output(float* out, T const* mat_a, T const* mat_b) { // Each block handles one expert column int const n_idx = blockIdx.x; int const tid = threadIdx.x; @@ -163,7 +163,7 @@ __global__ __launch_bounds__(128, 1) void router_gemm_kernel(float* out, T const } template -void invokeRouterGemm(float* output, T const* mat_a, T const* mat_b, cudaStream_t stream) { +void invokeRouterGemmFloatOutput(float* output, T const* mat_a, T const* mat_b, cudaStream_t stream) { constexpr int VPT = 16 / sizeof(T); constexpr int kBlockSize = 128; cudaLaunchConfig_t config; @@ -177,110 +177,57 @@ void invokeRouterGemm(float* output, T const* mat_a, T const* mat_b, cudaStream_ config.numAttrs = 1; config.attrs = attrs; cudaLaunchKernelEx( - &config, router_gemm_kernel, output, mat_a, mat_b); + &config, + router_gemm_kernel_float_output, + output, + mat_a, + mat_b); } -template void -invokeRouterGemm<__nv_bfloat16, 1, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 1, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void -invokeRouterGemm<__nv_bfloat16, 2, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 2, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void -invokeRouterGemm<__nv_bfloat16, 3, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 3, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void -invokeRouterGemm<__nv_bfloat16, 4, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 4, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void -invokeRouterGemm<__nv_bfloat16, 5, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 5, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void -invokeRouterGemm<__nv_bfloat16, 6, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 6, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void -invokeRouterGemm<__nv_bfloat16, 7, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 7, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void -invokeRouterGemm<__nv_bfloat16, 8, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 8, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void -invokeRouterGemm<__nv_bfloat16, 9, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 9, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void -invokeRouterGemm<__nv_bfloat16, 10, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 10, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void -invokeRouterGemm<__nv_bfloat16, 11, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 11, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void -invokeRouterGemm<__nv_bfloat16, 12, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 12, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void -invokeRouterGemm<__nv_bfloat16, 13, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 13, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void -invokeRouterGemm<__nv_bfloat16, 14, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 14, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void -invokeRouterGemm<__nv_bfloat16, 15, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 15, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void -invokeRouterGemm<__nv_bfloat16, 16, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); - -template -struct LoopUnroller { - static void - unroll(int num_tokens, float* output, __nv_bfloat16 const* input, __nv_bfloat16 const* weights, cudaStream_t stream) { - if (num_tokens == kBegin) { - invokeRouterGemm<__nv_bfloat16, kBegin, kNumExperts, kHiddenDim>(output, input, weights, stream); - } else { - LoopUnroller::unroll(num_tokens, output, input, weights, stream); - } - } -}; - -template -struct LoopUnroller { - static void - unroll(int num_tokens, float* output, __nv_bfloat16 const* input, __nv_bfloat16 const* weights, cudaStream_t stream) { - if (num_tokens == kEnd) { - invokeRouterGemm<__nv_bfloat16, kEnd, kNumExperts, kHiddenDim>(output, input, weights, stream); - } else { - throw std::invalid_argument("Invalid num_tokens, only supports 1 to 16"); - } - } -}; - -void dsv3_router_gemm( - torch::Tensor& output, // [num_tokens, num_experts] - const torch::Tensor& mat_a, // [num_tokens, hidden_dim] - const torch::Tensor& mat_b // [num_experts, hidden_dim] -) { - TORCH_CHECK(output.dim() == 2 && mat_a.dim() == 2 && mat_b.dim() == 2); - - const int num_tokens = mat_a.size(0); - constexpr int num_experts = 256; - constexpr int hidden_dim = 7168; - - TORCH_CHECK(mat_a.size(1) == mat_b.size(1), "mat_a and mat_b must have the same hidden_dim"); - TORCH_CHECK(mat_a.size(1) == hidden_dim, "currently hidden_dim only supports 7168"); - TORCH_CHECK(mat_b.size(0) == num_experts, "currently num_experts only supports 256"); - TORCH_CHECK( - num_tokens >= 1 && num_tokens <= 16, "currently num_tokens must be less than or equal to 16 for router_gemm"); - TORCH_CHECK(mat_a.dtype() == torch::kBFloat16, "mat_a must be bf16"); - TORCH_CHECK(mat_b.dtype() == torch::kBFloat16, "mat_b must be bf16"); - TORCH_CHECK(output.dtype() == torch::kFloat32, "output must be float32"); - - auto const sm = getSMVersion(); - TORCH_CHECK(sm >= 90, "required CUDA ARCH >= SM_90"); - - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - LoopUnroller<1, 16, num_experts, hidden_dim>::unroll( - num_tokens, - reinterpret_cast(output.mutable_data_ptr()), - reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()), - reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), - stream); -} +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 16, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); diff --git a/sgl-kernel/python/sgl_kernel/gemm.py b/sgl-kernel/python/sgl_kernel/gemm.py index 6ec4ce78a..7435cfdda 100644 --- a/sgl-kernel/python/sgl_kernel/gemm.py +++ b/sgl-kernel/python/sgl_kernel/gemm.py @@ -262,12 +262,13 @@ def qserve_w4a8_per_group_gemm( def dsv3_router_gemm( hidden_states: torch.Tensor, router_weights: torch.Tensor, + out_dtype: torch.dtype = torch.bfloat16, ) -> torch.Tensor: output = torch.empty( hidden_states.shape[0], router_weights.shape[0], device=hidden_states.device, - dtype=torch.float32, + dtype=out_dtype, ) torch.ops.sgl_kernel.dsv3_router_gemm( output, diff --git a/sgl-kernel/tests/test_dsv3_router_gemm.py b/sgl-kernel/tests/test_dsv3_router_gemm.py index 1b60bcf92..169c99671 100644 --- a/sgl-kernel/tests/test_dsv3_router_gemm.py +++ b/sgl-kernel/tests/test_dsv3_router_gemm.py @@ -15,17 +15,20 @@ def test_dsv3_router_gemm(num_tokens): mat_b = torch.randn( (num_experts, hidden_dim), dtype=torch.bfloat16, device="cuda" ).contiguous() - output = torch.empty( - (num_tokens, num_experts), dtype=torch.float32, device="cuda" - ).contiguous() - ref = F.linear(mat_a, mat_b).to(torch.float32) + bf16_ref = F.linear(mat_a, mat_b) + float_ref = bf16_ref.to(torch.float32) - output = dsv3_router_gemm(mat_a, mat_b) + bf16_output = dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.bfloat16) + float_output = dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.float32) assert torch.allclose( - output, ref, rtol=1e-2, atol=1e-3 - ), "Router GEMM output mismatch with torch.nn.functional.linear reference" + bf16_output, bf16_ref, rtol=1e-2, atol=1e-3 + ), "Router GEMM output in bf16 dtype mismatch with torch.nn.functional.linear reference" + + assert torch.allclose( + float_output, float_ref, rtol=1e-2, atol=1e-3 + ), "Router GEMM output in float32 dtype mismatch with torch.nn.functional.linear reference" if __name__ == "__main__":