From a5317b2fd3dd0bb7667bd9b6c646da8d2301a23b Mon Sep 17 00:00:00 2001 From: Chunyuan WU Date: Sat, 28 Jun 2025 10:04:29 +0800 Subject: [PATCH] [CPU] add optimizations for INT8 and FP8 DeepSeek (#6769) Co-authored-by: Zheng, Beilei --- .../srt/layers/moe/fused_moe_triton/layer.py | 2 +- python/sglang/srt/layers/quantization/fp8.py | 43 ++++++++++ .../srt/layers/quantization/moe_wna16.py | 2 +- .../srt/layers/quantization/w8a8_int8.py | 52 +++++++++++- python/sglang/srt/models/deepseek_v2.py | 83 +++++++++++++++++++ 5 files changed, 179 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index fd1898891..723601737 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -291,7 +291,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): torch.float ), # TODO: the topk_weights of llama4 is computed via Llama4MoE:custom_routing_function and is bfloat16 while the kernel requires it to be float32 topk_ids, - True, # inplace + False, # inplace # See [Note] inplace should be False in fused_experts. False, # use_int8_w8a8 False, # use_fp8_w8a16 None, # w1_scale diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index bc325aa2c..358c152d3 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -64,6 +64,7 @@ from sglang.srt.layers.quantization.utils import ( ) from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.utils import ( + _process_weight_after_loading, cpu_has_amx_support, get_bool_env_var, is_cpu, @@ -330,6 +331,12 @@ class Fp8LinearMethod(LinearMethodBase): ) layer.input_scale = None + elif _is_cpu: + assert ( + _is_cpu_amx_available + ), "Fp8LinearMethod on CPU requires that CPU has AMX support" + _process_weight_after_loading(layer, ["weight"]) + return else: weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data layer.weight = torch.nn.Parameter(weight, requires_grad=False) @@ -426,6 +433,17 @@ class Fp8LinearMethod(LinearMethodBase): ) if self.block_quant: + if getattr(layer, "use_intel_amx_backend", False): + return torch.ops.sgl_kernel.fp8_scaled_mm_cpu( + x, + layer.weight, + layer.weight_scale_inv, + self.quant_config.weight_block_size, + bias, + x.dtype, + True, # is_vnni + ) + return self.w8a8_block_fp8_linear( input=x, weight=layer.weight, @@ -746,6 +764,13 @@ class Fp8MoEMethod: layer.w2_weight.data = shuffle_weight( layer.w2_weight.contiguous(), (16, 16) ) + + if _is_cpu: + assert ( + _is_cpu_amx_available + ), "Fp8MoEMethod on CPU requires that CPU has AMX support" + _process_weight_after_loading(layer, ["w13_weight", "w2_weight"]) + return # If checkpoint is fp16 or bfloat16, quantize in place. @@ -971,6 +996,24 @@ class Fp8MoEMethod: routed_scaling_factor=routed_scaling_factor, ) + if getattr(layer, "use_intel_amx_backend", False): + return torch.ops.sgl_kernel.fused_experts_cpu( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + False, # inplace See [Note] inplace should be False in fused_experts. + False, # use_int8_w8a8 + True, # use_fp8_w8a16 + layer.w13_weight_scale_inv, # w1_scale + layer.w2_weight_scale_inv, # w2_scale + self.quant_config.weight_block_size, # block_size + None, # a1_scale + None, # a2_scale + True, # is_vnni + ) + if _is_hip: ret = self.maybe_apply_hip_fused_experts( layer, diff --git a/python/sglang/srt/layers/quantization/moe_wna16.py b/python/sglang/srt/layers/quantization/moe_wna16.py index 4be00f8a3..4f3bc716e 100644 --- a/python/sglang/srt/layers/quantization/moe_wna16.py +++ b/python/sglang/srt/layers/quantization/moe_wna16.py @@ -131,7 +131,7 @@ class MoeWNA16Config(QuantizationConfig): capability_tuple = get_device_capability() device_capability = ( -1 - if capability_tuple is None + if all(capability is None for capability in capability_tuple) else capability_tuple[0] * 10 + capability_tuple[1] ) # Avoid circular import diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index a973403ca..4e1d90a0e 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -11,9 +11,17 @@ from sglang.srt.layers.quantization.base_config import ( QuantizeMethodBase, ) from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 -from sglang.srt.utils import is_cuda, set_weight_attrs +from sglang.srt.utils import ( + _process_weight_after_loading, + cpu_has_amx_support, + is_cpu, + is_cuda, + set_weight_attrs, +) _is_cuda = is_cuda() +_is_cpu_amx_available = cpu_has_amx_support() +_is_cpu = is_cpu() if _is_cuda: from sgl_kernel import int8_scaled_mm @@ -72,6 +80,13 @@ class W8A8Int8LinearMethod(LinearMethodBase): self.quantization_config = quantization_config def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if _is_cpu: + assert ( + _is_cpu_amx_available + ), "W8A8Int8LinearMethod on CPU requires that CPU has AMX support" + _process_weight_after_loading(layer, ["weight"]) + return + layer.weight = Parameter(layer.weight.t(), requires_grad=False) layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False) @@ -112,6 +127,16 @@ class W8A8Int8LinearMethod(LinearMethodBase): x: torch.Tensor, bias: Optional[torch.Tensor] = None, ): + if getattr(layer, "use_intel_amx_backend", False): + return torch.ops.sgl_kernel.int8_scaled_mm_with_quant( + x, + layer.weight, + layer.weight_scale, + bias, + x.dtype, + True, # is_vnni + ) + x_q, x_scale = per_token_quant_int8(x) return int8_scaled_mm( @@ -206,6 +231,13 @@ class W8A8Int8MoEMethod: layer.register_parameter("w2_input_scale", w2_input_scale) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if _is_cpu: + assert ( + _is_cpu_amx_available + ), "W8A8Int8MoEMethod on CPU requires that CPU has AMX support" + _process_weight_after_loading(layer, ["w13_weight", "w2_weight"]) + return + layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False) layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False) layer.w13_weight_scale = Parameter( @@ -252,6 +284,24 @@ class W8A8Int8MoEMethod: routed_scaling_factor=routed_scaling_factor, ) + if getattr(layer, "use_intel_amx_backend", False): + return torch.ops.sgl_kernel.fused_experts_cpu( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + False, # inplace See [Note] inplace should be False in fused_experts. + True, # use_int8_w8a8 + False, # use_fp8_w8a16 + layer.w13_weight_scale, # w1_scale + layer.w2_weight_scale, # w2_scale + None, # block_size + layer.w13_input_scale, # a1_scale + layer.w2_input_scale, # a2_scale + True, # is_vnni + ) + return fused_experts( x, layer.w13_weight, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index c6617858e..79c7066df 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -300,6 +300,9 @@ class DeepseekV2MoE(nn.Module): ), ) + self.shared_experts_is_int8 = False + self.shared_experts_is_fp8 = False + self.shared_experts_weight_block_size = None if config.n_shared_experts is not None and self.num_fused_shared_experts == 0: intermediate_size = config.moe_intermediate_size * config.n_shared_experts # disable tp for shared experts when enable deepep moe @@ -316,6 +319,20 @@ class DeepseekV2MoE(nn.Module): else {} ), ) + self.shared_experts_is_int8 = ( + self.shared_experts.gate_up_proj.weight.dtype == torch.int8 + ) + self.shared_experts_is_fp8 = ( + self.shared_experts.gate_up_proj.weight.dtype == torch.float8_e4m3fn + ) + if self.shared_experts_is_fp8: + assert ( + self.shared_experts.gate_up_proj.quant_method.quant_config.weight_block_size + == self.shared_experts.down_proj.quant_method.quant_config.weight_block_size + ) + self.shared_experts_weight_block_size = ( + self.shared_experts.gate_up_proj.quant_method.quant_config.weight_block_size + ) self.top_k = config.num_experts_per_tok @@ -394,6 +411,11 @@ class DeepseekV2MoE(nn.Module): return final_hidden_states def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: + if hasattr(self, "shared_experts") and getattr( + self.shared_experts.gate_up_proj, "use_intel_amx_backend", False + ): + return self.forward_cpu(hidden_states) + shared_output = self._forward_shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) @@ -409,6 +431,59 @@ class DeepseekV2MoE(nn.Module): final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states + def forward_cpu(self, hidden_states: torch.Tensor) -> torch.Tensor: + # router_logits: (num_tokens, n_experts) + router_logits = self.gate(hidden_states) + fused_experts_out = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) + + assert getattr( + self.shared_experts.gate_up_proj, "use_intel_amx_backend", False + ) == getattr(self.shared_experts.down_proj, "use_intel_amx_backend", False) + # [Note] inplace should be False in fused_experts. + # If inplace is True in fused_experts (self.experts), hidden_states will be changed after fused_experts + # While hidden_states is still needed in shared_expert. + final_hidden_states = torch.ops.sgl_kernel.shared_expert_cpu( + hidden_states, + self.shared_experts.gate_up_proj.weight, + self.shared_experts.down_proj.weight, + fused_experts_out, + self.routed_scaling_factor, + True, # inplace + self.shared_experts_is_int8, # use_int8_w8a8 + self.shared_experts_is_fp8, # use_fp8_w8a16 + ( + self.shared_experts.gate_up_proj.weight_scale + if self.shared_experts_is_int8 + else ( + self.shared_experts.gate_up_proj.weight_scale_inv + if self.shared_experts_is_fp8 + else None + ) + ), # w1_scale + ( + self.shared_experts.down_proj.weight_scale + if self.shared_experts_is_int8 + else ( + self.shared_experts.down_proj.weight_scale_inv + if self.shared_experts_is_fp8 + else None + ) + ), # w2_scale + ( + self.shared_experts_weight_block_size + if self.shared_experts_is_fp8 + else None + ), # block_size + None, # a1_scale + None, # a2_scale + True, # is_vnni + ) + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + return final_hidden_states + def forward_deepep( self, hidden_states: torch.Tensor, forward_batch: ForwardBatch ) -> torch.Tensor: @@ -2107,6 +2182,14 @@ class DeepseekV2ForCausalLM(nn.Module): ) if _is_hip: self_attn.w_scale *= 2.0 + # TODO: remove this after adding FP8 support in bmm cpu kernel + if _is_cpu and _is_cpu_amx_available and w.dtype == torch.float8_e4m3fn: + self_attn.w_kc = ( + self_attn.w_kc.to(torch.bfloat16) * self_attn.w_scale + ) + self_attn.w_vc = ( + self_attn.w_vc.to(torch.bfloat16) * self_attn.w_scale + ) else: num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1] num_tiles_n = self_attn.v_head_dim // weight_block_size[0]