Fix deepseek awq v3 (#3450)
This commit is contained in:
@@ -421,11 +421,18 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
if len(loaded_weight.shape) == 0:
|
if len(loaded_weight.shape) == 0:
|
||||||
assert loaded_weight.numel() == 1
|
assert loaded_weight.numel() == 1
|
||||||
loaded_weight = loaded_weight.reshape(1)
|
loaded_weight = loaded_weight.reshape(1)
|
||||||
param.load_column_parallel_weight(
|
|
||||||
loaded_weight,
|
from sglang.srt.layers.parameter import _ColumnvLLMParameter
|
||||||
tp_rank=self.tp_rank,
|
|
||||||
use_presharded_weights=self.use_presharded_weights,
|
if isinstance(param, _ColumnvLLMParameter):
|
||||||
)
|
# FIXME: why would we need this special case?
|
||||||
|
param.load_column_parallel_weight(
|
||||||
|
loaded_weight,
|
||||||
|
tp_rank=self.tp_rank,
|
||||||
|
use_presharded_weights=self.use_presharded_weights,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
param.load_column_parallel_weight(loaded_weight)
|
||||||
|
|
||||||
def forward(self, input_):
|
def forward(self, input_):
|
||||||
bias = self.bias if not self.skip_bias_add else None
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
|
|||||||
@@ -298,7 +298,9 @@ class FusedMoE(torch.nn.Module):
|
|||||||
layer=self,
|
layer=self,
|
||||||
num_experts=num_experts,
|
num_experts=num_experts,
|
||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
|
# FIXME: figure out which intermediate_size to use
|
||||||
intermediate_size=self.intermediate_size_per_partition,
|
intermediate_size=self.intermediate_size_per_partition,
|
||||||
|
intermediate_size_per_partition=self.intermediate_size_per_partition,
|
||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
weight_loader=self.weight_loader,
|
weight_loader=self.weight_loader,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
|
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
|
||||||
|
from typing import Callable, Dict, Optional, Type
|
||||||
|
|
||||||
from typing import Dict, Type
|
import torch
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
||||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||||
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
|
from vllm.model_executor.layers.quantization.awq_marlin import (
|
||||||
|
AWQMarlinConfig,
|
||||||
|
AWQMoEMethod,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
|
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
|
||||||
CompressedTensorsConfig,
|
CompressedTensorsConfig,
|
||||||
@@ -73,21 +76,61 @@ def gptq_get_quant_method(self, layer, prefix):
|
|||||||
|
|
||||||
|
|
||||||
def awq_get_quant_method(self, layer, prefix):
|
def awq_get_quant_method(self, layer, prefix):
|
||||||
|
from vllm.model_executor.layers.quantization.awq import is_layer_skipped_awq
|
||||||
from vllm.model_executor.layers.quantization.awq_marlin import (
|
from vllm.model_executor.layers.quantization.awq_marlin import (
|
||||||
AWQMarlinLinearMethod,
|
AWQMarlinLinearMethod,
|
||||||
AWQMoEMethod,
|
AWQMoEMethod,
|
||||||
)
|
)
|
||||||
|
|
||||||
from sglang.srt.layers.linear import LinearBase
|
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||||
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
|
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase) or (
|
||||||
|
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
|
||||||
|
):
|
||||||
|
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
|
||||||
|
return UnquantizedLinearMethod()
|
||||||
return AWQMarlinLinearMethod(self)
|
return AWQMarlinLinearMethod(self)
|
||||||
elif isinstance(layer, FusedMoE):
|
elif isinstance(layer, FusedMoE):
|
||||||
return AWQMoEMethod(self)
|
return AWQMoEMethod(self)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
original_awq_moe_method_apply = AWQMoEMethod.apply
|
||||||
|
|
||||||
|
|
||||||
|
def awq_moe_method_apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
return original_awq_moe_method_apply(
|
||||||
|
self,
|
||||||
|
layer,
|
||||||
|
x,
|
||||||
|
router_logits,
|
||||||
|
top_k,
|
||||||
|
renormalize,
|
||||||
|
use_grouped_topk,
|
||||||
|
topk_group,
|
||||||
|
num_expert_group,
|
||||||
|
custom_routing_function,
|
||||||
|
scoring_func,
|
||||||
|
e_score_correction_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def patch_vllm_linear_base_isinstance():
|
def patch_vllm_linear_base_isinstance():
|
||||||
import builtins
|
import builtins
|
||||||
|
|
||||||
@@ -107,8 +150,11 @@ def patch_vllm_linear_base_isinstance():
|
|||||||
|
|
||||||
def apply_monkey_patches():
|
def apply_monkey_patches():
|
||||||
"""Apply all monkey patches in one place."""
|
"""Apply all monkey patches in one place."""
|
||||||
|
from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod
|
||||||
|
|
||||||
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
||||||
setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
|
setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
|
||||||
|
setattr(AWQMoEMethod, "apply", awq_moe_method_apply)
|
||||||
|
|
||||||
|
|
||||||
patch_vllm_linear_base_isinstance()
|
patch_vllm_linear_base_isinstance()
|
||||||
|
|||||||
@@ -255,6 +255,8 @@ class DeepseekV2Attention(nn.Module):
|
|||||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
# FIXME: quick fix for skip quantization
|
||||||
|
prefix=f"self_attn.kv_a_proj_with_mqa",
|
||||||
)
|
)
|
||||||
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
||||||
self.kv_b_proj = ColumnParallelLinear(
|
self.kv_b_proj = ColumnParallelLinear(
|
||||||
@@ -455,6 +457,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
# FIXME: quick fix for skip quantization
|
||||||
|
prefix=f"self_attn.kv_a_proj_with_mqa",
|
||||||
)
|
)
|
||||||
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user