feat: patch linear base (#2915)
This commit is contained in:
@@ -16,9 +16,6 @@ from vllm.distributed import (
|
|||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Workaround: many QuantizationConfig still depends on this, so we have to use vLLM's LinearBase now.
|
|
||||||
from vllm.model_executor.layers.linear import LinearBase
|
|
||||||
|
|
||||||
from sglang.srt.layers.parameter import (
|
from sglang.srt.layers.parameter import (
|
||||||
BasevLLMParameter,
|
BasevLLMParameter,
|
||||||
PackedColumnParameter,
|
PackedColumnParameter,
|
||||||
@@ -174,6 +171,45 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
|||||||
return F.linear(x, layer.weight, bias)
|
return F.linear(x, layer.weight, bias)
|
||||||
|
|
||||||
|
|
||||||
|
class LinearBase(torch.nn.Module):
|
||||||
|
"""Base linear layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_size: input dimension of the linear layer.
|
||||||
|
output_size: output dimension of the linear layer.
|
||||||
|
bias: If true, add bias.
|
||||||
|
skip_bias_add: If true, skip adding bias but instead return it.
|
||||||
|
params_dtype: Data type for the parameters.
|
||||||
|
quant_config: Quantization configure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size: int,
|
||||||
|
output_size: int,
|
||||||
|
skip_bias_add: bool = False,
|
||||||
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Keep input parameters
|
||||||
|
self.input_size = input_size
|
||||||
|
self.output_size = output_size
|
||||||
|
self.skip_bias_add = skip_bias_add
|
||||||
|
if params_dtype is None:
|
||||||
|
params_dtype = torch.get_default_dtype()
|
||||||
|
self.params_dtype = params_dtype
|
||||||
|
if quant_config is None:
|
||||||
|
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedLinearMethod()
|
||||||
|
else:
|
||||||
|
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class ReplicatedLinear(LinearBase):
|
class ReplicatedLinear(LinearBase):
|
||||||
"""Replicated linear layer.
|
"""Replicated linear layer.
|
||||||
|
|
||||||
|
|||||||
@@ -58,12 +58,11 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
|||||||
|
|
||||||
def fp8_get_quant_method(self, layer, prefix):
|
def fp8_get_quant_method(self, layer, prefix):
|
||||||
"""Enhanced get_quant_method for FP8 config."""
|
"""Enhanced get_quant_method for FP8 config."""
|
||||||
from vllm.model_executor.layers.linear import LinearBase
|
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
is_layer_skipped,
|
is_layer_skipped,
|
||||||
)
|
)
|
||||||
|
|
||||||
from sglang.srt.layers.linear import UnquantizedLinearMethod
|
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||||
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod, Fp8MoEMethod
|
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod, Fp8MoEMethod
|
||||||
|
|
||||||
@@ -77,12 +76,12 @@ def fp8_get_quant_method(self, layer, prefix):
|
|||||||
|
|
||||||
|
|
||||||
def gptq_get_quant_method(self, layer, prefix):
|
def gptq_get_quant_method(self, layer, prefix):
|
||||||
from vllm.model_executor.layers.linear import LinearBase
|
|
||||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||||
GPTQMarlinLinearMethod,
|
GPTQMarlinLinearMethod,
|
||||||
GPTQMarlinMoEMethod,
|
GPTQMarlinMoEMethod,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from sglang.srt.layers.linear import LinearBase
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||||
|
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
@@ -93,12 +92,12 @@ def gptq_get_quant_method(self, layer, prefix):
|
|||||||
|
|
||||||
|
|
||||||
def awq_get_quant_method(self, layer, prefix):
|
def awq_get_quant_method(self, layer, prefix):
|
||||||
from vllm.model_executor.layers.linear import LinearBase
|
|
||||||
from vllm.model_executor.layers.quantization.awq_marlin import (
|
from vllm.model_executor.layers.quantization.awq_marlin import (
|
||||||
AWQMarlinLinearMethod,
|
AWQMarlinLinearMethod,
|
||||||
AWQMoEMethod,
|
AWQMoEMethod,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from sglang.srt.layers.linear import LinearBase
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||||
|
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
@@ -108,6 +107,23 @@ def awq_get_quant_method(self, layer, prefix):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def patch_vllm_linear_base_isinstance():
|
||||||
|
import builtins
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.linear import LinearBase
|
||||||
|
|
||||||
|
from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
|
||||||
|
|
||||||
|
original_isinstance = builtins.isinstance
|
||||||
|
|
||||||
|
def patched_isinstance(obj, classinfo):
|
||||||
|
if classinfo is LinearBase:
|
||||||
|
return original_isinstance(obj, PatchedLinearBase)
|
||||||
|
return original_isinstance(obj, classinfo)
|
||||||
|
|
||||||
|
builtins.isinstance = patched_isinstance
|
||||||
|
|
||||||
|
|
||||||
def apply_monkey_patches():
|
def apply_monkey_patches():
|
||||||
"""Apply all monkey patches in one place."""
|
"""Apply all monkey patches in one place."""
|
||||||
setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
|
setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
|
||||||
@@ -115,6 +131,7 @@ def apply_monkey_patches():
|
|||||||
setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
|
setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
|
||||||
|
|
||||||
|
|
||||||
|
patch_vllm_linear_base_isinstance()
|
||||||
# Apply patches when module is imported
|
# Apply patches when module is imported
|
||||||
apply_monkey_patches()
|
apply_monkey_patches()
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from torch.nn import Module
|
|||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.linear import LinearBase
|
|
||||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||||
apply_fp8_marlin_linear,
|
apply_fp8_marlin_linear,
|
||||||
@@ -25,7 +24,11 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|||||||
requantize_with_max_scale,
|
requantize_with_max_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
|
from sglang.srt.layers.linear import (
|
||||||
|
LinearBase,
|
||||||
|
LinearMethodBase,
|
||||||
|
UnquantizedLinearMethod,
|
||||||
|
)
|
||||||
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
|
|||||||
@@ -5,14 +5,13 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
from vllm.model_executor.layers.linear import LinearBase
|
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
apply_fp8_linear,
|
apply_fp8_linear,
|
||||||
cutlass_fp8_supported,
|
cutlass_fp8_supported,
|
||||||
requantize_with_max_scale,
|
requantize_with_max_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
from sglang.srt.layers.linear import LinearMethodBase
|
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
|
||||||
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ class W8A8Int8Config(QuantizationConfig):
|
|||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
prefix: str,
|
prefix: str,
|
||||||
) -> Optional["QuantizeMethodBase"]:
|
) -> Optional["QuantizeMethodBase"]:
|
||||||
from vllm.model_executor.layers.linear import LinearBase
|
from sglang.srt.layers.linear import LinearBase
|
||||||
|
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
return W8A8Int8LinearMethod(self)
|
return W8A8Int8LinearMethod(self)
|
||||||
|
|||||||
@@ -574,13 +574,13 @@ def monkey_patch_vllm_all_gather(reverse: bool = False):
|
|||||||
|
|
||||||
|
|
||||||
def monkey_patch_vllm_gguf_config():
|
def monkey_patch_vllm_gguf_config():
|
||||||
from vllm.model_executor.layers.linear import LinearBase
|
|
||||||
from vllm.model_executor.layers.quantization.gguf import (
|
from vllm.model_executor.layers.quantization.gguf import (
|
||||||
GGUFConfig,
|
GGUFConfig,
|
||||||
GGUFEmbeddingMethod,
|
GGUFEmbeddingMethod,
|
||||||
GGUFLinearMethod,
|
GGUFLinearMethod,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from sglang.srt.layers.linear import LinearBase
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||||
|
|
||||||
def get_quant_method_with_embedding_replaced(
|
def get_quant_method_with_embedding_replaced(
|
||||||
|
|||||||
Reference in New Issue
Block a user