From 605972195bafcd3ffd7a3489dbed4e1d2d0d51dd Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Fri, 25 Oct 2024 03:40:36 +0800 Subject: [PATCH] check user-specified model_max_len with hf derived max_model_len (#1778) --- python/sglang/srt/configs/model_config.py | 27 +++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index a3c59e8d8..a74d240b4 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. """ +import logging +import os from enum import IntEnum, auto from typing import Optional @@ -20,6 +22,8 @@ from transformers import PretrainedConfig from sglang.srt.hf_transformers_utils import get_config, get_context_length +logger = logging.getLogger(__name__) + class AttentionArch(IntEnum): MLA = auto() @@ -46,10 +50,29 @@ class ModelConfig: model_override_args=model_override_args, ) self.hf_text_config = get_hf_text_config(self.hf_config) + derived_context_len = get_context_length(self.hf_text_config) + allow_long_context = os.environ.get( + "SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", None + ) + if context_length is not None: - self.context_len = context_length + if context_length > derived_context_len: + if allow_long_context: + logger.warning( + f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). " + f"This may lead to incorrect model outputs or CUDA errors." + ) + self.context_len = context_length + else: + raise ValueError( + f"User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). " + f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config. " + f"To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1" + ) + else: + self.context_len = context_length else: - self.context_len = get_context_length(self.hf_text_config) + self.context_len = derived_context_len # Unify the config keys for hf_text_config self.head_dim = getattr(