sglang quant module remove vllm dependency (#4507)

This commit is contained in:
Xiaoyu Zhang
2025-03-18 06:51:59 +08:00
committed by GitHub
parent f81a27f65e
commit 9b81f9bd34
8 changed files with 907 additions and 238 deletions

View File

@@ -3,11 +3,21 @@ from fractions import Fraction
from typing import Any, Dict, List, Optional, Union
import torch
from vllm.scalar_type import scalar_types
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.utils import scalar_types
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.utils import is_cuda
_is_cuda = is_cuda()
try:
import vllm
VLLM_AVAILABLE = True
except ImportError:
VLLM_AVAILABLE = False
logger = logging.getLogger(__name__)
@@ -110,6 +120,9 @@ class GPTQConfig(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["GPTQLinearMethod"]:
if not VLLM_AVAILABLE:
raise ImportError("vllm is not installed")
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from sglang.srt.layers.quantization import get_linear_quant_method
@@ -263,6 +276,9 @@ class GPTQMarlinConfig(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
if not VLLM_AVAILABLE:
raise ImportError("vllm is not installed")
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinLinearMethod,
GPTQMarlinMoEMethod,
@@ -285,6 +301,9 @@ class GPTQMarlinConfig(QuantizationConfig):
@classmethod
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
if not VLLM_AVAILABLE:
return False
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits")
group_size = quant_config.get("group_size")
@@ -294,9 +313,8 @@ class GPTQMarlinConfig(QuantizationConfig):
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported,
)
from vllm.platforms import current_platform
if not current_platform.is_cuda():
if not _is_cuda:
return False
if quant_method != "gptq":
@@ -407,6 +425,9 @@ class MarlinConfig(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["MarlinLinearMethod"]:
if not VLLM_AVAILABLE:
raise ImportError("vllm is not installed")
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
if isinstance(layer, LinearBase) or (