Fix llama2 weight loader (#1317)
This commit is contained in:
@@ -323,27 +323,6 @@ class ExaoneForCausalLM(nn.Module):
|
|||||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||||
return sample_output, logits_output
|
return sample_output, logits_output
|
||||||
|
|
||||||
def get_module_name(self, name):
|
|
||||||
stacked_params_mapping = [
|
|
||||||
# (param_name, shard_name, shard_id, num_shard)
|
|
||||||
("qkv_proj", "q_proj", "q", 3),
|
|
||||||
("qkv_proj", "k_proj", "k", 3),
|
|
||||||
("qkv_proj", "v_proj", "v", 3),
|
|
||||||
("gate_up_proj", "c_fc_0", 0, 2),
|
|
||||||
("gate_up_proj", "c_fc_1", 1, 2),
|
|
||||||
]
|
|
||||||
for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
|
|
||||||
if weight_name in name:
|
|
||||||
return (
|
|
||||||
name.replace(weight_name, param_name)[: -len(".weight")],
|
|
||||||
num_shard,
|
|
||||||
)
|
|
||||||
return name[: -len(".weight")], 1
|
|
||||||
|
|
||||||
def get_num_params(self):
|
|
||||||
params_dict = dict(self.named_parameters())
|
|
||||||
return len(params_dict)
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
# (param_name, shard_name, shard_id)
|
# (param_name, shard_name, shard_id)
|
||||||
@@ -357,13 +336,13 @@ class ExaoneForCausalLM(nn.Module):
|
|||||||
|
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if "rotary_emb.inv_freq" in name or "projector" in name:
|
if "rotary_emb.inv_freq" in name or "projector" in name:
|
||||||
return
|
continue
|
||||||
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
||||||
# Models trained using ColossalAI may include these tensors in
|
# Models trained using ColossalAI may include these tensors in
|
||||||
# the checkpoint. Skip them.
|
# the checkpoint. Skip them.
|
||||||
return
|
continue
|
||||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
if name.startswith("model.vision_tower") and name not in params_dict:
|
||||||
return
|
continue
|
||||||
|
|
||||||
name = name.replace("attn.attention", "self_attn")
|
name = name.replace("attn.attention", "self_attn")
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
@@ -380,7 +359,7 @@ class ExaoneForCausalLM(nn.Module):
|
|||||||
else:
|
else:
|
||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
return
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|||||||
@@ -334,13 +334,13 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
|
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if "rotary_emb.inv_freq" in name or "projector" in name:
|
if "rotary_emb.inv_freq" in name or "projector" in name:
|
||||||
return
|
continue
|
||||||
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
||||||
# Models trained using ColossalAI may include these tensors in
|
# Models trained using ColossalAI may include these tensors in
|
||||||
# the checkpoint. Skip them.
|
# the checkpoint. Skip them.
|
||||||
return
|
continue
|
||||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
if name.startswith("model.vision_tower") and name not in params_dict:
|
||||||
return
|
continue
|
||||||
|
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
@@ -356,7 +356,7 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
else:
|
else:
|
||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
return
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|||||||
Reference in New Issue
Block a user