Fix incorrect context length for llama3.2-11b (#1873)
This commit is contained in:
@@ -88,19 +88,23 @@ CONTEXT_LENGTH_KEYS = [
|
||||
|
||||
|
||||
def get_context_length(config):
|
||||
"""Get the context length of a model from a huggingface model configs."""
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
"""Get the context length of a model from a huggingface model configs.
|
||||
And here the config should be text_config part if the model is a multimodal
|
||||
LLM.
|
||||
"""
|
||||
text_config = getattr(config, "text_config", config)
|
||||
rope_scaling = getattr(text_config, "rope_scaling", None)
|
||||
if rope_scaling:
|
||||
rope_scaling_factor = config.rope_scaling.get("factor", 1)
|
||||
rope_scaling_factor = rope_scaling.get("factor", 1)
|
||||
if "original_max_position_embeddings" in rope_scaling:
|
||||
rope_scaling_factor = 1
|
||||
if config.rope_scaling.get("rope_type", None) == "llama3":
|
||||
if rope_scaling.get("rope_type", None) == "llama3":
|
||||
rope_scaling_factor = 1
|
||||
else:
|
||||
rope_scaling_factor = 1
|
||||
|
||||
for key in CONTEXT_LENGTH_KEYS:
|
||||
val = getattr(config, key, None)
|
||||
val = getattr(text_config, key, None)
|
||||
if val is not None:
|
||||
return int(rope_scaling_factor * val)
|
||||
return 2048
|
||||
|
||||
Reference in New Issue
Block a user