Fix incorrect context length for llama3.2-11b (#1873)

This commit is contained in:
Ran Chen
2024-11-02 00:04:50 -07:00
committed by GitHub
parent 660ecb731f
commit 146f613405

View File

@@ -88,19 +88,23 @@ CONTEXT_LENGTH_KEYS = [
def get_context_length(config):
"""Get the context length of a model from a huggingface model configs."""
rope_scaling = getattr(config, "rope_scaling", None)
"""Get the context length of a model from a huggingface model configs.
And here the config should be text_config part if the model is a multimodal
LLM.
"""
text_config = getattr(config, "text_config", config)
rope_scaling = getattr(text_config, "rope_scaling", None)
if rope_scaling:
rope_scaling_factor = config.rope_scaling.get("factor", 1)
rope_scaling_factor = rope_scaling.get("factor", 1)
if "original_max_position_embeddings" in rope_scaling:
rope_scaling_factor = 1
if config.rope_scaling.get("rope_type", None) == "llama3":
if rope_scaling.get("rope_type", None) == "llama3":
rope_scaling_factor = 1
else:
rope_scaling_factor = 1
for key in CONTEXT_LENGTH_KEYS:
val = getattr(config, key, None)
val = getattr(text_config, key, None)
if val is not None:
return int(rope_scaling_factor * val)
return 2048