Fix incorrect context length for llama3.2-11b (#1873)
This commit is contained in:
@@ -88,19 +88,23 @@ CONTEXT_LENGTH_KEYS = [
|
|||||||
|
|
||||||
|
|
||||||
def get_context_length(config):
|
def get_context_length(config):
|
||||||
"""Get the context length of a model from a huggingface model configs."""
|
"""Get the context length of a model from a huggingface model configs.
|
||||||
rope_scaling = getattr(config, "rope_scaling", None)
|
And here the config should be text_config part if the model is a multimodal
|
||||||
|
LLM.
|
||||||
|
"""
|
||||||
|
text_config = getattr(config, "text_config", config)
|
||||||
|
rope_scaling = getattr(text_config, "rope_scaling", None)
|
||||||
if rope_scaling:
|
if rope_scaling:
|
||||||
rope_scaling_factor = config.rope_scaling.get("factor", 1)
|
rope_scaling_factor = rope_scaling.get("factor", 1)
|
||||||
if "original_max_position_embeddings" in rope_scaling:
|
if "original_max_position_embeddings" in rope_scaling:
|
||||||
rope_scaling_factor = 1
|
rope_scaling_factor = 1
|
||||||
if config.rope_scaling.get("rope_type", None) == "llama3":
|
if rope_scaling.get("rope_type", None) == "llama3":
|
||||||
rope_scaling_factor = 1
|
rope_scaling_factor = 1
|
||||||
else:
|
else:
|
||||||
rope_scaling_factor = 1
|
rope_scaling_factor = 1
|
||||||
|
|
||||||
for key in CONTEXT_LENGTH_KEYS:
|
for key in CONTEXT_LENGTH_KEYS:
|
||||||
val = getattr(config, key, None)
|
val = getattr(text_config, key, None)
|
||||||
if val is not None:
|
if val is not None:
|
||||||
return int(rope_scaling_factor * val)
|
return int(rope_scaling_factor * val)
|
||||||
return 2048
|
return 2048
|
||||||
|
|||||||
Reference in New Issue
Block a user