From e0917e6bd0fbbbbc8ba3db48ae26f07366ab9a0c Mon Sep 17 00:00:00 2001 From: Stefan He Date: Wed, 12 Mar 2025 00:08:03 -0700 Subject: [PATCH] Remove vllm ops scaled fp8 quant and accelerate per token quant by 20-28% (#4215) Co-authored-by: Stefan He --- python/sglang/srt/custom_op.py | 59 +++++++++++++ python/sglang/srt/layers/moe/ep_moe/layer.py | 30 +++++-- .../layers/moe/fused_moe_triton/fused_moe.py | 20 +++-- python/sglang/test/test_custom_ops.py | 88 +++++++++++++++++++ sgl-kernel/csrc/gemm/per_token_quant_fp8.cu | 42 ++++----- 5 files changed, 202 insertions(+), 37 deletions(-) create mode 100644 python/sglang/test/test_custom_ops.py diff --git a/python/sglang/srt/custom_op.py b/python/sglang/srt/custom_op.py index 2e066efe8..d6fa29a70 100644 --- a/python/sglang/srt/custom_op.py +++ b/python/sglang/srt/custom_op.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch from torch import nn @@ -40,3 +42,60 @@ class CustomOp(nn.Module): return self.forward_hip else: return self.forward_native + + +if _is_cuda: + from sgl_kernel import sgl_per_tensor_quant_fp8, sgl_per_token_quant_fp8 + + def scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + use_per_token_if_dynamic: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to FP8 (8-bit floating point) format. + + Args: + input (torch.Tensor): Input tensor to be quantized + scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization. + If None, scales will be computed dynamically. + use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None), + determines the quantization granularity: + - True: compute scale per token + - False: compute single scale per tensor + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - quantized_tensor: The FP8 quantized version of input + - scale_tensor: The scaling factors used for quantization + + Raises: + AssertionError: If input is not 2D or if static scale's numel != 1 + """ + assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D" + shape = input.shape + out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn + output = torch.empty(shape, device=input.device, dtype=out_dtype) + + if scale is None: + # Dynamic scaling + if use_per_token_if_dynamic: + scale = torch.empty( + (shape[0], 1), device=input.device, dtype=torch.float32 + ) + sgl_per_token_quant_fp8(input, output, scale) + else: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + sgl_per_tensor_quant_fp8( + input, output, scale, is_static=False + ) # False for dynamic + else: + # Static scaling + assert ( + scale.numel() == 1 + ), f"Expected scalar scale, got numel={scale.numel()}" + sgl_per_tensor_quant_fp8( + input, output, scale, is_static=True + ) # True for static + + return output, scale diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 22eac0496..c94437221 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -3,7 +3,7 @@ from typing import Callable, List, Optional, Tuple import torch from torch.nn import Module -from vllm import _custom_ops as ops +from vllm import _custom_ops as vllm_ops from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import ( @@ -26,7 +26,13 @@ from sglang.srt.layers.quantization.base_config import ( QuantizeMethodBase, ) from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod -from sglang.srt.utils import is_hip, set_weight_attrs +from sglang.srt.utils import is_cuda, is_hip, set_weight_attrs + +_is_cuda = is_cuda() + +if _is_cuda: + from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant + logger = logging.getLogger(__name__) @@ -719,12 +725,20 @@ class Fp8EPMoEMethod(Fp8MoEMethod): ) for expert in range(layer.num_experts_per_partition): - w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( - ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) - ) - w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( - ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) - ) + if _is_cuda: + w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( + sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) + ) + w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( + sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) + ) + else: + w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( + vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) + ) + w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( + vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) + ) layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) return diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 40dbdc35d..4ccaf59e6 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -11,7 +11,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple import torch import triton import triton.language as tl -from vllm import _custom_ops as ops +from vllm import _custom_ops as vllm_ops from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 @@ -42,6 +42,7 @@ _is_cuda = is_cuda() if _is_cuda: from sgl_kernel import gelu_and_mul, silu_and_mul + from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant from sglang.srt.layers.quantization.fp8_kernel import ( sglang_per_token_group_quant_fp8, ) @@ -486,7 +487,7 @@ def moe_align_block_size( cumsum_buffer, ) else: - ops.moe_align_block_size( + vllm_ops.moe_align_block_size( topk_ids, num_experts, block_size, @@ -527,7 +528,10 @@ def invoke_fused_moe_kernel( if block_shape is None: # activation tensor-wise fp8 quantization, dynamic or static padded_size = padding_size - A, A_scale = ops.scaled_fp8_quant(A, A_scale) + if _is_cuda: + A, A_scale = sgl_scaled_fp8_quant(A, A_scale) + else: + A, A_scale = vllm_ops.scaled_fp8_quant(A, A_scale) else: # activation block-wise fp8 quantization assert len(block_shape) == 2 @@ -1095,12 +1099,16 @@ def fused_experts_impl( if _is_cuda: silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) else: - ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + vllm_ops.silu_and_mul( + intermediate_cache2, intermediate_cache1.view(-1, N) + ) elif activation == "gelu": if _is_cuda: gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) else: - ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + vllm_ops.gelu_and_mul( + intermediate_cache2, intermediate_cache1.view(-1, N) + ) else: raise ValueError(f"Unsupported activation: {activation=}") @@ -1132,7 +1140,7 @@ def fused_experts_impl( if no_combine: pass elif _is_hip: - ops.moe_sum( + vllm_ops.moe_sum( intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states[begin_chunk_idx:end_chunk_idx], ) diff --git a/python/sglang/test/test_custom_ops.py b/python/sglang/test/test_custom_ops.py new file mode 100644 index 000000000..6fa343dcb --- /dev/null +++ b/python/sglang/test/test_custom_ops.py @@ -0,0 +1,88 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/8ca7a71df787ad711ad3ac70a5bd2eb2bb398938/tests/quantization/test_fp8.py + +import pytest +import torch + +from sglang.srt.custom_op import scaled_fp8_quant +from sglang.srt.utils import is_cuda + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_scaled_fp8_quant_per_tensor(dtype) -> None: + + def quantize_ref_per_tensor(tensor, inv_scale): + # The reference implementation that fully aligns to + # the kernel being tested. + finfo = torch.finfo(torch.float8_e4m3fn) + scale = inv_scale.reciprocal() + qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max) + qweight = qweight.to(torch.float8_e4m3fn) + return qweight + + def dequantize_per_tensor(tensor, inv_scale, dtype): + fake_qweight = tensor.to(dtype) + dq_weight = fake_qweight * inv_scale + return dq_weight + + # Note that we use a shape % 8 != 0 to cover edge cases, + # because scaled_fp8_quant is vectorized by 8. + x = (torch.randn(size=(11, 11), device="cuda") * 13).to(dtype) + + # Test Per Tensor Dynamic quantization + # scale = max(abs(x)) / FP8_E4M3_MAX + y, scale = scaled_fp8_quant(x, None) + ref_y = quantize_ref_per_tensor(x, scale) + torch.testing.assert_close(y, ref_y) + torch.testing.assert_close( + dequantize_per_tensor(y, scale, dtype), + dequantize_per_tensor(ref_y, scale, dtype), + ) + + # Test Per Tensor Static quantization + y, _ = scaled_fp8_quant(x, scale) + ref_y = quantize_ref_per_tensor(x, scale) + torch.testing.assert_close(y, ref_y) + torch.testing.assert_close( + dequantize_per_tensor(y, scale, dtype), + dequantize_per_tensor(ref_y, scale, dtype), + ) + + +if is_cuda: + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_scaled_fp8_quant_per_token_dynamic(dtype) -> None: + def quantize_ref_per_token(tensor, inv_scale): + # The reference implementation that fully aligns to + # the kernel being tested. + finfo = torch.finfo(torch.float8_e4m3fn) + scale = inv_scale.reciprocal() + qweight = (tensor.to(torch.float32) * scale).clamp( + min=finfo.min, max=finfo.max + ) + qweight = qweight.to(torch.float8_e4m3fn) + return qweight + + def dequantize_per_token(tensor, inv_scale, dtype): + fake_qweight = tensor.to(dtype) + dq_weight = fake_qweight * inv_scale + return dq_weight + + # Note that we use a shape % 8 = 0, + # because per_token_quant_fp8 is vectorized by 8 elements. + x = (torch.randn(size=(11, 16), device="cuda") * 13).to(dtype) + + # Test Per Tensor Dynamic quantization + # scale = max(abs(x)) / FP8_E4M3_MAX + y, scale = scaled_fp8_quant(x, None, use_per_token_if_dynamic=True) + ref_y = quantize_ref_per_token(x, scale) + torch.testing.assert_close(y, ref_y) + torch.testing.assert_close( + dequantize_per_token(y, scale, dtype), + dequantize_per_token(ref_y, scale, dtype), + ) + + +if __name__ == "__main__": + # Run the specific test function directly + pytest.main([__file__]) diff --git a/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu b/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu index 12616ff44..9c3b67768 100644 --- a/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu +++ b/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu @@ -14,7 +14,6 @@ __global__ void per_token_quant_fp8_kernel( const int64_t hidden_dim, const int64_t num_tokens) { const int token_idx = blockIdx.x; - if (token_idx >= num_tokens) return; const int tid = threadIdx.x; @@ -25,9 +24,20 @@ __global__ void per_token_quant_fp8_kernel( float max_value = 0.0f; - for (int i = tid; i < hidden_dim; i += block_dim) { - float val = static_cast(token_input[i]); - max_value = fmaxf(max_value, fabsf(val)); + constexpr uint32_t vec_size = 16 / sizeof(T); + using vec_t = flashinfer::vec_t; + const int32_t num_vec_elems = hidden_dim / vec_size; + + // Find max using vectorized loads + for (int32_t i = tid; i < num_vec_elems; i += block_dim) { + vec_t input_vec; + input_vec.cast_load(token_input + i * vec_size); + +#pragma unroll + for (uint32_t j = 0; j < vec_size; ++j) { + float val = static_cast(input_vec[j]); + max_value = fmaxf(max_value, fabsf(val)); + } } max_value = blockReduceMax(max_value); @@ -41,11 +51,7 @@ __global__ void per_token_quant_fp8_kernel( const float scale_val = 1.0f / block_max; - constexpr uint32_t vec_size = 16 / sizeof(T); - using vec_t = flashinfer::vec_t; - - const int32_t num_vec_elems = hidden_dim / vec_size; - + // Quantize using vectorized loads for (int32_t i = tid; i < num_vec_elems; i += block_dim) { vec_t input_vec; input_vec.cast_load(token_input + i * vec_size); @@ -53,7 +59,7 @@ __global__ void per_token_quant_fp8_kernel( FP8_TYPE output_arr[vec_size]; #pragma unroll for (uint32_t j = 0; j < vec_size; ++j) { - float val = fmax(fmin(static_cast(input_vec[j]) * scale_val, FP8_E4M3_MAX), -FP8_E4M3_MAX); + float val = fmaxf(fminf(static_cast(input_vec[j]) * scale_val, FP8_E4M3_MAX), -FP8_E4M3_MAX); #ifndef USE_ROCM output_arr[j] = static_cast(val); #else @@ -68,18 +74,6 @@ __global__ void per_token_quant_fp8_kernel( token_output[i * vec_size + j] = output_arr[j]; } } - - const int32_t remaining_start = num_vec_elems * vec_size; - for (int32_t idx = remaining_start + tid; idx < hidden_dim; idx += block_dim) { - float val = fmax(-FP8_E4M3_MAX, fmin(static_cast(token_input[idx]) * scale_val, FP8_E4M3_MAX)); -#ifndef USE_ROCM - token_output[idx] = static_cast(val); -#else - token_output[idx] = c10::Float8_e4m3fnuz( - __hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret), - c10::Float8_e4m3fnuz::from_bits()); -#endif - } } void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch::Tensor output_s) { @@ -91,7 +85,9 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch: const int64_t num_tokens = input_sizes[0]; const int64_t hidden_dim = input_sizes[1]; - const int block_size = 128; + TORCH_CHECK(hidden_dim % 8 == 0, "Hidden dimension must be divisible by 8, but got ", hidden_dim); + + const int block_size = 256; const int num_blocks = num_tokens; dim3 grid(num_blocks);