From bf8d07a6f912e4ad291f38ed6a95e53fef273f40 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Thu, 16 Jan 2025 18:00:03 +0800 Subject: [PATCH] feat: patch linear base (#2915) --- python/sglang/srt/layers/linear.py | 42 +++++++++++++++++-- .../srt/layers/quantization/__init__.py | 25 +++++++++-- python/sglang/srt/layers/quantization/fp8.py | 7 +++- .../srt/layers/quantization/modelopt_quant.py | 3 +- .../srt/layers/quantization/w8a8_int8.py | 2 +- python/sglang/srt/utils.py | 2 +- 6 files changed, 68 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index ee9386c13..4596f3d78 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -16,9 +16,6 @@ from vllm.distributed import ( tensor_model_parallel_all_reduce, ) -# Workaround: many QuantizationConfig still depends on this, so we have to use vLLM's LinearBase now. -from vllm.model_executor.layers.linear import LinearBase - from sglang.srt.layers.parameter import ( BasevLLMParameter, PackedColumnParameter, @@ -174,6 +171,45 @@ class UnquantizedLinearMethod(LinearMethodBase): return F.linear(x, layer.weight, bias) +class LinearBase(torch.nn.Module): + """Base linear layer. + + Args: + input_size: input dimension of the linear layer. + output_size: output dimension of the linear layer. + bias: If true, add bias. + skip_bias_add: If true, skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + """ + + def __init__( + self, + input_size: int, + output_size: int, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.skip_bias_add = skip_bias_add + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + if quant_config is None: + self.quant_method: Optional[QuantizeMethodBase] = UnquantizedLinearMethod() + else: + self.quant_method = quant_config.get_quant_method(self, prefix=prefix) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + class ReplicatedLinear(LinearBase): """Replicated linear layer. diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 1a39e8006..88e9af695 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -58,12 +58,11 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: def fp8_get_quant_method(self, layer, prefix): """Enhanced get_quant_method for FP8 config.""" - from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped, ) - from sglang.srt.layers.linear import UnquantizedLinearMethod + from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod, Fp8MoEMethod @@ -77,12 +76,12 @@ def fp8_get_quant_method(self, layer, prefix): def gptq_get_quant_method(self, layer, prefix): - from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinLinearMethod, GPTQMarlinMoEMethod, ) + from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE if isinstance(layer, LinearBase): @@ -93,12 +92,12 @@ def gptq_get_quant_method(self, layer, prefix): def awq_get_quant_method(self, layer, prefix): - from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.quantization.awq_marlin import ( AWQMarlinLinearMethod, AWQMoEMethod, ) + from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE if isinstance(layer, LinearBase): @@ -108,6 +107,23 @@ def awq_get_quant_method(self, layer, prefix): return None +def patch_vllm_linear_base_isinstance(): + import builtins + + from vllm.model_executor.layers.linear import LinearBase + + from sglang.srt.layers.linear import LinearBase as PatchedLinearBase + + original_isinstance = builtins.isinstance + + def patched_isinstance(obj, classinfo): + if classinfo is LinearBase: + return original_isinstance(obj, PatchedLinearBase) + return original_isinstance(obj, classinfo) + + builtins.isinstance = patched_isinstance + + def apply_monkey_patches(): """Apply all monkey patches in one place.""" setattr(Fp8Config, "get_quant_method", fp8_get_quant_method) @@ -115,6 +131,7 @@ def apply_monkey_patches(): setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method) +patch_vllm_linear_base_isinstance() # Apply patches when module is imported apply_monkey_patches() diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index d16a3b0c2..5ccac960f 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -9,7 +9,6 @@ from torch.nn import Module from torch.nn.parameter import Parameter from vllm import _custom_ops as ops from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, @@ -25,7 +24,11 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( requantize_with_max_scale, ) -from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod +from sglang.srt.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 5d65899d6..3e5f996ed 100644 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -5,14 +5,13 @@ from typing import Any, Dict, List, Optional import torch from torch.nn.parameter import Parameter -from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( apply_fp8_linear, cutlass_fp8_supported, requantize_with_max_scale, ) -from sglang.srt.layers.linear import LinearMethodBase +from sglang.srt.layers.linear import LinearBase, LinearMethodBase from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index 0c39393b7..87ba4cfc5 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -54,7 +54,7 @@ class W8A8Int8Config(QuantizationConfig): layer: torch.nn.Module, prefix: str, ) -> Optional["QuantizeMethodBase"]: - from vllm.model_executor.layers.linear import LinearBase + from sglang.srt.layers.linear import LinearBase if isinstance(layer, LinearBase): return W8A8Int8LinearMethod(self) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index e70e6b425..c521e002f 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -574,13 +574,13 @@ def monkey_patch_vllm_all_gather(reverse: bool = False): def monkey_patch_vllm_gguf_config(): - from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.quantization.gguf import ( GGUFConfig, GGUFEmbeddingMethod, GGUFLinearMethod, ) + from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding def get_quant_method_with_embedding_replaced(