diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index c3abf11e4..7a8751043 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -88,19 +88,23 @@ CONTEXT_LENGTH_KEYS = [ def get_context_length(config): - """Get the context length of a model from a huggingface model configs.""" - rope_scaling = getattr(config, "rope_scaling", None) + """Get the context length of a model from a huggingface model configs. + And here the config should be text_config part if the model is a multimodal + LLM. + """ + text_config = getattr(config, "text_config", config) + rope_scaling = getattr(text_config, "rope_scaling", None) if rope_scaling: - rope_scaling_factor = config.rope_scaling.get("factor", 1) + rope_scaling_factor = rope_scaling.get("factor", 1) if "original_max_position_embeddings" in rope_scaling: rope_scaling_factor = 1 - if config.rope_scaling.get("rope_type", None) == "llama3": + if rope_scaling.get("rope_type", None) == "llama3": rope_scaling_factor = 1 else: rope_scaling_factor = 1 for key in CONTEXT_LENGTH_KEYS: - val = getattr(config, key, None) + val = getattr(text_config, key, None) if val is not None: return int(rope_scaling_factor * val) return 2048