Revert "Integration of TurboMind AWQ" (#2866)
This commit is contained in:
@@ -28,7 +28,7 @@ runtime_common = [
|
||||
srt = [
|
||||
"sglang[runtime_common]", "cuda-python",
|
||||
"sgl-kernel>=0.0.2.post11", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1",
|
||||
"flashinfer==0.1.6", "turbomind"
|
||||
"flashinfer==0.1.6"
|
||||
]
|
||||
|
||||
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from enum import IntEnum, auto
|
||||
from typing import List, Optional, Set, Union
|
||||
|
||||
@@ -231,7 +230,7 @@ class ModelConfig:
|
||||
# Parse quantization method from the HF model config, if available.
|
||||
quant_cfg = self._parse_quant_hf_config()
|
||||
|
||||
if quant_cfg is not None and not quantization_in_turbomind(self.quantization):
|
||||
if quant_cfg is not None:
|
||||
quant_method = quant_cfg.get("quant_method", "").lower()
|
||||
|
||||
# Detect which checkpoint is it
|
||||
@@ -402,10 +401,3 @@ def is_multimodal_model(model_architectures: List[str]):
|
||||
|
||||
def is_encoder_decoder_model(model_architectures: List[str]):
|
||||
return "MllamaForConditionalGeneration" in model_architectures
|
||||
|
||||
|
||||
def quantization_in_turbomind(quantization: str) -> bool:
|
||||
if quantization in ["awq_turbomind"]:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
@@ -48,7 +48,6 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
||||
"GPTQLinearMethod",
|
||||
"FBGEMMFp8LinearMethod",
|
||||
"ModelOptFp8LinearMethod",
|
||||
"AWQTurbomindLinearMethod",
|
||||
"IPEXAWQLinearMethod",
|
||||
]
|
||||
|
||||
|
||||
@@ -20,7 +20,6 @@ from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
||||
from vllm.model_executor.layers.quantization.qqq import QQQConfig
|
||||
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
|
||||
|
||||
from sglang.srt.layers.quantization.awq_turbomind import AWQTurbomindConfig
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
||||
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
|
||||
@@ -38,7 +37,6 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||
"gptq_marlin_24": GPTQMarlin24Config,
|
||||
"gptq_marlin": GPTQMarlinConfig,
|
||||
"awq_marlin": AWQMarlinConfig,
|
||||
"awq_turbomind": AWQTurbomindConfig,
|
||||
"gptq": GPTQConfig,
|
||||
"compressed-tensors": CompressedTensorsConfig,
|
||||
"bitsandbytes": BitsAndBytesConfig,
|
||||
|
||||
@@ -1,287 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import turbomind
|
||||
from torch.nn import Parameter
|
||||
|
||||
turbomind_dir = os.path.split(turbomind.__file__)[0]
|
||||
sys.path.append(os.path.join(turbomind_dir, "lib"))
|
||||
import _turbomind_ext
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
|
||||
from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
|
||||
from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.turbomind_utils import (
|
||||
get_u4_slices,
|
||||
is_layer_skipped_awq,
|
||||
pack_u4_row,
|
||||
unpack_awq_gemm,
|
||||
verify_turbomind_supported,
|
||||
)
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.utils import is_cuda, set_weight_attrs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AWQTurbomindConfig(QuantizationConfig):
|
||||
"""Config class for AWQ Turbomind"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
zero_point: bool,
|
||||
lm_head_quantized: bool,
|
||||
modules_to_not_convert: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
self.pack_factor = 32 // weight_bits # packed into int32
|
||||
self.group_size = group_size
|
||||
self.zero_point = zero_point
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.weight_bits = weight_bits
|
||||
self.modules_to_not_convert = modules_to_not_convert or []
|
||||
|
||||
verify_turbomind_supported(self.weight_bits, self.group_size)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"AWQTurbomindConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"zero_point={self.zero_point}, "
|
||||
f"lm_head_quantized={self.lm_head_quantized}, "
|
||||
f"modules_to_not_convert={self.modules_to_not_convert})"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "awq_turbomind"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "AWQTurbomindConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
zero_point = cls.get_from_keys(config, ["zero_point"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
|
||||
modules_to_not_convert = cls.get_from_keys_or(
|
||||
config, ["modules_to_not_convert"], None
|
||||
)
|
||||
return cls(
|
||||
weight_bits,
|
||||
group_size,
|
||||
zero_point,
|
||||
lm_head_quantized,
|
||||
modules_to_not_convert,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
||||
can_convert = cls.is_awq_turbomind_compatible(hf_quant_cfg)
|
||||
is_valid_user_quant = user_quant is None or user_quant == "awq_turbomind"
|
||||
|
||||
if can_convert and is_valid_user_quant:
|
||||
msg = f"The model is convertible to {cls.get_name()} during runtime. Using {cls.get_name()} kernel."
|
||||
logger.info(msg)
|
||||
return cls.get_name()
|
||||
|
||||
if can_convert and user_quant == "awq":
|
||||
logger.info(
|
||||
"Detected that the model can run with awq_turbomind"
|
||||
", however you specified quantization=awq explicitly,"
|
||||
" so forcing awq. Use quantization=awq_turbomind for"
|
||||
" faster inference"
|
||||
)
|
||||
return None
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
if isinstance(layer, LinearBase) or (
|
||||
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
|
||||
):
|
||||
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
|
||||
return UnquantizedLinearMethod()
|
||||
return AWQTurbomindLinearMethod(self)
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def is_awq_turbomind_compatible(cls, quant_config: Dict[str, Any]):
|
||||
if not is_cuda():
|
||||
return False
|
||||
|
||||
# Extract data from quant config.
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits")
|
||||
group_size = quant_config.get("group_size")
|
||||
zero_point = quant_config.get("zero_point")
|
||||
|
||||
if quant_method != "awq":
|
||||
return False
|
||||
|
||||
# If we cannot find the info needed in the config, cannot convert.
|
||||
if num_bits is None or group_size is None or zero_point is None:
|
||||
return False
|
||||
|
||||
return verify_turbomind_supported(quant_bit=num_bits, group_size=group_size)
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
class AWQTurbomindLinearMethod(LinearMethodBase):
|
||||
"""Linear method for AWQ Turbomind.
|
||||
|
||||
Args:
|
||||
quant_config: The AWQ Turbomind quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: AWQTurbomindConfig) -> None:
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
# Normalize group_size
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
num_groups = input_size_per_partition // group_size
|
||||
|
||||
qzeros = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
num_groups,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
scales = GroupQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
num_groups,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
layer.register_parameter("scales", scales)
|
||||
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.num_groups = num_groups
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
|
||||
qweight_turbomind = unpack_awq_gemm(layer.qweight.data)
|
||||
qzeros_turbomind = unpack_awq_gemm(layer.qzeros.data)
|
||||
scales_turbomind = layer.scales.data
|
||||
|
||||
qweight_turbomind = pack_u4_row(qweight_turbomind)
|
||||
qzeros_turbomind = qzeros_turbomind.to(torch.half)
|
||||
|
||||
device_id = layer.qweight.device.index
|
||||
properties = torch.cuda.get_device_properties(device_id)
|
||||
|
||||
def is_16xx_series(name):
|
||||
import re
|
||||
|
||||
pattern = r"GTX 16\d\d"
|
||||
return bool(re.search(pattern, name))
|
||||
|
||||
simt = is_16xx_series(properties.name)
|
||||
qweight_turbomind = qweight_turbomind.contiguous()
|
||||
scales_turbomind = scales_turbomind.contiguous()
|
||||
qzeros_turbomind = qzeros_turbomind.contiguous()
|
||||
|
||||
self.linear = _turbomind_ext.Linear(
|
||||
layer.input_size_per_partition,
|
||||
layer.output_size_per_partition,
|
||||
self.quant_config.weight_bits,
|
||||
self.quant_config.group_size,
|
||||
)
|
||||
|
||||
self.linear.post_init(
|
||||
qweight_turbomind, scales_turbomind, qzeros_turbomind, simt
|
||||
)
|
||||
|
||||
layer.qweight = Parameter(qweight_turbomind, requires_grad=False)
|
||||
layer.scales = Parameter(scales_turbomind, requires_grad=False)
|
||||
layer.qzeros = Parameter(qzeros_turbomind, requires_grad=False)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
x = x.view(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (layer.output_size_per_partition,)
|
||||
out = torch.empty(
|
||||
(x.shape[0], layer.output_size_per_partition),
|
||||
dtype=torch.float16,
|
||||
device=x.device,
|
||||
)
|
||||
stream = torch.cuda.current_stream()
|
||||
|
||||
self.linear.forward(x, out, stream.cuda_stream)
|
||||
out = torch.from_dlpack(out)
|
||||
if bias is not None:
|
||||
out.add_(bias)
|
||||
|
||||
return out.view(out_shape)
|
||||
@@ -1,63 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.utils import get_device_capability
|
||||
|
||||
|
||||
def get_u4_slices(x: torch.Tensor, dtype: torch.dtype) -> List[torch.Tensor]:
|
||||
assert x.dtype == torch.int32
|
||||
xs = []
|
||||
for _ in range(8):
|
||||
xs.append((x & 15).to(dtype))
|
||||
x = x >> 4
|
||||
return xs
|
||||
|
||||
|
||||
def unpack_awq_gemm(x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
The int4 weights are packed into int32:
|
||||
bit: 31-28 27-24 23-20 19-16 15-12 11-8 7-4 3-0
|
||||
weight: int4_1 int4_2 int4_3 int4_4 int4_5 int4_6 int4_7 int4_8
|
||||
"""
|
||||
xs = get_u4_slices(x, torch.uint8)
|
||||
order = [0, 4, 1, 5, 2, 6, 3, 7]
|
||||
ys = [xs[i] for i in order]
|
||||
return torch.stack(ys, dim=-1).view(*x.shape[:-1], -1)
|
||||
|
||||
|
||||
def pack_u4_row(x: torch.Tensor) -> torch.Tensor:
|
||||
assert x.dtype == torch.uint8
|
||||
xs = x.view(*x.shape[:-1], -1, 8).split(1, dim=-1)
|
||||
a = torch.zeros(xs[0].shape, dtype=torch.int32, device=x.device)
|
||||
for t in reversed(xs):
|
||||
a = (a << 4) | t
|
||||
return a.squeeze(dim=-1)
|
||||
|
||||
|
||||
def verify_turbomind_supported(quant_bit: int, group_size: int) -> bool:
|
||||
|
||||
if quant_bit not in [4]:
|
||||
raise NotImplementedError(
|
||||
f"[Tubomind] Only 4-bit is supported for now, but got {quant_bit} bit"
|
||||
)
|
||||
if group_size != 128:
|
||||
raise NotImplementedError(
|
||||
f"[Tubomind] Only group_size 128 is supported for now, "
|
||||
f"but got group_size {group_size}"
|
||||
)
|
||||
|
||||
major, minor = get_device_capability()
|
||||
capability = major * 10 + minor
|
||||
if capability < 70:
|
||||
raise NotImplementedError(
|
||||
f"[Tubomind] Only capability >= 70 is supported for now, but got {capability}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
|
||||
return any(module_name in prefix for module_name in modules_to_not_convert)
|
||||
@@ -375,7 +375,6 @@ class ServerArgs:
|
||||
"marlin",
|
||||
"gptq_marlin",
|
||||
"awq_marlin",
|
||||
"awq_turbomind",
|
||||
"bitsandbytes",
|
||||
"gguf",
|
||||
"modelopt",
|
||||
|
||||
Reference in New Issue
Block a user