From bdde237562ebd70944aef94873a96f765578b048 Mon Sep 17 00:00:00 2001 From: JieXin Liang Date: Tue, 15 Apr 2025 03:35:43 +0800 Subject: [PATCH] [perf] experimental enhance fp8 per-tensor quant (#5370) --- .../srt/layers/quantization/fp8_kernel.py | 100 ++++++++++++++++++ .../srt/layers/quantization/fp8_utils.py | 16 ++- python/sglang/srt/models/deepseek_v2.py | 18 ++-- python/sglang/test/test_block_fp8.py | 57 ++++++++++ 4 files changed, 178 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 72ec99c6f..e2b597c4f 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -839,3 +839,103 @@ def w8a8_block_fp8_matmul( ) return C + + +@triton.jit +def _per_tensor_quant_mla_fp8_stage1( + x_ptr, + x_s_ptr, + head_size, + x_stride_h, + x_stride_s, + eps, + fp8_max, + BLOCK_SIZE: tl.constexpr, +): + seq_id = tl.program_id(0) + head_id = tl.program_id(1) + offset = tl.arange(0, BLOCK_SIZE) + mask = offset < head_size + + x_ptr += head_id * x_stride_h + seq_id * x_stride_s + x = tl.load(x_ptr + offset, mask=mask, other=0.0).to(tl.float32) + _absmax = tl.maximum(tl.max(tl.abs(x)), eps) + + tl.atomic_max(x_s_ptr, _absmax / fp8_max) + + +@triton.jit +def _per_tensor_quant_mla_fp8_stage2( + x_ptr, + x_s_ptr, + x_q_ptr, + num_seq, + head_size, + x_stride_h, + x_stride_s, + fp8_min, + fp8_max, + BLOCK_SIZE: tl.constexpr, +): + seq_id = tl.program_id(0) + head_id = tl.program_id(1) + offset = tl.arange(0, BLOCK_SIZE) + mask = offset < head_size + + x_s = tl.load(x_s_ptr) + x_s_inv = 1.0 / x_s + + x_ptr += head_id * x_stride_h + seq_id * x_stride_s + x_q_ptr += head_id * num_seq * head_size + seq_id * head_size + + x = tl.load(x_ptr + offset, mask=mask, other=0.0).to(tl.float32) + x_q = tl.clamp(x * x_s_inv, fp8_min, fp8_max).to(x_q_ptr.dtype.element_ty) + tl.store(x_q_ptr + offset, x_q, mask=mask) + + +def per_tensor_quant_mla_fp8( + x: torch.Tensor, eps: float = 1e-12, dtype: torch.dtype = torch.float8_e4m3fn +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + This function quantizes input values to float8 values with tensor-wise quantization + and specialized for mla absorbed case. + """ + assert x.dim() == 3, "`x` is not a 3d-tensor" + + finfo = torch.finfo(dtype) + fp8_max = finfo.max + if _is_hip: + dtype = torch.float8_e4m3fnuz + fp8_max = 224.0 + + x_q = x.new_empty(x.size(), dtype=dtype) + x_s = torch.zeros((1,), dtype=torch.float32, device=x.device) + + num_head, num_seq, head_size = x.shape + BLOCK_SIZE = triton.next_power_of_2(head_size) + grid = (num_seq, num_head) + + _per_tensor_quant_mla_fp8_stage1[grid]( + x, + x_s, + head_size, + x.stride(0), + x.stride(1), + eps, + fp8_max, + BLOCK_SIZE, + ) + _per_tensor_quant_mla_fp8_stage2[grid]( + x, + x_s, + x_q, + num_seq, + head_size, + x.stride(0), + x.stride(1), + -fp8_max, + fp8_max, + BLOCK_SIZE, + ) + + return x_q, x_s diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 7d80f6e0d..b9f4e2804 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -168,13 +168,13 @@ def input_to_float8( """This function quantizes input values to float8 values with tensor-wise quantization.""" finfo = torch.finfo(dtype) min_val, max_val = x.aminmax() - amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + amax = torch.maximum(min_val.abs(), max_val.abs()).float().clamp(min=1e-12) fp8_max = finfo.max if _is_hip: dtype = torch.float8_e4m3fnuz fp8_max = 224.0 scale = fp8_max / amax - x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max) + x_scl_sat = (x.float() * scale).clamp(min=-fp8_max, max=fp8_max) return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() @@ -213,7 +213,11 @@ def block_quant_to_tensor_quant( for j in range(n_tiles): x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i] - x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype) + x_q_tensor, scale = ( + sgl_scaled_fp8_quant(x_dq_block) + if _is_cuda + else input_to_float8(x_dq_block, dtype=x_q_block.dtype) + ) return x_q_tensor, scale @@ -222,7 +226,11 @@ def channel_quant_to_tensor_quant( x_s: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: x_dq_channel = x_q_channel.to(torch.float32) * x_s - x_q_tensor, scale = input_to_float8(x_dq_channel, dtype=x_q_channel.dtype) + x_q_tensor, scale = ( + sgl_scaled_fp8_quant(x_dq_channel) + if _is_cuda + else input_to_float8(x_dq_channel, dtype=x_q_channel.dtype) + ) return x_q_tensor, scale diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index d581200cf..64a886653 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -53,10 +53,10 @@ from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.quantization.fp8_kernel import per_tensor_quant_mla_fp8 from sglang.srt.layers.quantization.fp8_utils import ( block_quant_to_tensor_quant, channel_quant_to_tensor_quant, - input_to_float8, normalize_e4m3fn_to_e4m3fnuz, ) from sglang.srt.layers.quantization.int8_utils import ( @@ -817,8 +817,8 @@ class DeepseekV2AttentionMLA(nn.Module): self.w_kc.to(torch.bfloat16) * self.w_scale, ) elif self.w_kc.dtype == torch.float8_e4m3fn: - q_nope_val, q_nope_scale = input_to_float8( - q_nope.transpose(0, 1), torch.float8_e4m3fn + q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8( + q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn ) q_nope_out = bmm_fp8( q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16 @@ -848,8 +848,8 @@ class DeepseekV2AttentionMLA(nn.Module): self.w_vc.to(torch.bfloat16) * self.w_scale, ) elif self.w_vc.dtype == torch.float8_e4m3fn: - attn_output_val, attn_output_scale = input_to_float8( - attn_output.transpose(0, 1), torch.float8_e4m3fn + attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8( + attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn ) attn_bmm_output = bmm_fp8( attn_output_val, @@ -895,8 +895,8 @@ class DeepseekV2AttentionMLA(nn.Module): self.w_kc.to(torch.bfloat16) * self.w_scale, ) elif self.w_kc.dtype == torch.float8_e4m3fn: - q_nope_val, q_nope_scale = input_to_float8( - q_nope.transpose(0, 1), torch.float8_e4m3fn + q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8( + q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn ) q_nope_out = bmm_fp8( q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16 @@ -991,8 +991,8 @@ class DeepseekV2AttentionMLA(nn.Module): self.w_vc.to(torch.bfloat16) * self.w_scale, ) elif self.w_vc.dtype == torch.float8_e4m3fn: - attn_output_val, attn_output_scale = input_to_float8( - attn_output.transpose(0, 1), torch.float8_e4m3fn + attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8( + attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn ) attn_bmm_output = bmm_fp8( attn_output_val, diff --git a/python/sglang/test/test_block_fp8.py b/python/sglang/test/test_block_fp8.py index a7c068ac2..c7cdd34ca 100644 --- a/python/sglang/test/test_block_fp8.py +++ b/python/sglang/test/test_block_fp8.py @@ -7,10 +7,12 @@ import torch from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe from sglang.srt.layers.quantization.fp8_kernel import ( + per_tensor_quant_mla_fp8, per_token_group_quant_fp8, static_quant_fp8, w8a8_block_fp8_matmul, ) +from sglang.srt.layers.quantization.fp8_utils import input_to_float8 from sglang.test.test_utils import CustomTestCase _is_cuda = torch.cuda.is_available() and torch.version.cuda @@ -155,6 +157,61 @@ class TestStaticQuantFP8(CustomTestCase): self._static_quant_fp8(*params) +class TestPerTensorQuantMlaFP8(CustomTestCase): + DTYPES = [torch.half, torch.bfloat16, torch.float32] + NUM_TOKENS = [7, 83, 2048] + D = [512, 4096, 5120, 13824] + LAST_D_EXT = [1024, 0] + LAST_D = [512] + SEEDS = [0] + + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA is not available") + torch.set_default_device("cuda") + + def _per_tensor_quant_mla_fp8(self, num_tokens, d, last_d_ext, last_d, dtype, seed): + torch.manual_seed(seed) + + x = torch.rand( + (num_tokens, d // last_d, last_d + last_d_ext), + dtype=dtype, + ) + x_sub, _ = x.split([last_d, last_d_ext], dim=-1) + + with torch.inference_mode(): + ref_out, ref_s = input_to_float8(x_sub.transpose(0, 1)) + out, out_s = per_tensor_quant_mla_fp8(x_sub.transpose(0, 1)) + + self.assertTrue(out.is_contiguous()) + self.assertTrue( + torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.50) + ) + self.assertTrue( + torch.allclose(out_s.to(torch.float32), ref_s.to(torch.float32)) + ) + + def test_per_tensor_quant_mla_fp8(self): + for params in itertools.product( + self.NUM_TOKENS, + self.D, + self.LAST_D_EXT, + self.LAST_D, + self.DTYPES, + self.SEEDS, + ): + with self.subTest( + num_tokens=params[0], + d=params[1], + last_d_ext=params[2], + last_d=params[3], + dtype=params[4], + seed=params[5], + ): + self._per_tensor_quant_mla_fp8(*params) + + # For test def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16): """This function performs matrix multiplication with block-wise quantization using native torch.