diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index 363289f73..d5f120418 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -19,7 +19,11 @@ from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.model_loader import _set_default_torch_dtype from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel -QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig, "marlin": MarlinConfig} +QUANTIZATION_CONFIG_MAPPING = { + "awq": AWQConfig, + "gptq": GPTQConfig, + "marlin": MarlinConfig, +} logger = logging.getLogger("model_runner") @@ -300,30 +304,31 @@ class ModelRunner: # Load weights linear_method = None + + quant_cfg = getattr(self.model_config.hf_config, "quantization_config", None) + if quant_cfg is not None: + quant_method = quant_cfg.get("quant_method", "").lower() + # compat: autogptq >=0.8.0 use checkpoint_format: str + # compat: autogptq <=0.7.1 is_marlin_format: bool + is_format_marlin = quant_cfg.get( + "checkpoint_format" + ) == "marlin" or quant_cfg.get("is_marlin_format", False) + + # Use marlin if the GPTQ model is serialized in marlin format. + if quant_method == "gptq" and is_format_marlin: + quant_method = "marlin" + + quant_config_class = QUANTIZATION_CONFIG_MAPPING.get(quant_method) + + if quant_config_class is None: + raise ValueError(f"Unsupported quantization method: {quant_method}") + + quant_config = quant_config_class.from_config(quant_cfg) + logger.info(f"quant_config: {quant_config}") + linear_method = quant_config.get_linear_method() + with _set_default_torch_dtype(torch.float16): with torch.device("cuda"): - hf_quant_config = getattr( - self.model_config.hf_config, "quantization_config", None - ) - if hf_quant_config is not None: - hf_quant_method = hf_quant_config["quant_method"] - - # compat: autogptq uses is_marlin_format within quant config - if ( - hf_quant_method == "gptq" - and "is_marlin_format" in hf_quant_config - and hf_quant_config["is_marlin_format"] - ): - hf_quant_method = "marlin" - quant_config_class = QUANTIONCONFIG_MAPPING.get(hf_quant_method) - - if quant_config_class is None: - raise ValueError( - f"Unsupported quantization method: {hf_quant_config['quant_method']}" - ) - quant_config = quant_config_class.from_config(hf_quant_config) - logger.info(f"quant_config: {quant_config}") - linear_method = quant_config.get_linear_method() model = model_class( config=self.model_config.hf_config, linear_method=linear_method )