[Fix] Fix AttributeError in Qwen2.5 LoRA: 'Qwen2ForCausalLM' object has no attribute 'get_hidden_dim' (#1536)

Co-authored-by: Ying Sheng <sqy1415@gmail.com>
This commit is contained in:
Minsang Song
2024-10-03 12:41:15 +09:00
committed by GitHub
parent 4ae0969c0a
commit e6852b0dd2
4 changed files with 73 additions and 8 deletions

View File

@@ -17,6 +17,7 @@ limitations under the License.
# and "Punica: Multi-Tenant LoRA Serving"
import logging
import re
import torch
@@ -26,11 +27,48 @@ from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_hip, replace_submodule
logger = logging.getLogger(__name__)
# ROCm: flashinfer available later
if not is_hip():
from flashinfer import SegmentGEMMWrapper
def get_module_name(name):
# Fallback solution of mapping from config module name to module name in model class.
# Please check if it aligns with your base model.
# Please implement the function in the model class if it is not.
# You can reference this function in llama.py.
params_mapping = {
"q_proj": "qkv_proj",
"k_proj": "qkv_proj",
"v_proj": "qkv_proj",
"gate_proj": "gate_up_proj",
"up_proj": "gate_up_proj",
}
return params_mapping.get(name, name)
def get_hidden_dim(module_name, config):
# Fallback solution of get_hidden_dim for different modules
# Please check if it aligns with your base model.
# Please implement the function in the model class if it is not.
# You can reference this function in llama.py.
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
return config.hidden_size, config.hidden_size
elif module_name in ["kv_proj"]:
return config.hidden_size, config.hidden_size // (
config.num_attention_heads // config.num_key_value_heads
)
elif module_name == "gate_up_proj":
return config.hidden_size, config.intermediate_size
elif module_name == "down_proj":
return config.intermediate_size, config.hidden_size
else:
raise NotImplementedError()
def get_stacked_name(name):
# origin name -> (name for A, name for B)
params_mapping = {
@@ -103,12 +141,20 @@ class LoRAManager:
self.origin_target_modules = set(self.origin_target_modules) | set(
self.configs[name].target_modules
)
self.target_modules = set(
[
if hasattr(self.base_model, "get_module_name"):
self.target_modules = {
self.base_model.get_module_name(module)
for module in self.origin_target_modules
]
)
}
else:
logger.warning(
f"WARNING: get_module_name() is not defined, "
f"which is used to map config module name to model implementation module name."
f"Use the default one, but please check if it is correct for your model."
)
self.target_modules = {
get_module_name(module) for module in self.origin_target_modules
}
self.target_weights = set(
[get_stacked_name(module) for module in self.origin_target_modules]
)
@@ -146,7 +192,15 @@ class LoRAManager:
num_layer = self.base_hf_config.num_hidden_layers
for module_A, module_B in self.target_weights:
# init A tensor, column_major=True
hidden_dim_A, _ = self.base_model.get_hidden_dim(module_A)
if hasattr(self.base_model, "get_hidden_dim"):
hidden_dim_A, _ = self.base_model.get_hidden_dim(module_A)
else:
logger.warning(
f"WARNING: get_hidden_dim() is not defined, "
f"which is used to get the hidden dim for different lora modules"
f"Use the default one, but please check if it is correct for your model."
)
hidden_dim_A, _ = get_hidden_dim(module_A, self.base_hf_config)
c = self.loras[-1].get_stacked_multiply(module_A)
if module_A not in self.A_buffer:
self.A_buffer[module_A] = [
@@ -162,7 +216,15 @@ class LoRAManager:
for i in range(num_layer)
]
# init B tensor, column_major=True
_, hidden_dim_B = self.base_model.get_hidden_dim(module_B)
if hasattr(self.base_model, "get_hidden_dim"):
_, hidden_dim_B = self.base_model.get_hidden_dim(module_B)
else:
logger.warning(
f"WARNING: get_hidden_dim() is not defined, "
f"which is used to get the hidden dim for different lora modules"
f"Use the default one, but please check if it is correct for your model."
)
_, hidden_dim_B = get_hidden_dim(module_B, self.base_hf_config)
c = self.loras[-1].get_stacked_multiply(module_B)
if module_B not in self.B_buffer:
self.B_buffer[module_B] = [

View File

@@ -319,6 +319,7 @@ class LlamaForCausalLM(nn.Module):
)
def get_hidden_dim(self, module_name):
# return input_dim, output_dim
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
return self.config.hidden_size, self.config.hidden_size
elif module_name in ["kv_proj"]: