feat: update linear deps 1/N (#1305)
This commit is contained in:
76
python/sglang/srt/layers/quantization/__init__.py
Normal file
76
python/sglang/srt/layers/quantization/__init__.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
|
||||
|
||||
from typing import Dict, Type
|
||||
|
||||
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
|
||||
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||
CompressedTensorsConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
|
||||
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
||||
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import GPTQMarlin24Config
|
||||
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
||||
from vllm.model_executor.layers.quantization.qqq import QQQConfig
|
||||
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
|
||||
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
|
||||
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
|
||||
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||
"aqlm": AQLMConfig,
|
||||
"awq": AWQConfig,
|
||||
"deepspeedfp": DeepSpeedFPConfig,
|
||||
"tpu_int8": Int8TpuConfig,
|
||||
"fp8": Fp8Config,
|
||||
"fbgemm_fp8": FBGEMMFp8Config,
|
||||
# The order of gptq methods is important for config.py iteration over
|
||||
# override_quantization_method(..)
|
||||
"marlin": MarlinConfig,
|
||||
"gguf": GGUFConfig,
|
||||
"gptq_marlin_24": GPTQMarlin24Config,
|
||||
"gptq_marlin": GPTQMarlinConfig,
|
||||
"awq_marlin": AWQMarlinConfig,
|
||||
"gptq": GPTQConfig,
|
||||
"squeezellm": SqueezeLLMConfig,
|
||||
"compressed-tensors": CompressedTensorsConfig,
|
||||
"bitsandbytes": BitsAndBytesConfig,
|
||||
"qqq": QQQConfig,
|
||||
"experts_int8": ExpertsInt8Config,
|
||||
}
|
||||
|
||||
|
||||
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
||||
if quantization not in QUANTIZATION_METHODS:
|
||||
raise ValueError(f"Invalid quantization method: {quantization}")
|
||||
return QUANTIZATION_METHODS[quantization]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"QuantizationConfig",
|
||||
"get_quantization_config",
|
||||
"QUANTIZATION_METHODS",
|
||||
]
|
||||
|
||||
"""
|
||||
def fp8_get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(prefix, self.ignored_layers):
|
||||
return UnquantizedLinearMethod()
|
||||
return Fp8LinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return Fp8MoEMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
|
||||
"""
|
||||
122
python/sglang/srt/layers/quantization/base_config.py
Normal file
122
python/sglang/srt/layers/quantization/base_config.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/base_config.py
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class QuantizeMethodBase(ABC):
|
||||
"""Base class for different quantized methods."""
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(
|
||||
self, layer: torch.nn.Module, *weight_args, **extra_weight_attrs
|
||||
):
|
||||
"""Create weights for a layer.
|
||||
|
||||
The weights will be set as attributes of the layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
|
||||
"""Apply the weights in layer to the input tensor.
|
||||
|
||||
Expects create_weights to have been called before on the layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
def process_weights_after_loading(self, layer: nn.Module) -> None:
|
||||
"""Process the weight after loading.
|
||||
|
||||
This can be used for example, to transpose weights for computation.
|
||||
"""
|
||||
return
|
||||
|
||||
|
||||
class QuantizationConfig(ABC):
|
||||
"""Base class for quantization configs."""
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self) -> str:
|
||||
"""Name of the quantization method."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
"""List of supported activation dtypes."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
"""Minimum GPU capability to support the quantization method.
|
||||
|
||||
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
|
||||
This requirement is due to the custom CUDA kernels used by the
|
||||
quantization method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
"""List of filenames to search for in the model directory."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
|
||||
"""Create a config class from the model's quantization config."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
||||
"""
|
||||
Detects if this quantization method can support a given checkpoint
|
||||
format by overriding the user specified quantization method --
|
||||
this method should only be overwritten by subclasses in exceptional
|
||||
circumstances
|
||||
"""
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
|
||||
"""Get a value from the model's quantization config."""
|
||||
for key in keys:
|
||||
if key in config:
|
||||
return config[key]
|
||||
raise ValueError(
|
||||
f"Cannot find any of {keys} in the model's " "quantization config."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_from_keys_or(config: Dict[str, Any], keys: List[str], default: Any) -> Any:
|
||||
"""Get a optional value from the model's quantization config."""
|
||||
try:
|
||||
return QuantizationConfig.get_from_keys(config, keys)
|
||||
except ValueError:
|
||||
return default
|
||||
|
||||
@abstractmethod
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional[QuantizeMethodBase]:
|
||||
"""Get the quantize method to use for the quantized layer.
|
||||
|
||||
Args:
|
||||
layer: The layer for the quant method.
|
||||
prefix: The full name of the layer in the state dict
|
||||
Returns:
|
||||
The quantize method. None if the given layer doesn't support quant
|
||||
method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
"""Returns the activation function names that should be post-scaled.
|
||||
|
||||
For now, this is only used by AWQ.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
Reference in New Issue
Block a user