Fix shape error that occurred when loading lora weight of gemma2 model. (#2330)
This commit is contained in:
committed by
GitHub
parent
ef995dae1e
commit
63dfab1bea
@@ -355,6 +355,40 @@ class Gemma2ForCausalLM(nn.Module):
|
|||||||
input_ids, hidden_states, self.model.embed_tokens, forward_batch
|
input_ids, hidden_states, self.model.embed_tokens, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_hidden_dim(self, module_name):
|
||||||
|
# return input_dim, output_dim
|
||||||
|
if module_name in ["q_proj", "qkv_proj"]:
|
||||||
|
return (
|
||||||
|
self.config.hidden_size,
|
||||||
|
self.config.head_dim * self.config.num_attention_heads,
|
||||||
|
)
|
||||||
|
elif module_name in ["o_proj"]:
|
||||||
|
return (
|
||||||
|
self.config.head_dim * self.config.num_attention_heads,
|
||||||
|
self.config.hidden_size,
|
||||||
|
)
|
||||||
|
elif module_name in ["kv_proj"]:
|
||||||
|
return (
|
||||||
|
self.config.hidden_size,
|
||||||
|
self.config.head_dim * self.config.num_key_value_heads,
|
||||||
|
)
|
||||||
|
elif module_name == "gate_up_proj":
|
||||||
|
return self.config.hidden_size, self.config.intermediate_size
|
||||||
|
elif module_name == "down_proj":
|
||||||
|
return self.config.intermediate_size, self.config.hidden_size
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_module_name(self, name):
|
||||||
|
params_mapping = {
|
||||||
|
"q_proj": "qkv_proj",
|
||||||
|
"k_proj": "qkv_proj",
|
||||||
|
"v_proj": "qkv_proj",
|
||||||
|
"gate_proj": "gate_up_proj",
|
||||||
|
"up_proj": "gate_up_proj",
|
||||||
|
}
|
||||||
|
return params_mapping.get(name, name)
|
||||||
|
|
||||||
def get_attention_sliding_window_size(self):
|
def get_attention_sliding_window_size(self):
|
||||||
return get_attention_sliding_window_size(self.config)
|
return get_attention_sliding_window_size(self.config)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user