Fix quantization and nightly tests (#4258)

This commit is contained in:
Lianmin Zheng
2025-03-10 03:06:21 -07:00
committed by GitHub
parent 1a5023e05d
commit 00d25a7f5e
7 changed files with 142 additions and 70 deletions

View File

@@ -1,4 +1,6 @@
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
import builtins
import inspect
import re
from copy import deepcopy
from typing import Callable, Dict, Optional, Type, Union
@@ -6,10 +8,7 @@ from typing import Callable, Dict, Optional, Type, Union
import torch
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinConfig,
AWQMoEMethod,
)
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
CompressedTensorsConfig,
@@ -180,96 +179,117 @@ def gptq_get_quant_method(self, layer, prefix):
return None
def awq_get_quant_method(self, layer, prefix):
from vllm.model_executor.layers.quantization.awq import is_layer_skipped_awq
from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinLinearMethod,
AWQMoEMethod,
)
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
if isinstance(layer, LinearBase) or (
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
):
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
return AWQMarlinLinearMethod(self)
elif isinstance(layer, FusedMoE):
return AWQMoEMethod(self)
return None
original_isinstance = builtins.isinstance
original_awq_moe_method_apply = AWQMoEMethod.apply
def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
"""
Patch isinstance so that the `get_quant_method` in vllm's QuantizationConfig
can recognize sglang layers
"""
if reverse:
builtins.isinstance = original_isinstance
return
def awq_moe_method_apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
**kwargs,
):
return original_awq_moe_method_apply(
self,
layer,
x,
router_logits,
top_k,
renormalize,
use_grouped_topk,
topk_group,
num_expert_group,
custom_routing_function,
scoring_func,
e_score_correction_bias,
)
def patch_vllm_linear_base_isinstance():
import builtins
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
original_isinstance = builtins.isinstance
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE as PatchedFusedMoE
from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding as PatchedVocabParallelEmbedding,
)
def patched_isinstance(obj, classinfo):
if classinfo is LinearBase:
return original_isinstance(obj, PatchedLinearBase)
if classinfo is FusedMoE:
return original_isinstance(obj, PatchedFusedMoE)
if classinfo is VocabParallelEmbedding:
return original_isinstance(obj, PatchedVocabParallelEmbedding)
return original_isinstance(obj, classinfo)
builtins.isinstance = patched_isinstance
def apply_monkey_patches():
def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
"""
Monkey patch the apply function of vllm's FusedMoEMethodBase.
Convert sglang arguments to vllm arguments.
"""
original_apply = class_obj.apply
sig = inspect.signature(original_apply)
param_names = list(sig.parameters.keys())
has_correction_bias = "e_score_correction_bias" in param_names
def new_apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
inplace: bool = True,
no_combine: bool = False,
):
assert activation == "silu"
assert inplace and not no_combine
kwargs = {
"self": self,
"layer": layer,
"x": x,
"router_logits": router_logits,
"top_k": top_k,
"renormalize": renormalize,
"use_grouped_topk": use_grouped_topk,
"topk_group": topk_group,
"num_expert_group": num_expert_group,
"custom_routing_function": custom_routing_function,
}
if correction_bias is not None:
if not has_correction_bias:
raise ValueError(
"Please increase the version of your vllm. Try `pip install vllm==0.7.2`"
)
kwargs["e_score_correction_bias"] = correction_bias
return original_apply(**kwargs)
setattr(class_obj, "apply", new_apply)
def monkey_patch_quant_configs():
"""Apply all monkey patches in one place."""
from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
CompressedTensorsW8A8Fp8MoEMethod,
CompressedTensorsWNA16MoEMethod,
)
from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinMoEMethod
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
setattr(AWQMoEMethod, "apply", awq_moe_method_apply)
monkey_patch_moe_apply(AWQMoEMethod)
monkey_patch_moe_apply(GPTQMarlinMoEMethod)
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
patch_vllm_linear_base_isinstance()
# Apply patches when module is imported
apply_monkey_patches()
monkey_patch_quant_configs()
__all__ = [
"QuantizationConfig",
"get_quantization_config",
"QUANTIZATION_METHODS",
]