check user-specified model_max_len with hf derived max_model_len (#1778)
This commit is contained in:
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
from enum import IntEnum, auto
|
from enum import IntEnum, auto
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -20,6 +22,8 @@ from transformers import PretrainedConfig
|
|||||||
|
|
||||||
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AttentionArch(IntEnum):
|
class AttentionArch(IntEnum):
|
||||||
MLA = auto()
|
MLA = auto()
|
||||||
@@ -46,10 +50,29 @@ class ModelConfig:
|
|||||||
model_override_args=model_override_args,
|
model_override_args=model_override_args,
|
||||||
)
|
)
|
||||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||||
|
derived_context_len = get_context_length(self.hf_text_config)
|
||||||
|
allow_long_context = os.environ.get(
|
||||||
|
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", None
|
||||||
|
)
|
||||||
|
|
||||||
if context_length is not None:
|
if context_length is not None:
|
||||||
self.context_len = context_length
|
if context_length > derived_context_len:
|
||||||
|
if allow_long_context:
|
||||||
|
logger.warning(
|
||||||
|
f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
|
||||||
|
f"This may lead to incorrect model outputs or CUDA errors."
|
||||||
|
)
|
||||||
|
self.context_len = context_length
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
|
||||||
|
f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config. "
|
||||||
|
f"To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.context_len = context_length
|
||||||
else:
|
else:
|
||||||
self.context_len = get_context_length(self.hf_text_config)
|
self.context_len = derived_context_len
|
||||||
|
|
||||||
# Unify the config keys for hf_text_config
|
# Unify the config keys for hf_text_config
|
||||||
self.head_dim = getattr(
|
self.head_dim = getattr(
|
||||||
|
|||||||
Reference in New Issue
Block a user