feat: patch linear base (#2915)
This commit is contained in:
@@ -58,12 +58,11 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
||||
|
||||
def fp8_get_quant_method(self, layer, prefix):
|
||||
"""Enhanced get_quant_method for FP8 config."""
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.linear import UnquantizedLinearMethod
|
||||
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod, Fp8MoEMethod
|
||||
|
||||
@@ -77,12 +76,12 @@ def fp8_get_quant_method(self, layer, prefix):
|
||||
|
||||
|
||||
def gptq_get_quant_method(self, layer, prefix):
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinLinearMethod,
|
||||
GPTQMarlinMoEMethod,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
@@ -93,12 +92,12 @@ def gptq_get_quant_method(self, layer, prefix):
|
||||
|
||||
|
||||
def awq_get_quant_method(self, layer, prefix):
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.model_executor.layers.quantization.awq_marlin import (
|
||||
AWQMarlinLinearMethod,
|
||||
AWQMoEMethod,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
@@ -108,6 +107,23 @@ def awq_get_quant_method(self, layer, prefix):
|
||||
return None
|
||||
|
||||
|
||||
def patch_vllm_linear_base_isinstance():
|
||||
import builtins
|
||||
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
|
||||
from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
|
||||
|
||||
original_isinstance = builtins.isinstance
|
||||
|
||||
def patched_isinstance(obj, classinfo):
|
||||
if classinfo is LinearBase:
|
||||
return original_isinstance(obj, PatchedLinearBase)
|
||||
return original_isinstance(obj, classinfo)
|
||||
|
||||
builtins.isinstance = patched_isinstance
|
||||
|
||||
|
||||
def apply_monkey_patches():
|
||||
"""Apply all monkey patches in one place."""
|
||||
setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
|
||||
@@ -115,6 +131,7 @@ def apply_monkey_patches():
|
||||
setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
|
||||
|
||||
|
||||
patch_vllm_linear_base_isinstance()
|
||||
# Apply patches when module is imported
|
||||
apply_monkey_patches()
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear,
|
||||
@@ -25,7 +24,11 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
requantize_with_max_scale,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
|
||||
from sglang.srt.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
|
||||
@@ -5,14 +5,13 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
apply_fp8_linear,
|
||||
cutlass_fp8_supported,
|
||||
requantize_with_max_scale,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.linear import LinearMethodBase
|
||||
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
|
||||
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
|
||||
@@ -54,7 +54,7 @@ class W8A8Int8Config(QuantizationConfig):
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
return W8A8Int8LinearMethod(self)
|
||||
|
||||
Reference in New Issue
Block a user