Fix incorrect default get_hidden_dim logic (#8987)

This commit is contained in:
Lifu Huang
2025-08-09 00:25:38 -07:00
committed by GitHub
parent a47baff12c
commit 6e2151183b
7 changed files with 36 additions and 143 deletions

View File

@@ -92,11 +92,30 @@ def get_hidden_dim(
Please implement the function in the model class if it is not.
You can reference this function in llama.py.
"""
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
return config.hidden_size, config.hidden_size
elif module_name in ["kv_proj"]:
return config.hidden_size, config.hidden_size // (
config.num_attention_heads // config.num_key_value_heads
head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads
)
# TODO: the special handling of qkv will be addressed in #8940.
if module_name == "qkv_proj":
return (
config.hidden_size,
None, # qkv_proj is only used in LoRA A
)
elif module_name == "kv_proj":
return (
None, # kv_proj is only used in LoRA B
head_dim * config.num_key_value_heads,
)
elif module_name == "q_proj":
return (
None, # q_proj is only used in LoRA B
head_dim * config.num_attention_heads,
)
elif module_name == "o_proj":
return (
head_dim * config.num_attention_heads,
config.hidden_size,
)
elif module_name == "gate_up_proj":
return config.hidden_size, config.intermediate_size