diff --git a/python/sglang/srt/lora/utils.py b/python/sglang/srt/lora/utils.py index d440fa70c..61642cba5 100644 --- a/python/sglang/srt/lora/utils.py +++ b/python/sglang/srt/lora/utils.py @@ -92,11 +92,30 @@ def get_hidden_dim( Please implement the function in the model class if it is not. You can reference this function in llama.py. """ - if module_name in ["q_proj", "o_proj", "qkv_proj"]: - return config.hidden_size, config.hidden_size - elif module_name in ["kv_proj"]: - return config.hidden_size, config.hidden_size // ( - config.num_attention_heads // config.num_key_value_heads + head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + + # TODO: the special handling of qkv will be addressed in #8940. + if module_name == "qkv_proj": + return ( + config.hidden_size, + None, # qkv_proj is only used in LoRA A + ) + elif module_name == "kv_proj": + return ( + None, # kv_proj is only used in LoRA B + head_dim * config.num_key_value_heads, + ) + elif module_name == "q_proj": + return ( + None, # q_proj is only used in LoRA B + head_dim * config.num_attention_heads, + ) + elif module_name == "o_proj": + return ( + head_dim * config.num_attention_heads, + config.hidden_size, ) elif module_name == "gate_up_proj": return config.hidden_size, config.intermediate_size diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index ee490d083..b9b4e4cec 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -432,40 +432,6 @@ class Gemma2ForCausalLM(nn.Module): return result - def get_hidden_dim(self, module_name): - # return input_dim, output_dim - if module_name in ["q_proj", "qkv_proj"]: - return ( - self.config.hidden_size, - self.config.head_dim * self.config.num_attention_heads, - ) - elif module_name in ["o_proj"]: - return ( - self.config.head_dim * self.config.num_attention_heads, - self.config.hidden_size, - ) - elif module_name in ["kv_proj"]: - return ( - self.config.hidden_size, - self.config.head_dim * self.config.num_key_value_heads, - ) - elif module_name == "gate_up_proj": - return self.config.hidden_size, self.config.intermediate_size - elif module_name == "down_proj": - return self.config.intermediate_size, self.config.hidden_size - else: - raise NotImplementedError() - - def get_module_name(self, name): - params_mapping = { - "q_proj": "qkv_proj", - "k_proj": "qkv_proj", - "v_proj": "qkv_proj", - "gate_proj": "gate_up_proj", - "up_proj": "gate_up_proj", - } - return params_mapping.get(name, name) - def get_attention_sliding_window_size(self): return get_attention_sliding_window_size(self.config) diff --git a/python/sglang/srt/models/gemma3n_mm.py b/python/sglang/srt/models/gemma3n_mm.py index b4bf2ba75..f9c58eaae 100644 --- a/python/sglang/srt/models/gemma3n_mm.py +++ b/python/sglang/srt/models/gemma3n_mm.py @@ -501,9 +501,20 @@ class Gemma3nForConditionalGeneration(PreTrainedModel): def get_hidden_dim(self, module_name): # return input_dim, output_dim - if module_name in ["q_proj", "qkv_proj"]: + # TODO: the special handling of qkv will be addressed in #8940. + if module_name == "qkv_proj": return ( self.config.hidden_size, + None, # qkv_proj is only used in LoRA A + ) + elif module_name == "kv_proj": + return ( + None, # kv_proj is only used in LoRA B + self.config.head_dim * self.config.num_key_value_heads, + ) + elif module_name == "q_proj": + return ( + None, # q_proj is only used in LoRA B self.config.head_dim * self.config.num_attention_heads, ) elif module_name in ["o_proj"]: @@ -511,11 +522,6 @@ class Gemma3nForConditionalGeneration(PreTrainedModel): self.config.head_dim * self.config.num_attention_heads, self.config.hidden_size, ) - elif module_name in ["kv_proj"]: - return ( - self.config.hidden_size, - self.config.head_dim * self.config.num_key_value_heads, - ) elif module_name == "gate_up_proj": assert len(set(self.config.intermediate_size)) == 1, ( "Currently SGLang requires uniform intermediate size for all layers. " diff --git a/python/sglang/srt/models/granite.py b/python/sglang/srt/models/granite.py index 26fccc48d..19252dc8d 100644 --- a/python/sglang/srt/models/granite.py +++ b/python/sglang/srt/models/granite.py @@ -363,31 +363,6 @@ class GraniteForCausalLM(nn.Module): else: return self.pooler(hidden_states, forward_batch) - def get_hidden_dim(self, module_name): - # return input_dim, output_dim - if module_name in ["q_proj", "o_proj", "qkv_proj"]: - return self.config.hidden_size, self.config.hidden_size - elif module_name in ["kv_proj"]: - return self.config.hidden_size, self.config.hidden_size // ( - self.config.num_attention_heads // self.config.num_key_value_heads - ) - elif module_name == "gate_up_proj": - return self.config.hidden_size, self.config.intermediate_size - elif module_name == "down_proj": - return self.config.intermediate_size, self.config.hidden_size - else: - raise NotImplementedError() - - def get_module_name(self, name): - params_mapping = { - "q_proj": "qkv_proj", - "k_proj": "qkv_proj", - "v_proj": "qkv_proj", - "gate_proj": "gate_up_proj", - "up_proj": "gate_up_proj", - } - return params_mapping.get(name, name) - def get_module_name_from_weight_name(self, name): for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping: if weight_name in name: diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index d1614935b..4efbc48fd 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -532,31 +532,6 @@ class LlamaForCausalLM(nn.Module): def get_input_embeddings(self) -> nn.Embedding: return self.model.embed_tokens - def get_hidden_dim(self, module_name): - # return input_dim, output_dim - if module_name in ["q_proj", "o_proj", "qkv_proj"]: - return self.config.hidden_size, self.config.hidden_size - elif module_name in ["kv_proj"]: - return self.config.hidden_size, self.config.hidden_size // ( - self.config.num_attention_heads // self.config.num_key_value_heads - ) - elif module_name == "gate_up_proj": - return self.config.hidden_size, self.config.intermediate_size - elif module_name == "down_proj": - return self.config.intermediate_size, self.config.hidden_size - else: - raise NotImplementedError() - - def get_module_name(self, name): - params_mapping = { - "q_proj": "qkv_proj", - "k_proj": "qkv_proj", - "v_proj": "qkv_proj", - "gate_proj": "gate_up_proj", - "up_proj": "gate_up_proj", - } - return params_mapping.get(name, name) - def get_module_name_from_weight_name(self, name): for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping: if weight_name in name: diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py index 7d7c3bf7b..6289e61e7 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -330,30 +330,6 @@ class Qwen3ForCausalLM(nn.Module): def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) - def get_hidden_dim(self, module_name: str) -> Tuple[int]: - # return input_dim, output_dim - if module_name in ["q_proj", "qkv_proj"]: - return ( - self.config.hidden_size, - self.config.head_dim * self.config.num_attention_heads, - ) - elif module_name in ["o_proj"]: - return ( - self.config.head_dim * self.config.num_attention_heads, - self.config.hidden_size, - ) - elif module_name in ["kv_proj"]: - return ( - self.config.hidden_size, - self.config.head_dim * self.config.num_key_value_heads, - ) - elif module_name == "gate_up_proj": - return self.config.hidden_size, self.config.intermediate_size - elif module_name == "down_proj": - return self.config.intermediate_size, self.config.hidden_size - else: - raise NotImplementedError() - @torch.no_grad() def forward( self, diff --git a/python/sglang/srt/models/torch_native_llama.py b/python/sglang/srt/models/torch_native_llama.py index 494ef80ed..630e5feb8 100644 --- a/python/sglang/srt/models/torch_native_llama.py +++ b/python/sglang/srt/models/torch_native_llama.py @@ -416,30 +416,6 @@ class TorchNativeLlamaForCausalLM(nn.Module): input_ids, hidden_states, self.lm_head, forward_batch ) - def get_hidden_dim(self, module_name): - if module_name in ["q_proj", "o_proj", "qkv_proj"]: - return self.config.hidden_size, self.config.hidden_size - elif module_name in ["kv_proj"]: - return self.config.hidden_size, self.config.hidden_size // ( - self.config.num_attention_heads // self.config.num_key_value_heads - ) - elif module_name == "gate_up_proj": - return self.config.hidden_size, self.config.intermediate_size - elif module_name == "down_proj": - return self.config.intermediate_size, self.config.hidden_size - else: - raise NotImplementedError() - - def get_module_name(self, name): - params_mapping = { - "q_proj": "qkv_proj", - "k_proj": "qkv_proj", - "v_proj": "qkv_proj", - "gate_proj": "gate_up_proj", - "up_proj": "gate_up_proj", - } - return params_mapping.get(name, name) - def get_module_name_from_weight_name(self, name): stacked_params_mapping = [ # (param_name, shard_name, shard_id, num_shard)