From b5fb4ef58a6bbe6c105d533b69e8e8bc2bf4fc3c Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Wed, 8 Jan 2025 18:04:30 +0800 Subject: [PATCH] Update modelopt config and fix running issue (#2792) --- python/sglang/srt/layers/quantization/__init__.py | 2 +- python/sglang/srt/layers/{ => quantization}/modelopt_quant.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) rename python/sglang/srt/layers/{ => quantization}/modelopt_quant.py (99%) diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index df20a7a4b..35b0c4d94 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -17,12 +17,12 @@ from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig from vllm.model_executor.layers.quantization.gptq_marlin_24 import GPTQMarlin24Config from vllm.model_executor.layers.quantization.marlin import MarlinConfig -from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config from vllm.model_executor.layers.quantization.qqq import QQQConfig from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.fp8 import Fp8Config +from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "aqlm": AQLMConfig, diff --git a/python/sglang/srt/layers/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py similarity index 99% rename from python/sglang/srt/layers/modelopt_quant.py rename to python/sglang/srt/layers/quantization/modelopt_quant.py index 2c0887df2..8ce9d20d1 100644 --- a/python/sglang/srt/layers/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -142,6 +142,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase): data=torch.full( (len(output_partition_sizes),), torch.finfo(torch.float32).min, + dtype=torch.float32, ), weight_loader=weight_loader, ),