Support Phi3 mini and medium (#1299)
This commit is contained in:
@@ -92,7 +92,7 @@ def get_context_length(config):
|
|||||||
"""Get the context length of a model from a huggingface model configs."""
|
"""Get the context length of a model from a huggingface model configs."""
|
||||||
rope_scaling = getattr(config, "rope_scaling", None)
|
rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
if rope_scaling:
|
if rope_scaling:
|
||||||
rope_scaling_factor = config.rope_scaling["factor"]
|
rope_scaling_factor = config.rope_scaling.get("factor", 1)
|
||||||
if "original_max_position_embeddings" in rope_scaling:
|
if "original_max_position_embeddings" in rope_scaling:
|
||||||
rope_scaling_factor = 1
|
rope_scaling_factor = 1
|
||||||
if config.rope_scaling.get("rope_type", None) == "llama3":
|
if config.rope_scaling.get("rope_type", None) == "llama3":
|
||||||
|
|||||||
@@ -324,11 +324,11 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
# (param_name, shard_name, shard_id)
|
# (param_name, shard_name, shard_id)
|
||||||
("qkv_proj", "q_proj", "q"),
|
(".qkv_proj", ".q_proj", "q"),
|
||||||
("qkv_proj", "k_proj", "k"),
|
(".qkv_proj", ".k_proj", "k"),
|
||||||
("qkv_proj", "v_proj", "v"),
|
(".qkv_proj", ".v_proj", "v"),
|
||||||
("gate_up_proj", "gate_proj", 0),
|
(".gate_up_proj", ".gate_proj", 0),
|
||||||
("gate_up_proj", "up_proj", 1),
|
(".gate_up_proj", ".up_proj", 1),
|
||||||
]
|
]
|
||||||
params_dict = self.param_dict
|
params_dict = self.param_dict
|
||||||
|
|
||||||
@@ -362,4 +362,8 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
EntryClass = LlamaForCausalLM
|
class Phi3ForCausalLM(LlamaForCausalLM):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
EntryClass = [LlamaForCausalLM, Phi3ForCausalLM]
|
||||||
|
|||||||
Reference in New Issue
Block a user