check user-specified model_max_len with hf derived max_model_len (#1778)

This commit is contained in:
Xiaoyu Zhang
2024-10-25 03:40:36 +08:00
committed by GitHub
parent fc82f5a743
commit 605972195b

View File

@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import logging
import os
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import Optional from typing import Optional
@@ -20,6 +22,8 @@ from transformers import PretrainedConfig
from sglang.srt.hf_transformers_utils import get_config, get_context_length from sglang.srt.hf_transformers_utils import get_config, get_context_length
logger = logging.getLogger(__name__)
class AttentionArch(IntEnum): class AttentionArch(IntEnum):
MLA = auto() MLA = auto()
@@ -46,10 +50,29 @@ class ModelConfig:
model_override_args=model_override_args, model_override_args=model_override_args,
) )
self.hf_text_config = get_hf_text_config(self.hf_config) self.hf_text_config = get_hf_text_config(self.hf_config)
derived_context_len = get_context_length(self.hf_text_config)
allow_long_context = os.environ.get(
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", None
)
if context_length is not None: if context_length is not None:
self.context_len = context_length if context_length > derived_context_len:
if allow_long_context:
logger.warning(
f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
f"This may lead to incorrect model outputs or CUDA errors."
)
self.context_len = context_length
else:
raise ValueError(
f"User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config. "
f"To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
)
else:
self.context_len = context_length
else: else:
self.context_len = get_context_length(self.hf_text_config) self.context_len = derived_context_len
# Unify the config keys for hf_text_config # Unify the config keys for hf_text_config
self.head_dim = getattr( self.head_dim = getattr(