Fix incorrect default get_hidden_dim logic (#8987)

This commit is contained in:
Lifu Huang
2025-08-09 00:25:38 -07:00
committed by GitHub
parent a47baff12c
commit 6e2151183b
7 changed files with 36 additions and 143 deletions

View File

@@ -92,11 +92,30 @@ def get_hidden_dim(
Please implement the function in the model class if it is not. Please implement the function in the model class if it is not.
You can reference this function in llama.py. You can reference this function in llama.py.
""" """
if module_name in ["q_proj", "o_proj", "qkv_proj"]: head_dim = getattr(
return config.hidden_size, config.hidden_size config, "head_dim", config.hidden_size // config.num_attention_heads
elif module_name in ["kv_proj"]: )
return config.hidden_size, config.hidden_size // (
config.num_attention_heads // config.num_key_value_heads # TODO: the special handling of qkv will be addressed in #8940.
if module_name == "qkv_proj":
return (
config.hidden_size,
None, # qkv_proj is only used in LoRA A
)
elif module_name == "kv_proj":
return (
None, # kv_proj is only used in LoRA B
head_dim * config.num_key_value_heads,
)
elif module_name == "q_proj":
return (
None, # q_proj is only used in LoRA B
head_dim * config.num_attention_heads,
)
elif module_name == "o_proj":
return (
head_dim * config.num_attention_heads,
config.hidden_size,
) )
elif module_name == "gate_up_proj": elif module_name == "gate_up_proj":
return config.hidden_size, config.intermediate_size return config.hidden_size, config.intermediate_size

View File

@@ -432,40 +432,6 @@ class Gemma2ForCausalLM(nn.Module):
return result return result
def get_hidden_dim(self, module_name):
# return input_dim, output_dim
if module_name in ["q_proj", "qkv_proj"]:
return (
self.config.hidden_size,
self.config.head_dim * self.config.num_attention_heads,
)
elif module_name in ["o_proj"]:
return (
self.config.head_dim * self.config.num_attention_heads,
self.config.hidden_size,
)
elif module_name in ["kv_proj"]:
return (
self.config.hidden_size,
self.config.head_dim * self.config.num_key_value_heads,
)
elif module_name == "gate_up_proj":
return self.config.hidden_size, self.config.intermediate_size
elif module_name == "down_proj":
return self.config.intermediate_size, self.config.hidden_size
else:
raise NotImplementedError()
def get_module_name(self, name):
params_mapping = {
"q_proj": "qkv_proj",
"k_proj": "qkv_proj",
"v_proj": "qkv_proj",
"gate_proj": "gate_up_proj",
"up_proj": "gate_up_proj",
}
return params_mapping.get(name, name)
def get_attention_sliding_window_size(self): def get_attention_sliding_window_size(self):
return get_attention_sliding_window_size(self.config) return get_attention_sliding_window_size(self.config)

View File

@@ -501,9 +501,20 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
def get_hidden_dim(self, module_name): def get_hidden_dim(self, module_name):
# return input_dim, output_dim # return input_dim, output_dim
if module_name in ["q_proj", "qkv_proj"]: # TODO: the special handling of qkv will be addressed in #8940.
if module_name == "qkv_proj":
return ( return (
self.config.hidden_size, self.config.hidden_size,
None, # qkv_proj is only used in LoRA A
)
elif module_name == "kv_proj":
return (
None, # kv_proj is only used in LoRA B
self.config.head_dim * self.config.num_key_value_heads,
)
elif module_name == "q_proj":
return (
None, # q_proj is only used in LoRA B
self.config.head_dim * self.config.num_attention_heads, self.config.head_dim * self.config.num_attention_heads,
) )
elif module_name in ["o_proj"]: elif module_name in ["o_proj"]:
@@ -511,11 +522,6 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
self.config.head_dim * self.config.num_attention_heads, self.config.head_dim * self.config.num_attention_heads,
self.config.hidden_size, self.config.hidden_size,
) )
elif module_name in ["kv_proj"]:
return (
self.config.hidden_size,
self.config.head_dim * self.config.num_key_value_heads,
)
elif module_name == "gate_up_proj": elif module_name == "gate_up_proj":
assert len(set(self.config.intermediate_size)) == 1, ( assert len(set(self.config.intermediate_size)) == 1, (
"Currently SGLang requires uniform intermediate size for all layers. " "Currently SGLang requires uniform intermediate size for all layers. "

View File

@@ -363,31 +363,6 @@ class GraniteForCausalLM(nn.Module):
else: else:
return self.pooler(hidden_states, forward_batch) return self.pooler(hidden_states, forward_batch)
def get_hidden_dim(self, module_name):
# return input_dim, output_dim
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
return self.config.hidden_size, self.config.hidden_size
elif module_name in ["kv_proj"]:
return self.config.hidden_size, self.config.hidden_size // (
self.config.num_attention_heads // self.config.num_key_value_heads
)
elif module_name == "gate_up_proj":
return self.config.hidden_size, self.config.intermediate_size
elif module_name == "down_proj":
return self.config.intermediate_size, self.config.hidden_size
else:
raise NotImplementedError()
def get_module_name(self, name):
params_mapping = {
"q_proj": "qkv_proj",
"k_proj": "qkv_proj",
"v_proj": "qkv_proj",
"gate_proj": "gate_up_proj",
"up_proj": "gate_up_proj",
}
return params_mapping.get(name, name)
def get_module_name_from_weight_name(self, name): def get_module_name_from_weight_name(self, name):
for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping: for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
if weight_name in name: if weight_name in name:

View File

@@ -532,31 +532,6 @@ class LlamaForCausalLM(nn.Module):
def get_input_embeddings(self) -> nn.Embedding: def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens return self.model.embed_tokens
def get_hidden_dim(self, module_name):
# return input_dim, output_dim
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
return self.config.hidden_size, self.config.hidden_size
elif module_name in ["kv_proj"]:
return self.config.hidden_size, self.config.hidden_size // (
self.config.num_attention_heads // self.config.num_key_value_heads
)
elif module_name == "gate_up_proj":
return self.config.hidden_size, self.config.intermediate_size
elif module_name == "down_proj":
return self.config.intermediate_size, self.config.hidden_size
else:
raise NotImplementedError()
def get_module_name(self, name):
params_mapping = {
"q_proj": "qkv_proj",
"k_proj": "qkv_proj",
"v_proj": "qkv_proj",
"gate_proj": "gate_up_proj",
"up_proj": "gate_up_proj",
}
return params_mapping.get(name, name)
def get_module_name_from_weight_name(self, name): def get_module_name_from_weight_name(self, name):
for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping: for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
if weight_name in name: if weight_name in name:

View File

@@ -330,30 +330,6 @@ class Qwen3ForCausalLM(nn.Module):
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
def get_hidden_dim(self, module_name: str) -> Tuple[int]:
# return input_dim, output_dim
if module_name in ["q_proj", "qkv_proj"]:
return (
self.config.hidden_size,
self.config.head_dim * self.config.num_attention_heads,
)
elif module_name in ["o_proj"]:
return (
self.config.head_dim * self.config.num_attention_heads,
self.config.hidden_size,
)
elif module_name in ["kv_proj"]:
return (
self.config.hidden_size,
self.config.head_dim * self.config.num_key_value_heads,
)
elif module_name == "gate_up_proj":
return self.config.hidden_size, self.config.intermediate_size
elif module_name == "down_proj":
return self.config.intermediate_size, self.config.hidden_size
else:
raise NotImplementedError()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
self, self,

View File

@@ -416,30 +416,6 @@ class TorchNativeLlamaForCausalLM(nn.Module):
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
def get_hidden_dim(self, module_name):
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
return self.config.hidden_size, self.config.hidden_size
elif module_name in ["kv_proj"]:
return self.config.hidden_size, self.config.hidden_size // (
self.config.num_attention_heads // self.config.num_key_value_heads
)
elif module_name == "gate_up_proj":
return self.config.hidden_size, self.config.intermediate_size
elif module_name == "down_proj":
return self.config.intermediate_size, self.config.hidden_size
else:
raise NotImplementedError()
def get_module_name(self, name):
params_mapping = {
"q_proj": "qkv_proj",
"k_proj": "qkv_proj",
"v_proj": "qkv_proj",
"gate_proj": "gate_up_proj",
"up_proj": "gate_up_proj",
}
return params_mapping.get(name, name)
def get_module_name_from_weight_name(self, name): def get_module_name_from_weight_name(self, name):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id, num_shard) # (param_name, shard_name, shard_id, num_shard)