Fix incorrect default get_hidden_dim logic (#8987)
This commit is contained in:
@@ -92,11 +92,30 @@ def get_hidden_dim(
|
|||||||
Please implement the function in the model class if it is not.
|
Please implement the function in the model class if it is not.
|
||||||
You can reference this function in llama.py.
|
You can reference this function in llama.py.
|
||||||
"""
|
"""
|
||||||
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
|
head_dim = getattr(
|
||||||
return config.hidden_size, config.hidden_size
|
config, "head_dim", config.hidden_size // config.num_attention_heads
|
||||||
elif module_name in ["kv_proj"]:
|
)
|
||||||
return config.hidden_size, config.hidden_size // (
|
|
||||||
config.num_attention_heads // config.num_key_value_heads
|
# TODO: the special handling of qkv will be addressed in #8940.
|
||||||
|
if module_name == "qkv_proj":
|
||||||
|
return (
|
||||||
|
config.hidden_size,
|
||||||
|
None, # qkv_proj is only used in LoRA A
|
||||||
|
)
|
||||||
|
elif module_name == "kv_proj":
|
||||||
|
return (
|
||||||
|
None, # kv_proj is only used in LoRA B
|
||||||
|
head_dim * config.num_key_value_heads,
|
||||||
|
)
|
||||||
|
elif module_name == "q_proj":
|
||||||
|
return (
|
||||||
|
None, # q_proj is only used in LoRA B
|
||||||
|
head_dim * config.num_attention_heads,
|
||||||
|
)
|
||||||
|
elif module_name == "o_proj":
|
||||||
|
return (
|
||||||
|
head_dim * config.num_attention_heads,
|
||||||
|
config.hidden_size,
|
||||||
)
|
)
|
||||||
elif module_name == "gate_up_proj":
|
elif module_name == "gate_up_proj":
|
||||||
return config.hidden_size, config.intermediate_size
|
return config.hidden_size, config.intermediate_size
|
||||||
|
|||||||
@@ -432,40 +432,6 @@ class Gemma2ForCausalLM(nn.Module):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_hidden_dim(self, module_name):
|
|
||||||
# return input_dim, output_dim
|
|
||||||
if module_name in ["q_proj", "qkv_proj"]:
|
|
||||||
return (
|
|
||||||
self.config.hidden_size,
|
|
||||||
self.config.head_dim * self.config.num_attention_heads,
|
|
||||||
)
|
|
||||||
elif module_name in ["o_proj"]:
|
|
||||||
return (
|
|
||||||
self.config.head_dim * self.config.num_attention_heads,
|
|
||||||
self.config.hidden_size,
|
|
||||||
)
|
|
||||||
elif module_name in ["kv_proj"]:
|
|
||||||
return (
|
|
||||||
self.config.hidden_size,
|
|
||||||
self.config.head_dim * self.config.num_key_value_heads,
|
|
||||||
)
|
|
||||||
elif module_name == "gate_up_proj":
|
|
||||||
return self.config.hidden_size, self.config.intermediate_size
|
|
||||||
elif module_name == "down_proj":
|
|
||||||
return self.config.intermediate_size, self.config.hidden_size
|
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def get_module_name(self, name):
|
|
||||||
params_mapping = {
|
|
||||||
"q_proj": "qkv_proj",
|
|
||||||
"k_proj": "qkv_proj",
|
|
||||||
"v_proj": "qkv_proj",
|
|
||||||
"gate_proj": "gate_up_proj",
|
|
||||||
"up_proj": "gate_up_proj",
|
|
||||||
}
|
|
||||||
return params_mapping.get(name, name)
|
|
||||||
|
|
||||||
def get_attention_sliding_window_size(self):
|
def get_attention_sliding_window_size(self):
|
||||||
return get_attention_sliding_window_size(self.config)
|
return get_attention_sliding_window_size(self.config)
|
||||||
|
|
||||||
|
|||||||
@@ -501,9 +501,20 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
|
|||||||
|
|
||||||
def get_hidden_dim(self, module_name):
|
def get_hidden_dim(self, module_name):
|
||||||
# return input_dim, output_dim
|
# return input_dim, output_dim
|
||||||
if module_name in ["q_proj", "qkv_proj"]:
|
# TODO: the special handling of qkv will be addressed in #8940.
|
||||||
|
if module_name == "qkv_proj":
|
||||||
return (
|
return (
|
||||||
self.config.hidden_size,
|
self.config.hidden_size,
|
||||||
|
None, # qkv_proj is only used in LoRA A
|
||||||
|
)
|
||||||
|
elif module_name == "kv_proj":
|
||||||
|
return (
|
||||||
|
None, # kv_proj is only used in LoRA B
|
||||||
|
self.config.head_dim * self.config.num_key_value_heads,
|
||||||
|
)
|
||||||
|
elif module_name == "q_proj":
|
||||||
|
return (
|
||||||
|
None, # q_proj is only used in LoRA B
|
||||||
self.config.head_dim * self.config.num_attention_heads,
|
self.config.head_dim * self.config.num_attention_heads,
|
||||||
)
|
)
|
||||||
elif module_name in ["o_proj"]:
|
elif module_name in ["o_proj"]:
|
||||||
@@ -511,11 +522,6 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
|
|||||||
self.config.head_dim * self.config.num_attention_heads,
|
self.config.head_dim * self.config.num_attention_heads,
|
||||||
self.config.hidden_size,
|
self.config.hidden_size,
|
||||||
)
|
)
|
||||||
elif module_name in ["kv_proj"]:
|
|
||||||
return (
|
|
||||||
self.config.hidden_size,
|
|
||||||
self.config.head_dim * self.config.num_key_value_heads,
|
|
||||||
)
|
|
||||||
elif module_name == "gate_up_proj":
|
elif module_name == "gate_up_proj":
|
||||||
assert len(set(self.config.intermediate_size)) == 1, (
|
assert len(set(self.config.intermediate_size)) == 1, (
|
||||||
"Currently SGLang requires uniform intermediate size for all layers. "
|
"Currently SGLang requires uniform intermediate size for all layers. "
|
||||||
|
|||||||
@@ -363,31 +363,6 @@ class GraniteForCausalLM(nn.Module):
|
|||||||
else:
|
else:
|
||||||
return self.pooler(hidden_states, forward_batch)
|
return self.pooler(hidden_states, forward_batch)
|
||||||
|
|
||||||
def get_hidden_dim(self, module_name):
|
|
||||||
# return input_dim, output_dim
|
|
||||||
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
|
|
||||||
return self.config.hidden_size, self.config.hidden_size
|
|
||||||
elif module_name in ["kv_proj"]:
|
|
||||||
return self.config.hidden_size, self.config.hidden_size // (
|
|
||||||
self.config.num_attention_heads // self.config.num_key_value_heads
|
|
||||||
)
|
|
||||||
elif module_name == "gate_up_proj":
|
|
||||||
return self.config.hidden_size, self.config.intermediate_size
|
|
||||||
elif module_name == "down_proj":
|
|
||||||
return self.config.intermediate_size, self.config.hidden_size
|
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def get_module_name(self, name):
|
|
||||||
params_mapping = {
|
|
||||||
"q_proj": "qkv_proj",
|
|
||||||
"k_proj": "qkv_proj",
|
|
||||||
"v_proj": "qkv_proj",
|
|
||||||
"gate_proj": "gate_up_proj",
|
|
||||||
"up_proj": "gate_up_proj",
|
|
||||||
}
|
|
||||||
return params_mapping.get(name, name)
|
|
||||||
|
|
||||||
def get_module_name_from_weight_name(self, name):
|
def get_module_name_from_weight_name(self, name):
|
||||||
for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
|
for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
|
||||||
if weight_name in name:
|
if weight_name in name:
|
||||||
|
|||||||
@@ -532,31 +532,6 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
def get_input_embeddings(self) -> nn.Embedding:
|
def get_input_embeddings(self) -> nn.Embedding:
|
||||||
return self.model.embed_tokens
|
return self.model.embed_tokens
|
||||||
|
|
||||||
def get_hidden_dim(self, module_name):
|
|
||||||
# return input_dim, output_dim
|
|
||||||
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
|
|
||||||
return self.config.hidden_size, self.config.hidden_size
|
|
||||||
elif module_name in ["kv_proj"]:
|
|
||||||
return self.config.hidden_size, self.config.hidden_size // (
|
|
||||||
self.config.num_attention_heads // self.config.num_key_value_heads
|
|
||||||
)
|
|
||||||
elif module_name == "gate_up_proj":
|
|
||||||
return self.config.hidden_size, self.config.intermediate_size
|
|
||||||
elif module_name == "down_proj":
|
|
||||||
return self.config.intermediate_size, self.config.hidden_size
|
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def get_module_name(self, name):
|
|
||||||
params_mapping = {
|
|
||||||
"q_proj": "qkv_proj",
|
|
||||||
"k_proj": "qkv_proj",
|
|
||||||
"v_proj": "qkv_proj",
|
|
||||||
"gate_proj": "gate_up_proj",
|
|
||||||
"up_proj": "gate_up_proj",
|
|
||||||
}
|
|
||||||
return params_mapping.get(name, name)
|
|
||||||
|
|
||||||
def get_module_name_from_weight_name(self, name):
|
def get_module_name_from_weight_name(self, name):
|
||||||
for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
|
for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
|
||||||
if weight_name in name:
|
if weight_name in name:
|
||||||
|
|||||||
@@ -330,30 +330,6 @@ class Qwen3ForCausalLM(nn.Module):
|
|||||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return self.model.get_input_embeddings(input_ids)
|
return self.model.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
def get_hidden_dim(self, module_name: str) -> Tuple[int]:
|
|
||||||
# return input_dim, output_dim
|
|
||||||
if module_name in ["q_proj", "qkv_proj"]:
|
|
||||||
return (
|
|
||||||
self.config.hidden_size,
|
|
||||||
self.config.head_dim * self.config.num_attention_heads,
|
|
||||||
)
|
|
||||||
elif module_name in ["o_proj"]:
|
|
||||||
return (
|
|
||||||
self.config.head_dim * self.config.num_attention_heads,
|
|
||||||
self.config.hidden_size,
|
|
||||||
)
|
|
||||||
elif module_name in ["kv_proj"]:
|
|
||||||
return (
|
|
||||||
self.config.hidden_size,
|
|
||||||
self.config.head_dim * self.config.num_key_value_heads,
|
|
||||||
)
|
|
||||||
elif module_name == "gate_up_proj":
|
|
||||||
return self.config.hidden_size, self.config.intermediate_size
|
|
||||||
elif module_name == "down_proj":
|
|
||||||
return self.config.intermediate_size, self.config.hidden_size
|
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -416,30 +416,6 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
|||||||
input_ids, hidden_states, self.lm_head, forward_batch
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_hidden_dim(self, module_name):
|
|
||||||
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
|
|
||||||
return self.config.hidden_size, self.config.hidden_size
|
|
||||||
elif module_name in ["kv_proj"]:
|
|
||||||
return self.config.hidden_size, self.config.hidden_size // (
|
|
||||||
self.config.num_attention_heads // self.config.num_key_value_heads
|
|
||||||
)
|
|
||||||
elif module_name == "gate_up_proj":
|
|
||||||
return self.config.hidden_size, self.config.intermediate_size
|
|
||||||
elif module_name == "down_proj":
|
|
||||||
return self.config.intermediate_size, self.config.hidden_size
|
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def get_module_name(self, name):
|
|
||||||
params_mapping = {
|
|
||||||
"q_proj": "qkv_proj",
|
|
||||||
"k_proj": "qkv_proj",
|
|
||||||
"v_proj": "qkv_proj",
|
|
||||||
"gate_proj": "gate_up_proj",
|
|
||||||
"up_proj": "gate_up_proj",
|
|
||||||
}
|
|
||||||
return params_mapping.get(name, name)
|
|
||||||
|
|
||||||
def get_module_name_from_weight_name(self, name):
|
def get_module_name_from_weight_name(self, name):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
# (param_name, shard_name, shard_id, num_shard)
|
# (param_name, shard_name, shard_id, num_shard)
|
||||||
|
|||||||
Reference in New Issue
Block a user