Sync from v0.13
This commit is contained in:
@@ -1,9 +1,20 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.lora.models import LoRAModel
|
||||
from vllm.lora.lora_model import LoRAModel
|
||||
from vllm.lora.peft_helper import PEFTHelper
|
||||
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
|
||||
lora_lst = ["baichuan7B", "baichuan7B-zero", "chatglm3-6b"]
|
||||
lora_lst = ["baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b"]
|
||||
BAICHUAN_LORA_MODULES = [
|
||||
"W_pack",
|
||||
"o_proj",
|
||||
"gate_up_proj",
|
||||
"down_proj",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("lora_name", lora_lst)
|
||||
@@ -11,48 +22,109 @@ def test_load_checkpoints(
|
||||
lora_name,
|
||||
baichuan_lora_files,
|
||||
baichuan_zero_lora_files,
|
||||
baichuan_regex_lora_files,
|
||||
chatglm3_lora_files,
|
||||
):
|
||||
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
|
||||
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
|
||||
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
|
||||
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
|
||||
expected_lora_modules = []
|
||||
for module in supported_lora_modules:
|
||||
|
||||
expected_lora_lst: list[str] = []
|
||||
for module in BAICHUAN_LORA_MODULES:
|
||||
if module in packed_modules_mapping:
|
||||
expected_lora_modules.extend(packed_modules_mapping[module])
|
||||
expected_lora_lst.extend(packed_modules_mapping[module])
|
||||
else:
|
||||
expected_lora_modules.append(module)
|
||||
expected_lora_lst.append(module)
|
||||
expected_lora_modules = set(expected_lora_lst)
|
||||
if lora_name == "baichuan7B":
|
||||
peft_helper = PEFTHelper.from_local_dir(
|
||||
baichuan_lora_files, max_position_embeddings=4096
|
||||
)
|
||||
# For the baichuan7B model, load it's LoRA,
|
||||
# and the test should pass.
|
||||
LoRAModel.from_local_checkpoint(
|
||||
baichuan_lora_files,
|
||||
expected_lora_modules,
|
||||
peft_helper=peft_helper,
|
||||
lora_model_id=1,
|
||||
device="cpu",
|
||||
embedding_modules=embedding_modules,
|
||||
embedding_padding_modules=embed_padding_modules)
|
||||
model_vocab_size=64000,
|
||||
)
|
||||
elif lora_name == "baichuan7B-zero":
|
||||
#Test that the target_modules contain prefix
|
||||
# Test that the target_modules contain prefix
|
||||
# such as "model.layers.0.self_atten.W_pack", and
|
||||
# the test should pass.
|
||||
peft_helper = PEFTHelper.from_local_dir(
|
||||
baichuan_zero_lora_files, max_position_embeddings=4096
|
||||
)
|
||||
LoRAModel.from_local_checkpoint(
|
||||
baichuan_zero_lora_files,
|
||||
expected_lora_modules,
|
||||
peft_helper=peft_helper,
|
||||
lora_model_id=1,
|
||||
device="cpu",
|
||||
embedding_modules=embedding_modules,
|
||||
embedding_padding_modules=embed_padding_modules)
|
||||
model_vocab_size=64000,
|
||||
)
|
||||
elif lora_name == "baichuan7B-zero-regex":
|
||||
# Test that the `target_modules` in the form of regular expressions,
|
||||
# such as `model\\..*(W_pack|o_proj)`, and the test should pass.
|
||||
peft_helper = PEFTHelper.from_local_dir(
|
||||
baichuan_regex_lora_files, max_position_embeddings=4096
|
||||
)
|
||||
LoRAModel.from_local_checkpoint(
|
||||
baichuan_regex_lora_files,
|
||||
expected_lora_modules,
|
||||
peft_helper=peft_helper,
|
||||
lora_model_id=1,
|
||||
device="cpu",
|
||||
model_vocab_size=64000,
|
||||
)
|
||||
else:
|
||||
# For the baichuan7B model, load chatglm3-6b's LoRA,
|
||||
# and the test should raise the following error.
|
||||
expected_error = "Please verify that the loaded LoRA module is correct" # noqa: E501
|
||||
peft_helper = PEFTHelper.from_local_dir(
|
||||
chatglm3_lora_files, max_position_embeddings=4096
|
||||
)
|
||||
with pytest.raises(ValueError, match=expected_error):
|
||||
LoRAModel.from_local_checkpoint(
|
||||
chatglm3_lora_files,
|
||||
expected_lora_modules,
|
||||
peft_helper=peft_helper,
|
||||
lora_model_id=1,
|
||||
device="cpu",
|
||||
embedding_modules=embedding_modules,
|
||||
embedding_padding_modules=embed_padding_modules)
|
||||
model_vocab_size=64000,
|
||||
)
|
||||
|
||||
|
||||
def test_lora_weights_mapping(baichuan_lora_files):
|
||||
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
|
||||
|
||||
expected_lora_lst: list[str] = []
|
||||
for module in BAICHUAN_LORA_MODULES:
|
||||
if module in packed_modules_mapping:
|
||||
expected_lora_lst.extend(packed_modules_mapping[module])
|
||||
else:
|
||||
expected_lora_lst.append(module)
|
||||
expected_lora_modules = set(expected_lora_lst)
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"model.": "language_model.model.",
|
||||
},
|
||||
orig_to_new_substr={
|
||||
".layers.": ".baichuan_layers.",
|
||||
},
|
||||
)
|
||||
peft_helper = PEFTHelper.from_local_dir(
|
||||
baichuan_lora_files, max_position_embeddings=4096
|
||||
)
|
||||
lora_model = LoRAModel.from_local_checkpoint(
|
||||
baichuan_lora_files,
|
||||
expected_lora_modules,
|
||||
peft_helper=peft_helper,
|
||||
lora_model_id=1,
|
||||
device="cpu",
|
||||
model_vocab_size=64000,
|
||||
weights_mapper=hf_to_vllm_mapper,
|
||||
)
|
||||
for name in lora_model.loras:
|
||||
assert name.startswith(hf_to_vllm_mapper.orig_to_new_prefix["model."])
|
||||
assert ".baichuan_layers." in name
|
||||
|
||||
Reference in New Issue
Block a user