[Hotfix] solve fp8 w8a8 ci test fail (#4531)

This commit is contained in:
Xiaoyu Zhang
2025-03-18 14:17:04 +08:00
committed by GitHub
parent d373a48c98
commit dd865befde
5 changed files with 110 additions and 417 deletions

View File

@@ -6,7 +6,6 @@ import torch
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.utils import scalar_types
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.utils import is_cuda
@@ -133,11 +132,16 @@ class GPTQConfig(QuantizationConfig):
class GPTQMarlinConfig(QuantizationConfig):
"""Config class for GPTQ Marlin"""
# (num_bits, is_sym) -> quant_type
TYPE_MAP = {
(4, True): scalar_types.uint4b8,
(8, True): scalar_types.uint8b128,
}
if VLLM_AVAILABLE:
from vllm.scalar_type import scalar_types
# (num_bits, is_sym) -> quant_type
TYPE_MAP = {
(4, True): scalar_types.uint4b8,
(8, True): scalar_types.uint8b128,
}
else:
raise ImportError("vllm is not installed")
def __init__(
self,