Fix deepseek awq v3 (#3450)
This commit is contained in:
@@ -1,10 +1,13 @@
|
||||
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
|
||||
from typing import Callable, Dict, Optional, Type
|
||||
|
||||
from typing import Dict, Type
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
|
||||
from vllm.model_executor.layers.quantization.awq_marlin import (
|
||||
AWQMarlinConfig,
|
||||
AWQMoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
|
||||
CompressedTensorsConfig,
|
||||
@@ -73,21 +76,61 @@ def gptq_get_quant_method(self, layer, prefix):
|
||||
|
||||
|
||||
def awq_get_quant_method(self, layer, prefix):
|
||||
from vllm.model_executor.layers.quantization.awq import is_layer_skipped_awq
|
||||
from vllm.model_executor.layers.quantization.awq_marlin import (
|
||||
AWQMarlinLinearMethod,
|
||||
AWQMoEMethod,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if isinstance(layer, LinearBase) or (
|
||||
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
|
||||
):
|
||||
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
|
||||
return UnquantizedLinearMethod()
|
||||
return AWQMarlinLinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return AWQMoEMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
original_awq_moe_method_apply = AWQMoEMethod.apply
|
||||
|
||||
|
||||
def awq_moe_method_apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return original_awq_moe_method_apply(
|
||||
self,
|
||||
layer,
|
||||
x,
|
||||
router_logits,
|
||||
top_k,
|
||||
renormalize,
|
||||
use_grouped_topk,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
custom_routing_function,
|
||||
scoring_func,
|
||||
e_score_correction_bias,
|
||||
)
|
||||
|
||||
|
||||
def patch_vllm_linear_base_isinstance():
|
||||
import builtins
|
||||
|
||||
@@ -107,8 +150,11 @@ def patch_vllm_linear_base_isinstance():
|
||||
|
||||
def apply_monkey_patches():
|
||||
"""Apply all monkey patches in one place."""
|
||||
from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod
|
||||
|
||||
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
||||
setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
|
||||
setattr(AWQMoEMethod, "apply", awq_moe_method_apply)
|
||||
|
||||
|
||||
patch_vllm_linear_base_isinstance()
|
||||
|
||||
Reference in New Issue
Block a user