fix rotary_embedding rope_scaling for phi (#3055)
This commit is contained in:
@@ -1018,7 +1018,12 @@ def get_rope(
|
|||||||
head_size, rotary_dim, max_position, base, is_neox_style, dtype
|
head_size, rotary_dim, max_position, base, is_neox_style, dtype
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
if "rope_type" in rope_scaling:
|
||||||
scaling_type = rope_scaling["rope_type"]
|
scaling_type = rope_scaling["rope_type"]
|
||||||
|
elif "type" in rope_scaling:
|
||||||
|
scaling_type = rope_scaling["type"]
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown RoPE scaling type")
|
||||||
|
|
||||||
if scaling_type == "llama3":
|
if scaling_type == "llama3":
|
||||||
scaling_factor = rope_scaling["factor"]
|
scaling_factor = rope_scaling["factor"]
|
||||||
|
|||||||
Reference in New Issue
Block a user