[4/n]decouple quantization implementation from vLLM dependency (#9191)
Co-authored-by: AniZpZ <aniz1905@gmail.com> Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
@@ -58,7 +58,7 @@ runtime_common = [
|
||||
|
||||
srt = [
|
||||
"sglang[runtime_common]",
|
||||
"sgl-kernel==0.3.4.post1",
|
||||
"sgl-kernel==0.3.5",
|
||||
"torch==2.8.0",
|
||||
"torchaudio==2.8.0",
|
||||
"torchvision",
|
||||
|
||||
@@ -655,7 +655,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
|
||||
assert_pkg_version(
|
||||
"sgl-kernel",
|
||||
"0.3.4",
|
||||
"0.3.5",
|
||||
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
||||
)
|
||||
|
||||
|
||||
@@ -55,13 +55,7 @@ if is_mxfp_supported:
|
||||
from sglang.srt.layers.quantization.fp4 import MxFp4Config
|
||||
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
||||
from sglang.srt.layers.quantization.gptq import (
|
||||
GPTQConfig,
|
||||
GPTQLinearMethod,
|
||||
GPTQMarlinConfig,
|
||||
GPTQMarlinLinearMethod,
|
||||
GPTQMarlinMoEMethod,
|
||||
)
|
||||
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
||||
from sglang.srt.layers.quantization.modelopt_quant import (
|
||||
ModelOptFp4Config,
|
||||
ModelOptFp8Config,
|
||||
@@ -70,7 +64,6 @@ from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
|
||||
from sglang.srt.layers.quantization.mxfp4 import Mxfp4Config
|
||||
from sglang.srt.layers.quantization.petit import PetitNvFp4Config
|
||||
from sglang.srt.layers.quantization.qoq import QoQConfig
|
||||
from sglang.srt.layers.quantization.utils import get_linear_quant_method
|
||||
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
||||
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
||||
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
||||
@@ -86,6 +79,10 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||
"modelopt_fp4": ModelOptFp4Config,
|
||||
"w8a8_int8": W8A8Int8Config,
|
||||
"w8a8_fp8": W8A8Fp8Config,
|
||||
"awq": AWQConfig,
|
||||
"awq_marlin": AWQMarlinConfig,
|
||||
"gptq": GPTQConfig,
|
||||
"gptq_marlin": GPTQMarlinConfig,
|
||||
"moe_wna16": MoeWNA16Config,
|
||||
"compressed-tensors": CompressedTensorsConfig,
|
||||
"qoq": QoQConfig,
|
||||
@@ -111,19 +108,15 @@ elif is_mxfp_supported and is_hip():
|
||||
# VLLM-dependent quantization methods
|
||||
VLLM_QUANTIZATION_METHODS = {
|
||||
"aqlm": AQLMConfig,
|
||||
"awq": AWQConfig,
|
||||
"deepspeedfp": DeepSpeedFPConfig,
|
||||
"tpu_int8": Int8TpuConfig,
|
||||
"fbgemm_fp8": FBGEMMFp8Config,
|
||||
"marlin": MarlinConfig,
|
||||
"gguf": GGUFConfig,
|
||||
"gptq_marlin_24": GPTQMarlin24Config,
|
||||
"awq_marlin": AWQMarlinConfig,
|
||||
"bitsandbytes": BitsAndBytesConfig,
|
||||
"qqq": QQQConfig,
|
||||
"experts_int8": ExpertsInt8Config,
|
||||
"gptq_marlin": GPTQMarlinConfig,
|
||||
"gptq": GPTQConfig,
|
||||
}
|
||||
|
||||
QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS, **VLLM_QUANTIZATION_METHODS}
|
||||
@@ -145,23 +138,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
||||
return QUANTIZATION_METHODS[quantization]
|
||||
|
||||
|
||||
def gptq_get_quant_method(self, layer, prefix):
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
return GPTQMarlinMoEMethod(self)
|
||||
|
||||
if isinstance(self, GPTQConfig):
|
||||
return get_linear_quant_method(
|
||||
self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
|
||||
)
|
||||
elif isinstance(self, GPTQMarlinConfig):
|
||||
return get_linear_quant_method(
|
||||
self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
original_isinstance = builtins.isinstance
|
||||
|
||||
|
||||
@@ -239,10 +215,7 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
||||
|
||||
def monkey_patch_quant_configs():
|
||||
"""Apply all monkey patches in one place."""
|
||||
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
||||
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
|
||||
|
||||
monkey_patch_moe_apply(GPTQMarlinMoEMethod)
|
||||
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
|
||||
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
|
||||
|
||||
|
||||
@@ -35,22 +35,18 @@ from sglang.srt.layers.quantization.utils import get_scalar_types, replace_param
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
|
||||
try:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
warnings.warn(
|
||||
f"Using kernels directly from vllm. This might lead to performance degradation or "
|
||||
f"missing functionalities as certain kernels may not be optimized. "
|
||||
)
|
||||
except ImportError:
|
||||
ops = None
|
||||
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
if _is_cuda:
|
||||
from sgl_kernel import awq_dequantize, fused_marlin_moe
|
||||
from sgl_kernel import (
|
||||
awq_dequantize,
|
||||
awq_marlin_moe_repack,
|
||||
awq_marlin_repack,
|
||||
fused_marlin_moe,
|
||||
)
|
||||
|
||||
|
||||
elif _is_hip:
|
||||
from sglang.srt.layers.quantization.awq_triton import (
|
||||
@@ -519,7 +515,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
layer.workspace = marlin_make_workspace(device)
|
||||
|
||||
# Repack weights from AWQ format to marlin format.
|
||||
marlin_qweight = ops.awq_marlin_repack(
|
||||
marlin_qweight = awq_marlin_repack(
|
||||
layer.qweight,
|
||||
size_k=layer.input_size_per_partition,
|
||||
size_n=layer.output_size_per_partition,
|
||||
@@ -687,7 +683,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
marlin_w13_qweight = ops.awq_marlin_moe_repack(
|
||||
marlin_w13_qweight = awq_marlin_moe_repack(
|
||||
layer.w13_qweight,
|
||||
layer.w13_g_idx_sort_indices,
|
||||
size_k=layer.w13_qweight.shape[1],
|
||||
@@ -696,7 +692,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
|
||||
|
||||
marlin_w2_qweight = ops.awq_marlin_moe_repack(
|
||||
marlin_w2_qweight = awq_marlin_moe_repack(
|
||||
layer.w2_qweight,
|
||||
layer.w2_g_idx_sort_indices,
|
||||
size_k=layer.w2_qweight.shape[1],
|
||||
|
||||
@@ -46,17 +46,12 @@ from sglang.srt.layers.quantization.utils import (
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
|
||||
try:
|
||||
from vllm import _custom_ops as ops
|
||||
except ImportError:
|
||||
ops = None
|
||||
|
||||
from sglang.srt.utils import is_cuda
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import fused_marlin_moe
|
||||
from sgl_kernel import fused_marlin_moe, gptq_gemm, gptq_marlin_repack, gptq_shuffle
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -86,9 +81,7 @@ def gptq_marlin_moe_repack(
|
||||
dtype=b_q_weight.dtype,
|
||||
)
|
||||
for e in range(num_experts):
|
||||
output[e] = torch.ops.sgl_kernel.gptq_marlin_repack(
|
||||
b_q_weight[e], perm[e], size_k, size_n, num_bits
|
||||
)
|
||||
output[e] = gptq_marlin_repack(b_q_weight[e], perm[e], size_k, size_n, num_bits)
|
||||
return output
|
||||
|
||||
|
||||
@@ -205,11 +198,12 @@ class GPTQConfig(QuantizationConfig):
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
if isinstance(layer, FusedMoE):
|
||||
raise TypeError("GPTQ Method does not support MoE, please use gptq_marlin")
|
||||
return None
|
||||
else:
|
||||
return get_linear_quant_method(
|
||||
self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
|
||||
)
|
||||
|
||||
|
||||
class GPTQMarlinConfig(QuantizationConfig):
|
||||
@@ -531,7 +525,7 @@ class GPTQLinearMethod(LinearMethodBase):
|
||||
layer.g_idx.data = torch.empty(
|
||||
(0,), dtype=torch.int, device=layer.g_idx.device
|
||||
)
|
||||
ops.gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits)
|
||||
gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@@ -542,7 +536,7 @@ class GPTQLinearMethod(LinearMethodBase):
|
||||
out_shape = x.shape[:-1] + (layer.qweight.shape[-1],)
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
output = ops.gptq_gemm(
|
||||
output = gptq_gemm(
|
||||
reshaped_x,
|
||||
layer.qweight,
|
||||
layer.qzeros,
|
||||
@@ -727,7 +721,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
x.data = torch.ops.sgl_kernel.gptq_marlin_repack(
|
||||
x.data = gptq_marlin_repack(
|
||||
x.data.contiguous(),
|
||||
perm=layer.g_idx_sort_indices,
|
||||
size_k=c.partition_weight_shape[0],
|
||||
|
||||
@@ -24,7 +24,7 @@ from sglang.srt.layers.quantization.utils import (
|
||||
pack_cols,
|
||||
unpack_cols,
|
||||
)
|
||||
from sglang.srt.utils import get_device_capability
|
||||
from sglang.srt.utils import get_device_capability, is_cuda
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
@@ -34,6 +34,11 @@ try:
|
||||
except ImportError:
|
||||
ops = None
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import gptq_marlin_gemm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ScalarType, scalar_types = get_scalar_types()
|
||||
@@ -458,7 +463,7 @@ def apply_gptq_marlin_linear(
|
||||
dtype=input.dtype,
|
||||
)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
output = gptq_marlin_gemm(
|
||||
reshaped_x,
|
||||
None,
|
||||
weight,
|
||||
@@ -509,7 +514,7 @@ def apply_awq_marlin_linear(
|
||||
dtype=input.dtype,
|
||||
)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
output = gptq_marlin_gemm(
|
||||
reshaped_x,
|
||||
None,
|
||||
weight,
|
||||
|
||||
Reference in New Issue
Block a user