From 146f6134051a23cda360a9de2a1abc1f447b9787 Mon Sep 17 00:00:00 2001 From: Ran Chen Date: Sat, 2 Nov 2024 00:04:50 -0700 Subject: [PATCH] Fix incorrect context length for llama3.2-11b (#1873) --- python/sglang/srt/hf_transformers_utils.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index c3abf11e4..7a8751043 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -88,19 +88,23 @@ CONTEXT_LENGTH_KEYS = [ def get_context_length(config): - """Get the context length of a model from a huggingface model configs.""" - rope_scaling = getattr(config, "rope_scaling", None) + """Get the context length of a model from a huggingface model configs. + And here the config should be text_config part if the model is a multimodal + LLM. + """ + text_config = getattr(config, "text_config", config) + rope_scaling = getattr(text_config, "rope_scaling", None) if rope_scaling: - rope_scaling_factor = config.rope_scaling.get("factor", 1) + rope_scaling_factor = rope_scaling.get("factor", 1) if "original_max_position_embeddings" in rope_scaling: rope_scaling_factor = 1 - if config.rope_scaling.get("rope_type", None) == "llama3": + if rope_scaling.get("rope_type", None) == "llama3": rope_scaling_factor = 1 else: rope_scaling_factor = 1 for key in CONTEXT_LENGTH_KEYS: - val = getattr(config, key, None) + val = getattr(text_config, key, None) if val is not None: return int(rope_scaling_factor * val) return 2048