From 8616357a97c5f68eca194dfbeef0ae51943032ef Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Wed, 12 Feb 2025 22:09:52 +0800 Subject: [PATCH] Fix deepseek awq v3 (#3450) --- python/sglang/srt/layers/linear.py | 17 ++++-- .../srt/layers/moe/fused_moe_triton/layer.py | 2 + .../srt/layers/quantization/__init__.py | 56 +++++++++++++++++-- python/sglang/srt/models/deepseek_v2.py | 4 ++ 4 files changed, 69 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 64daf79c5..02557d107 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -421,11 +421,18 @@ class ColumnParallelLinear(LinearBase): if len(loaded_weight.shape) == 0: assert loaded_weight.numel() == 1 loaded_weight = loaded_weight.reshape(1) - param.load_column_parallel_weight( - loaded_weight, - tp_rank=self.tp_rank, - use_presharded_weights=self.use_presharded_weights, - ) + + from sglang.srt.layers.parameter import _ColumnvLLMParameter + + if isinstance(param, _ColumnvLLMParameter): + # FIXME: why would we need this special case? + param.load_column_parallel_weight( + loaded_weight, + tp_rank=self.tp_rank, + use_presharded_weights=self.use_presharded_weights, + ) + else: + param.load_column_parallel_weight(loaded_weight) def forward(self, input_): bias = self.bias if not self.skip_bias_add else None diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index dc7152da9..e83a32767 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -298,7 +298,9 @@ class FusedMoE(torch.nn.Module): layer=self, num_experts=num_experts, hidden_size=hidden_size, + # FIXME: figure out which intermediate_size to use intermediate_size=self.intermediate_size_per_partition, + intermediate_size_per_partition=self.intermediate_size_per_partition, params_dtype=params_dtype, weight_loader=self.weight_loader, ) diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 1c0092c1a..da6f29312 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -1,10 +1,13 @@ # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py +from typing import Callable, Dict, Optional, Type -from typing import Dict, Type - +import torch from vllm.model_executor.layers.quantization.aqlm import AQLMConfig from vllm.model_executor.layers.quantization.awq import AWQConfig -from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig +from vllm.model_executor.layers.quantization.awq_marlin import ( + AWQMarlinConfig, + AWQMoEMethod, +) from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( CompressedTensorsConfig, @@ -73,21 +76,61 @@ def gptq_get_quant_method(self, layer, prefix): def awq_get_quant_method(self, layer, prefix): + from vllm.model_executor.layers.quantization.awq import is_layer_skipped_awq from vllm.model_executor.layers.quantization.awq_marlin import ( AWQMarlinLinearMethod, AWQMoEMethod, ) - from sglang.srt.layers.linear import LinearBase + from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE + from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead - if isinstance(layer, LinearBase): + if isinstance(layer, LinearBase) or ( + isinstance(layer, ParallelLMHead) and self.lm_head_quantized + ): + if is_layer_skipped_awq(prefix, self.modules_to_not_convert): + return UnquantizedLinearMethod() return AWQMarlinLinearMethod(self) elif isinstance(layer, FusedMoE): return AWQMoEMethod(self) return None +original_awq_moe_method_apply = AWQMoEMethod.apply + + +def awq_moe_method_apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + **kwargs, +): + return original_awq_moe_method_apply( + self, + layer, + x, + router_logits, + top_k, + renormalize, + use_grouped_topk, + topk_group, + num_expert_group, + custom_routing_function, + scoring_func, + e_score_correction_bias, + ) + + def patch_vllm_linear_base_isinstance(): import builtins @@ -107,8 +150,11 @@ def patch_vllm_linear_base_isinstance(): def apply_monkey_patches(): """Apply all monkey patches in one place.""" + from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod + setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method) setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method) + setattr(AWQMoEMethod, "apply", awq_moe_method_apply) patch_vllm_linear_base_isinstance() diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 438441047..2a1c75cc4 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -255,6 +255,8 @@ class DeepseekV2Attention(nn.Module): self.kv_lora_rank + self.qk_rope_head_dim, bias=False, quant_config=quant_config, + # FIXME: quick fix for skip quantization + prefix=f"self_attn.kv_a_proj_with_mqa", ) self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_b_proj = ColumnParallelLinear( @@ -455,6 +457,8 @@ class DeepseekV2AttentionMLA(nn.Module): self.kv_lora_rank + self.qk_rope_head_dim, bias=False, quant_config=quant_config, + # FIXME: quick fix for skip quantization + prefix=f"self_attn.kv_a_proj_with_mqa", ) self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)