Fix incorrect default get_hidden_dim logic (#8987)
This commit is contained in:
@@ -92,11 +92,30 @@ def get_hidden_dim(
|
||||
Please implement the function in the model class if it is not.
|
||||
You can reference this function in llama.py.
|
||||
"""
|
||||
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
|
||||
return config.hidden_size, config.hidden_size
|
||||
elif module_name in ["kv_proj"]:
|
||||
return config.hidden_size, config.hidden_size // (
|
||||
config.num_attention_heads // config.num_key_value_heads
|
||||
head_dim = getattr(
|
||||
config, "head_dim", config.hidden_size // config.num_attention_heads
|
||||
)
|
||||
|
||||
# TODO: the special handling of qkv will be addressed in #8940.
|
||||
if module_name == "qkv_proj":
|
||||
return (
|
||||
config.hidden_size,
|
||||
None, # qkv_proj is only used in LoRA A
|
||||
)
|
||||
elif module_name == "kv_proj":
|
||||
return (
|
||||
None, # kv_proj is only used in LoRA B
|
||||
head_dim * config.num_key_value_heads,
|
||||
)
|
||||
elif module_name == "q_proj":
|
||||
return (
|
||||
None, # q_proj is only used in LoRA B
|
||||
head_dim * config.num_attention_heads,
|
||||
)
|
||||
elif module_name == "o_proj":
|
||||
return (
|
||||
head_dim * config.num_attention_heads,
|
||||
config.hidden_size,
|
||||
)
|
||||
elif module_name == "gate_up_proj":
|
||||
return config.hidden_size, config.intermediate_size
|
||||
|
||||
Reference in New Issue
Block a user