[Fix] Fix AttributeError in Qwen2.5 LoRA: 'Qwen2ForCausalLM' object has no attribute 'get_hidden_dim' (#1536)
Co-authored-by: Ying Sheng <sqy1415@gmail.com>
This commit is contained in:
@@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
# and "Punica: Multi-Tenant LoRA Serving"
|
# and "Punica: Multi-Tenant LoRA Serving"
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -26,11 +27,48 @@ from sglang.srt.lora.lora_config import LoRAConfig
|
|||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.utils import is_hip, replace_submodule
|
from sglang.srt.utils import is_hip, replace_submodule
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# ROCm: flashinfer available later
|
# ROCm: flashinfer available later
|
||||||
if not is_hip():
|
if not is_hip():
|
||||||
from flashinfer import SegmentGEMMWrapper
|
from flashinfer import SegmentGEMMWrapper
|
||||||
|
|
||||||
|
|
||||||
|
def get_module_name(name):
|
||||||
|
# Fallback solution of mapping from config module name to module name in model class.
|
||||||
|
# Please check if it aligns with your base model.
|
||||||
|
# Please implement the function in the model class if it is not.
|
||||||
|
# You can reference this function in llama.py.
|
||||||
|
params_mapping = {
|
||||||
|
"q_proj": "qkv_proj",
|
||||||
|
"k_proj": "qkv_proj",
|
||||||
|
"v_proj": "qkv_proj",
|
||||||
|
"gate_proj": "gate_up_proj",
|
||||||
|
"up_proj": "gate_up_proj",
|
||||||
|
}
|
||||||
|
return params_mapping.get(name, name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_hidden_dim(module_name, config):
|
||||||
|
# Fallback solution of get_hidden_dim for different modules
|
||||||
|
# Please check if it aligns with your base model.
|
||||||
|
# Please implement the function in the model class if it is not.
|
||||||
|
# You can reference this function in llama.py.
|
||||||
|
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
|
||||||
|
return config.hidden_size, config.hidden_size
|
||||||
|
elif module_name in ["kv_proj"]:
|
||||||
|
return config.hidden_size, config.hidden_size // (
|
||||||
|
config.num_attention_heads // config.num_key_value_heads
|
||||||
|
)
|
||||||
|
elif module_name == "gate_up_proj":
|
||||||
|
return config.hidden_size, config.intermediate_size
|
||||||
|
elif module_name == "down_proj":
|
||||||
|
return config.intermediate_size, config.hidden_size
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
def get_stacked_name(name):
|
def get_stacked_name(name):
|
||||||
# origin name -> (name for A, name for B)
|
# origin name -> (name for A, name for B)
|
||||||
params_mapping = {
|
params_mapping = {
|
||||||
@@ -103,12 +141,20 @@ class LoRAManager:
|
|||||||
self.origin_target_modules = set(self.origin_target_modules) | set(
|
self.origin_target_modules = set(self.origin_target_modules) | set(
|
||||||
self.configs[name].target_modules
|
self.configs[name].target_modules
|
||||||
)
|
)
|
||||||
self.target_modules = set(
|
if hasattr(self.base_model, "get_module_name"):
|
||||||
[
|
self.target_modules = {
|
||||||
self.base_model.get_module_name(module)
|
self.base_model.get_module_name(module)
|
||||||
for module in self.origin_target_modules
|
for module in self.origin_target_modules
|
||||||
]
|
}
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"WARNING: get_module_name() is not defined, "
|
||||||
|
f"which is used to map config module name to model implementation module name."
|
||||||
|
f"Use the default one, but please check if it is correct for your model."
|
||||||
)
|
)
|
||||||
|
self.target_modules = {
|
||||||
|
get_module_name(module) for module in self.origin_target_modules
|
||||||
|
}
|
||||||
self.target_weights = set(
|
self.target_weights = set(
|
||||||
[get_stacked_name(module) for module in self.origin_target_modules]
|
[get_stacked_name(module) for module in self.origin_target_modules]
|
||||||
)
|
)
|
||||||
@@ -146,7 +192,15 @@ class LoRAManager:
|
|||||||
num_layer = self.base_hf_config.num_hidden_layers
|
num_layer = self.base_hf_config.num_hidden_layers
|
||||||
for module_A, module_B in self.target_weights:
|
for module_A, module_B in self.target_weights:
|
||||||
# init A tensor, column_major=True
|
# init A tensor, column_major=True
|
||||||
|
if hasattr(self.base_model, "get_hidden_dim"):
|
||||||
hidden_dim_A, _ = self.base_model.get_hidden_dim(module_A)
|
hidden_dim_A, _ = self.base_model.get_hidden_dim(module_A)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"WARNING: get_hidden_dim() is not defined, "
|
||||||
|
f"which is used to get the hidden dim for different lora modules"
|
||||||
|
f"Use the default one, but please check if it is correct for your model."
|
||||||
|
)
|
||||||
|
hidden_dim_A, _ = get_hidden_dim(module_A, self.base_hf_config)
|
||||||
c = self.loras[-1].get_stacked_multiply(module_A)
|
c = self.loras[-1].get_stacked_multiply(module_A)
|
||||||
if module_A not in self.A_buffer:
|
if module_A not in self.A_buffer:
|
||||||
self.A_buffer[module_A] = [
|
self.A_buffer[module_A] = [
|
||||||
@@ -162,7 +216,15 @@ class LoRAManager:
|
|||||||
for i in range(num_layer)
|
for i in range(num_layer)
|
||||||
]
|
]
|
||||||
# init B tensor, column_major=True
|
# init B tensor, column_major=True
|
||||||
|
if hasattr(self.base_model, "get_hidden_dim"):
|
||||||
_, hidden_dim_B = self.base_model.get_hidden_dim(module_B)
|
_, hidden_dim_B = self.base_model.get_hidden_dim(module_B)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"WARNING: get_hidden_dim() is not defined, "
|
||||||
|
f"which is used to get the hidden dim for different lora modules"
|
||||||
|
f"Use the default one, but please check if it is correct for your model."
|
||||||
|
)
|
||||||
|
_, hidden_dim_B = get_hidden_dim(module_B, self.base_hf_config)
|
||||||
c = self.loras[-1].get_stacked_multiply(module_B)
|
c = self.loras[-1].get_stacked_multiply(module_B)
|
||||||
if module_B not in self.B_buffer:
|
if module_B not in self.B_buffer:
|
||||||
self.B_buffer[module_B] = [
|
self.B_buffer[module_B] = [
|
||||||
|
|||||||
@@ -319,6 +319,7 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_hidden_dim(self, module_name):
|
def get_hidden_dim(self, module_name):
|
||||||
|
# return input_dim, output_dim
|
||||||
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
|
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
|
||||||
return self.config.hidden_size, self.config.hidden_size
|
return self.config.hidden_size, self.config.hidden_size
|
||||||
elif module_name in ["kv_proj"]:
|
elif module_name in ["kv_proj"]:
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ CI_MODELS = [
|
|||||||
# All other models
|
# All other models
|
||||||
ALL_OTHER_MODELS = [
|
ALL_OTHER_MODELS = [
|
||||||
ModelCase("Qwen/Qwen2-1.5B"),
|
ModelCase("Qwen/Qwen2-1.5B"),
|
||||||
|
ModelCase("Qwen/Qwen2.5-14B-Instruct"),
|
||||||
ModelCase("HuggingFaceTB/SmolLM-135M-Instruct"),
|
ModelCase("HuggingFaceTB/SmolLM-135M-Instruct"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ LORA_SETS = [
|
|||||||
# "loras": ["RuterNorway/Llama-2-7b-chat-norwegian-LoRa"],
|
# "loras": ["RuterNorway/Llama-2-7b-chat-norwegian-LoRa"],
|
||||||
# },
|
# },
|
||||||
{"base": "meta-llama/Llama-2-7b-hf", "loras": ["winddude/wizardLM-LlaMA-LoRA-7B"]},
|
{"base": "meta-llama/Llama-2-7b-hf", "loras": ["winddude/wizardLM-LlaMA-LoRA-7B"]},
|
||||||
|
# {"base": "Qwen/Qwen2.5-14B-Instruct", "loras": ["mssongit/Qwen2.5-14B-SFT-LoRA"]},
|
||||||
# {"base": "mistralai/Mistral-7B-Instruct-v0.3", "loras": ["/home/ying/test_lora"]},
|
# {"base": "mistralai/Mistral-7B-Instruct-v0.3", "loras": ["/home/ying/test_lora"]},
|
||||||
# {
|
# {
|
||||||
# "base": "mistralai/Mistral-7B-Instruct-v0.3",
|
# "base": "mistralai/Mistral-7B-Instruct-v0.3",
|
||||||
@@ -170,7 +171,7 @@ class TestLoRA(unittest.TestCase):
|
|||||||
print(f"{srt_no_lora_outputs.output_strs=}")
|
print(f"{srt_no_lora_outputs.output_strs=}")
|
||||||
for i in range(len(prompts)):
|
for i in range(len(prompts)):
|
||||||
assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[i], (
|
assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[i], (
|
||||||
str_outputs.output_strs[i].strip(" "),
|
srt_outputs.output_strs[i].strip(" "),
|
||||||
hf_outputs.output_strs[i],
|
hf_outputs.output_strs[i],
|
||||||
)
|
)
|
||||||
# assert (
|
# assert (
|
||||||
@@ -264,7 +265,7 @@ class TestLoRA(unittest.TestCase):
|
|||||||
|
|
||||||
for i in range(len(prompts)):
|
for i in range(len(prompts)):
|
||||||
assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[i], (
|
assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[i], (
|
||||||
str_outputs.output_strs[i].strip(" "),
|
srt_outputs.output_strs[i].strip(" "),
|
||||||
hf_outputs.output_strs[i],
|
hf_outputs.output_strs[i],
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
|
|||||||
Reference in New Issue
Block a user