From d9b3b0188338c6a1411c2995db5e8da7f56f6e4d Mon Sep 17 00:00:00 2001 From: Enrique Shockwave <33002121+qeternity@users.noreply.github.com> Date: Wed, 13 Mar 2024 02:10:12 +0000 Subject: [PATCH] enable marlin kernels (#286) --- python/sglang/srt/managers/router/model_runner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index 6ea1ac9d7..b63cc6b9d 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -13,12 +13,13 @@ from sglang.srt.utils import is_multimodal_model from sglang.utils import get_available_gpu_memory from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig +from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.model_loader import _set_default_torch_dtype from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel import sglang -QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig} +QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig, "marlin": MarlinConfig} logger = logging.getLogger("model_runner")