diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 6ba686621..89e8d23bf 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -44,6 +44,7 @@ else: fp8_min = -fp8_max _enable_jit_deepgemm = False +_enable_jit_deepgemm_bmm = False if _is_cuda: import deep_gemm from sgl_kernel import ( @@ -53,10 +54,11 @@ if _is_cuda: ) sm_version = get_device_sm() - if sm_version == 90 and get_bool_env_var( - "SGL_ENABLE_JIT_DEEPGEMM", default="false" - ): - _enable_jit_deepgemm = True + if sm_version == 90: + if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="false"): + _enable_jit_deepgemm = True + if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM_BMM", default="false"): + _enable_jit_deepgemm_bmm = True logger = logging.getLogger(__name__) @@ -940,6 +942,108 @@ def per_tensor_quant_mla_fp8( return x_q, x_s_out +@triton.jit +def _per_token_group_quant_mla_deep_gemm_masked_fp8( + y_ptr, + y_q_ptr, + y_s_ptr, + masked_m_ptr, + group_size, + y_stride_b, + y_stride_t, + y_q_stride_b, + y_q_stride_t, + y_s_stride_b, + y_s_stride_g, + eps, + fp8_min, + fp8_max, + NUM_GROUP: tl.constexpr, + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor for deep_gemm grouped_gemm_masked. + This function converts the tensor values into float8 values. + y and y_q: (b, t, k) + y_s: (b, k//group_size, t) + """ + t_id = tl.program_id(0) + b_id = tl.program_id(1) + + y_ptr += b_id * y_stride_b + t_id * y_stride_t + y_q_ptr += b_id * y_q_stride_b + t_id * y_q_stride_t + y_s_ptr += b_id * y_s_stride_b + t_id + + if t_id == 0: + tl.store(masked_m_ptr + b_id, tl.num_programs(0)) + + cols = tl.arange(0, BLOCK) # group_size <= BLOCK + mask = cols < group_size + + for gid in range(NUM_GROUP): + y = tl.load(y_ptr + gid * group_size + cols, mask=mask, other=0.0).to( + tl.float32 + ) + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + gid * group_size + cols, y_q, mask=mask) + tl.store(y_s_ptr + gid * y_s_stride_g, y_s) + + +def per_tensor_quant_mla_deep_gemm_masked_fp8( + x: torch.Tensor, + group_size: int = 128, + eps: float = 1e-12, + dtype: torch.dtype = torch.float8_e4m3fn, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + This function quantizes input values to float8 values with per-token-group-quantization + for deep_gemm grouped_gemm_masked and specialized for mla absorbed case. + """ + assert x.dim() == 3, "`x` is not a 3d-tensor" + + finfo = torch.finfo(dtype) + fp8_max = finfo.max + if _is_hip: + dtype = torch.float8_e4m3fnuz + fp8_max = 224.0 + + b, m, k = x.shape + aligned_m = (m + 255) // 256 * 256 # 256 is the max block_m of the gemm kernel + num_tiles_k = k // group_size + assert num_tiles_k * group_size == k, f"k % {group_size} must be zero" + + x_q = x.new_empty((b, aligned_m, k), dtype=dtype) + x_s = x.new_empty((b, num_tiles_k, aligned_m), dtype=torch.float32) + masked_m = x.new_empty((b,), dtype=torch.int32) + + BLOCK_SIZE = triton.next_power_of_2(group_size) + grid = (m, b) + + _per_token_group_quant_mla_deep_gemm_masked_fp8[grid]( + x, + x_q, + x_s, + masked_m, + group_size, + x.stride(0), + x.stride(1), + x_q.stride(0), + x_q.stride(1), + x_s.stride(0), + x_s.stride(1), + eps, + -fp8_max, + fp8_max, + num_tiles_k, + BLOCK_SIZE, + ) + + return x_q, x_s.transpose(1, 2), masked_m, m, aligned_m + + def scaled_fp8_quant( input: torch.Tensor, scale: Optional[torch.Tensor] = None, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 47dc5beaf..625df4642 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -57,7 +57,11 @@ from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.layers.quantization.fp8_kernel import per_tensor_quant_mla_fp8 +from sglang.srt.layers.quantization.fp8_kernel import ( + _enable_jit_deepgemm_bmm, + per_tensor_quant_mla_deep_gemm_masked_fp8, + per_tensor_quant_mla_fp8, +) from sglang.srt.layers.quantization.fp8_utils import ( block_quant_to_tensor_quant, channel_quant_to_tensor_quant, @@ -82,6 +86,7 @@ _is_hip = is_hip() _is_cuda = is_cuda() if _is_cuda: + from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2 else: from vllm._custom_ops import awq_dequantize @@ -530,6 +535,10 @@ class DeepseekV2AttentionMLA(nn.Module): self.w_vc = None self.w_scale = None + self.w_scale_k = None + self.w_scale_v = None + self.use_deep_gemm_bmm = False + self.flashinfer_mla_disable_ragged = global_server_args_dict[ "flashinfer_mla_disable_ragged" ] @@ -684,7 +693,24 @@ class DeepseekV2AttentionMLA(nn.Module): ) q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - if self.w_kc.dtype == torch.float8_e4m3fnuz: + if self.use_deep_gemm_bmm: + q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = ( + per_tensor_quant_mla_deep_gemm_masked_fp8( + q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn + ) + ) + q_nope_out = q_nope.new_empty( + (self.num_local_heads, aligned_m, self.kv_lora_rank) + ) + m_grouped_gemm_fp8_fp8_bf16_nt_masked( + (q_nope_val, q_nope_scale), + (self.w_kc, self.w_scale_k), + q_nope_out, + masked_m, + expected_m, + ) + q_nope_out = q_nope_out[:, :expected_m, :] + elif self.w_kc.dtype == torch.float8_e4m3fnuz: # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz q_nope_out = torch.bmm( q_nope.to(torch.bfloat16).transpose(0, 1), @@ -716,7 +742,24 @@ class DeepseekV2AttentionMLA(nn.Module): attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) - if self.w_vc.dtype == torch.float8_e4m3fnuz: + if self.use_deep_gemm_bmm: + attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = ( + per_tensor_quant_mla_deep_gemm_masked_fp8( + attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn + ) + ) + attn_bmm_output = attn_output.new_empty( + (self.num_local_heads, aligned_m, self.v_head_dim) + ) + m_grouped_gemm_fp8_fp8_bf16_nt_masked( + (attn_output_val, attn_output_scale), + (self.w_vc, self.w_scale_v), + attn_bmm_output, + masked_m, + expected_m, + ) + attn_bmm_output = attn_bmm_output[:, :expected_m, :] + elif self.w_vc.dtype == torch.float8_e4m3fnuz: # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz attn_bmm_output = torch.bmm( attn_output.to(torch.bfloat16).transpose(0, 1), @@ -1439,6 +1482,10 @@ class DeepseekV2ForCausalLM(nn.Module): w = self_attn.kv_b_proj.weight # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`. # This may affect the accuracy of fp8 model. + # Fix deepseek v3 blockwise bmm by using deep_gemm + use_deep_gemm_bmm = False + model_dtype = torch.get_default_dtype() + if w.dtype in ( torch.float8_e4m3fn, torch.float8_e4m3fnuz, @@ -1457,10 +1504,20 @@ class DeepseekV2ForCausalLM(nn.Module): weight = w weight_scale = self_attn.kv_b_proj.weight_scale_inv - w, scale = block_quant_to_tensor_quant( - weight, weight_scale, weight_block_size - ) - self_attn.w_scale = scale + if ( + _is_cuda + and _enable_jit_deepgemm_bmm + and weight_block_size[0] == 128 + and weight_block_size[1] == 128 + and model_dtype == torch.bfloat16 + ): + block_scale = weight_scale + use_deep_gemm_bmm = True + else: + w, scale = block_quant_to_tensor_quant( + weight, weight_scale, weight_block_size + ) + self_attn.w_scale = scale else: weight = w weight_scale = self_attn.kv_b_proj.weight_scale @@ -1483,18 +1540,31 @@ class DeepseekV2ForCausalLM(nn.Module): w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to( torch.bfloat16 ) + w_kc, w_vc = w.unflatten( 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) - self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) - self_attn.w_vc = w_vc.contiguous().transpose(1, 2) - if ( - hasattr(self_attn.kv_b_proj, "weight_scale") - and self_attn.w_scale is None - ): - self_attn.w_scale = self_attn.kv_b_proj.weight_scale - if _is_hip: - self_attn.w_scale *= 2.0 + if not use_deep_gemm_bmm: + self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) + self_attn.w_vc = w_vc.contiguous().transpose(1, 2) + if ( + hasattr(self_attn.kv_b_proj, "weight_scale") + and self_attn.w_scale is None + ): + self_attn.w_scale = self_attn.kv_b_proj.weight_scale + if _is_hip: + self_attn.w_scale *= 2.0 + else: + num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1] + num_tiles_n = self_attn.v_head_dim // weight_block_size[0] + ws_kc, ws_vc = block_scale.unflatten( + 0, (-1, (num_tiles_k + num_tiles_n)) + ).split([num_tiles_k, num_tiles_n], dim=1) + self_attn.w_scale_k = ws_kc.transpose(1, 2).contiguous() + self_attn.w_scale_v = ws_vc.contiguous() + self_attn.w_kc = w_kc.transpose(1, 2).contiguous() + self_attn.w_vc = w_vc.contiguous() + self_attn.use_deep_gemm_bmm = True def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/test/test_block_fp8.py b/python/sglang/test/test_block_fp8.py index c7cdd34ca..117acf3a1 100644 --- a/python/sglang/test/test_block_fp8.py +++ b/python/sglang/test/test_block_fp8.py @@ -7,6 +7,7 @@ import torch from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe from sglang.srt.layers.quantization.fp8_kernel import ( + per_tensor_quant_mla_deep_gemm_masked_fp8, per_tensor_quant_mla_fp8, per_token_group_quant_fp8, static_quant_fp8, @@ -212,6 +213,62 @@ class TestPerTensorQuantMlaFP8(CustomTestCase): self._per_tensor_quant_mla_fp8(*params) +class TestPerTokenGroupQuantMlaDeepGemmMaskedFP8(CustomTestCase): + DTYPES = [torch.half, torch.bfloat16, torch.float32] + B = [128] + NUM_TOKENS = [7, 83, 2048, 1024 * 16] + D = [512, 128] + GROUP_SIZE = [128] + SEEDS = [0] + + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA is not available") + torch.set_default_device("cuda") + + def _per_token_group_quant_mla_deep_gemm_masked_fp8( + self, b, num_tokens, d, dtype, group_size, seed + ): + torch.manual_seed(seed) + + x = torch.rand(b, num_tokens, d, dtype=dtype) + + with torch.inference_mode(): + ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size, 1e-12) + out, scale, _, _, _ = per_tensor_quant_mla_deep_gemm_masked_fp8( + x, group_size + ) + out = out[:, :num_tokens, :] + scale = scale[:, :num_tokens, :] + + self.assertTrue( + torch.allclose( + out.to(torch.float32), ref_out.to(torch.float32), rtol=0.20, atol=1e-2 + ) + ) + self.assertTrue(torch.allclose(scale, ref_scale)) + + def test_per_token_group_quant_mla_deep_gemm_masked_fp8(self): + for params in itertools.product( + self.B, + self.NUM_TOKENS, + self.D, + self.DTYPES, + self.GROUP_SIZE, + self.SEEDS, + ): + with self.subTest( + b=params[0], + num_tokens=params[1], + d=params[2], + dtype=params[3], + group_size=params[4], + seed=params[5], + ): + self._per_token_group_quant_mla_deep_gemm_masked_fp8(*params) + + # For test def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16): """This function performs matrix multiplication with block-wise quantization using native torch. @@ -485,5 +542,115 @@ class TestW8A8BlockFP8FusedMoE(CustomTestCase): self._w8a8_block_fp8_fused_moe(*params) +# For test +def torch_w8a8_block_fp8_bmm(a, a_s, w, w_s, block_shape, out_dtype): + """This function performs bmm with block-wise quantization using native torch.""" + + B, N, _ = w.shape + _, M, _ = a.shape + out = torch.empty((B, M, N), dtype=out_dtype, device=a.device) + + for i in range(B): + out[i] = native_w8a8_block_fp8_matmul( + a[i], w[i], a_s[i], w_s[i], block_shape, output_dtype=out_dtype + ) + + return out + + +class TestW8A8BlockFP8BatchedDeepGemm(CustomTestCase): + DTYPES = [torch.bfloat16] + M = [1, 33, 64, 222, 8192] + N = [128, 512] + K = [128, 512] + BATCH = [128] + BLOCK_SIZE = [[128, 128]] + SEEDS = [0] + + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA is not available") + try: + import deep_gemm + except ImportError: + raise unittest.SkipTest("DeepGEMM is not available") + torch.set_default_device("cuda") + + def _w8a8_block_fp8_batched_deep_gemm(self, M, N, K, B, block_size, dtype, seed): + torch.manual_seed(seed) + factor_for_scale = 1e-2 + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + a_fp32 = torch.randn((B, M, K), dtype=torch.float32) / 10 + a = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + w_fp32 = (torch.rand((B, N, K), dtype=torch.float32) - 0.5) * 2 * fp8_max + w = w_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + block_n, block_k = block_size[0], block_size[1] + n_tiles_w = (N + block_n - 1) // block_n + k_tiles_w = (K + block_k - 1) // block_k + + w_s = ( + torch.rand((B, n_tiles_w, k_tiles_w), dtype=torch.float32) + * factor_for_scale + ) + a_s = torch.rand((B, M, k_tiles_w), dtype=torch.float32) * factor_for_scale + + ae = a.new_empty(B, (M + 255) // 256 * 256, K) + ae_s = a_s.new_empty(B, (M + 255) // 256 * 256, k_tiles_w) + oe = torch.empty((B, (M + 255) // 256 * 256, N), dtype=dtype) + ae[:, :M, :] = a + ae_s[:, :M, :] = a_s + + masked_m = torch.full((B,), M, dtype=torch.int) + expected_m = M + lhs = ( + ae, + ae_s, + ) + rhs = ( + w, + w_s, + ) + + from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked + + with torch.inference_mode(): + ref_out = torch_w8a8_block_fp8_bmm(a, a_s, w, w_s, block_size, dtype) + m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs, rhs, oe, masked_m, expected_m) + out = oe[:, :M, :] + + self.assertTrue( + torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) + / torch.mean(torch.abs(ref_out.to(torch.float32))) + < 0.0001 + ) + + def test_w8a8_block_fp8_batched_deep_gemm(self): + + for params in itertools.product( + self.M, + self.N, + self.K, + self.BATCH, + self.BLOCK_SIZE, + self.DTYPES, + self.SEEDS, + ): + with self.subTest( + M=params[0], + N=params[1], + K=params[2], + B=params[3], + block_size=params[4], + dtype=params[5], + seed=params[6], + ): + self._w8a8_block_fp8_batched_deep_gemm(*params) + + if __name__ == "__main__": unittest.main(verbosity=2)