sglang quant module remove vllm dependency (#4507)
This commit is contained in:
@@ -3,11 +3,21 @@ from fractions import Fraction
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.utils import scalar_types
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.utils import is_cuda
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
try:
|
||||
import vllm
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
VLLM_AVAILABLE = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -110,6 +120,9 @@ class GPTQConfig(QuantizationConfig):
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["GPTQLinearMethod"]:
|
||||
if not VLLM_AVAILABLE:
|
||||
raise ImportError("vllm is not installed")
|
||||
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
|
||||
from sglang.srt.layers.quantization import get_linear_quant_method
|
||||
@@ -263,6 +276,9 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
if not VLLM_AVAILABLE:
|
||||
raise ImportError("vllm is not installed")
|
||||
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinLinearMethod,
|
||||
GPTQMarlinMoEMethod,
|
||||
@@ -285,6 +301,9 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
|
||||
if not VLLM_AVAILABLE:
|
||||
return False
|
||||
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits")
|
||||
group_size = quant_config.get("group_size")
|
||||
@@ -294,9 +313,8 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supported,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
if not _is_cuda:
|
||||
return False
|
||||
|
||||
if quant_method != "gptq":
|
||||
@@ -407,6 +425,9 @@ class MarlinConfig(QuantizationConfig):
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["MarlinLinearMethod"]:
|
||||
if not VLLM_AVAILABLE:
|
||||
raise ImportError("vllm is not installed")
|
||||
|
||||
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
|
||||
|
||||
if isinstance(layer, LinearBase) or (
|
||||
|
||||
Reference in New Issue
Block a user