Fix incorrect default get_hidden_dim logic (#8987)

This commit is contained in:
Lifu Huang
2025-08-09 00:25:38 -07:00
committed by GitHub
parent a47baff12c
commit 6e2151183b
7 changed files with 36 additions and 143 deletions

View File

@@ -416,30 +416,6 @@ class TorchNativeLlamaForCausalLM(nn.Module):
input_ids, hidden_states, self.lm_head, forward_batch
)
def get_hidden_dim(self, module_name):
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
return self.config.hidden_size, self.config.hidden_size
elif module_name in ["kv_proj"]:
return self.config.hidden_size, self.config.hidden_size // (
self.config.num_attention_heads // self.config.num_key_value_heads
)
elif module_name == "gate_up_proj":
return self.config.hidden_size, self.config.intermediate_size
elif module_name == "down_proj":
return self.config.intermediate_size, self.config.hidden_size
else:
raise NotImplementedError()
def get_module_name(self, name):
params_mapping = {
"q_proj": "qkv_proj",
"k_proj": "qkv_proj",
"v_proj": "qkv_proj",
"gate_proj": "gate_up_proj",
"up_proj": "gate_up_proj",
}
return params_mapping.get(name, name)
def get_module_name_from_weight_name(self, name):
stacked_params_mapping = [
# (param_name, shard_name, shard_id, num_shard)