Fix context length (#757)
This commit is contained in:
@@ -73,6 +73,8 @@ def get_context_length(config):
|
|||||||
rope_scaling = getattr(config, "rope_scaling", None)
|
rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
if rope_scaling:
|
if rope_scaling:
|
||||||
rope_scaling_factor = config.rope_scaling["factor"]
|
rope_scaling_factor = config.rope_scaling["factor"]
|
||||||
|
if "original_max_position_embeddings" in rope_scaling:
|
||||||
|
rope_scaling_factor = 1
|
||||||
if config.rope_scaling.get("rope_type", None) == "llama3":
|
if config.rope_scaling.get("rope_type", None) == "llama3":
|
||||||
rope_scaling_factor = 1
|
rope_scaling_factor = 1
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user