From e6852b0dd28979d548885c6b203e0b74616d9af7 Mon Sep 17 00:00:00 2001 From: Minsang Song Date: Thu, 3 Oct 2024 12:41:15 +0900 Subject: [PATCH] [Fix] Fix AttributeError in Qwen2.5 LoRA: 'Qwen2ForCausalLM' object has no attribute 'get_hidden_dim' (#1536) Co-authored-by: Ying Sheng --- python/sglang/srt/lora/lora_manager.py | 74 +++++++++++++++++++++-- python/sglang/srt/models/llama.py | 1 + test/srt/models/test_generation_models.py | 1 + test/srt/models/test_lora.py | 5 +- 4 files changed, 73 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 84f55082a..59cd7e157 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -17,6 +17,7 @@ limitations under the License. # and "Punica: Multi-Tenant LoRA Serving" +import logging import re import torch @@ -26,11 +27,48 @@ from sglang.srt.lora.lora_config import LoRAConfig from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.utils import is_hip, replace_submodule +logger = logging.getLogger(__name__) + + # ROCm: flashinfer available later if not is_hip(): from flashinfer import SegmentGEMMWrapper +def get_module_name(name): + # Fallback solution of mapping from config module name to module name in model class. + # Please check if it aligns with your base model. + # Please implement the function in the model class if it is not. + # You can reference this function in llama.py. + params_mapping = { + "q_proj": "qkv_proj", + "k_proj": "qkv_proj", + "v_proj": "qkv_proj", + "gate_proj": "gate_up_proj", + "up_proj": "gate_up_proj", + } + return params_mapping.get(name, name) + + +def get_hidden_dim(module_name, config): + # Fallback solution of get_hidden_dim for different modules + # Please check if it aligns with your base model. + # Please implement the function in the model class if it is not. + # You can reference this function in llama.py. + if module_name in ["q_proj", "o_proj", "qkv_proj"]: + return config.hidden_size, config.hidden_size + elif module_name in ["kv_proj"]: + return config.hidden_size, config.hidden_size // ( + config.num_attention_heads // config.num_key_value_heads + ) + elif module_name == "gate_up_proj": + return config.hidden_size, config.intermediate_size + elif module_name == "down_proj": + return config.intermediate_size, config.hidden_size + else: + raise NotImplementedError() + + def get_stacked_name(name): # origin name -> (name for A, name for B) params_mapping = { @@ -103,12 +141,20 @@ class LoRAManager: self.origin_target_modules = set(self.origin_target_modules) | set( self.configs[name].target_modules ) - self.target_modules = set( - [ + if hasattr(self.base_model, "get_module_name"): + self.target_modules = { self.base_model.get_module_name(module) for module in self.origin_target_modules - ] - ) + } + else: + logger.warning( + f"WARNING: get_module_name() is not defined, " + f"which is used to map config module name to model implementation module name." + f"Use the default one, but please check if it is correct for your model." + ) + self.target_modules = { + get_module_name(module) for module in self.origin_target_modules + } self.target_weights = set( [get_stacked_name(module) for module in self.origin_target_modules] ) @@ -146,7 +192,15 @@ class LoRAManager: num_layer = self.base_hf_config.num_hidden_layers for module_A, module_B in self.target_weights: # init A tensor, column_major=True - hidden_dim_A, _ = self.base_model.get_hidden_dim(module_A) + if hasattr(self.base_model, "get_hidden_dim"): + hidden_dim_A, _ = self.base_model.get_hidden_dim(module_A) + else: + logger.warning( + f"WARNING: get_hidden_dim() is not defined, " + f"which is used to get the hidden dim for different lora modules" + f"Use the default one, but please check if it is correct for your model." + ) + hidden_dim_A, _ = get_hidden_dim(module_A, self.base_hf_config) c = self.loras[-1].get_stacked_multiply(module_A) if module_A not in self.A_buffer: self.A_buffer[module_A] = [ @@ -162,7 +216,15 @@ class LoRAManager: for i in range(num_layer) ] # init B tensor, column_major=True - _, hidden_dim_B = self.base_model.get_hidden_dim(module_B) + if hasattr(self.base_model, "get_hidden_dim"): + _, hidden_dim_B = self.base_model.get_hidden_dim(module_B) + else: + logger.warning( + f"WARNING: get_hidden_dim() is not defined, " + f"which is used to get the hidden dim for different lora modules" + f"Use the default one, but please check if it is correct for your model." + ) + _, hidden_dim_B = get_hidden_dim(module_B, self.base_hf_config) c = self.loras[-1].get_stacked_multiply(module_B) if module_B not in self.B_buffer: self.B_buffer[module_B] = [ diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index ce7eed969..431250260 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -319,6 +319,7 @@ class LlamaForCausalLM(nn.Module): ) def get_hidden_dim(self, module_name): + # return input_dim, output_dim if module_name in ["q_proj", "o_proj", "qkv_proj"]: return self.config.hidden_size, self.config.hidden_size elif module_name in ["kv_proj"]: diff --git a/test/srt/models/test_generation_models.py b/test/srt/models/test_generation_models.py index 7be410ccb..fadc6dd50 100644 --- a/test/srt/models/test_generation_models.py +++ b/test/srt/models/test_generation_models.py @@ -51,6 +51,7 @@ CI_MODELS = [ # All other models ALL_OTHER_MODELS = [ ModelCase("Qwen/Qwen2-1.5B"), + ModelCase("Qwen/Qwen2.5-14B-Instruct"), ModelCase("HuggingFaceTB/SmolLM-135M-Instruct"), ] diff --git a/test/srt/models/test_lora.py b/test/srt/models/test_lora.py index 85a963893..41ea8fc15 100644 --- a/test/srt/models/test_lora.py +++ b/test/srt/models/test_lora.py @@ -26,6 +26,7 @@ LORA_SETS = [ # "loras": ["RuterNorway/Llama-2-7b-chat-norwegian-LoRa"], # }, {"base": "meta-llama/Llama-2-7b-hf", "loras": ["winddude/wizardLM-LlaMA-LoRA-7B"]}, + # {"base": "Qwen/Qwen2.5-14B-Instruct", "loras": ["mssongit/Qwen2.5-14B-SFT-LoRA"]}, # {"base": "mistralai/Mistral-7B-Instruct-v0.3", "loras": ["/home/ying/test_lora"]}, # { # "base": "mistralai/Mistral-7B-Instruct-v0.3", @@ -170,7 +171,7 @@ class TestLoRA(unittest.TestCase): print(f"{srt_no_lora_outputs.output_strs=}") for i in range(len(prompts)): assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[i], ( - str_outputs.output_strs[i].strip(" "), + srt_outputs.output_strs[i].strip(" "), hf_outputs.output_strs[i], ) # assert ( @@ -264,7 +265,7 @@ class TestLoRA(unittest.TestCase): for i in range(len(prompts)): assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[i], ( - str_outputs.output_strs[i].strip(" "), + srt_outputs.output_strs[i].strip(" "), hf_outputs.output_strs[i], ) assert (