Files
enginex-mthreads-vllm/tests/lora/test_lora_checkpoints.py

131 lines
4.5 KiB
Python
Raw Permalink Normal View History

2026-01-19 10:38:50 +08:00
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
2026-01-09 13:34:11 +08:00
import pytest
2026-01-19 10:38:50 +08:00
from vllm.lora.lora_model import LoRAModel
from vllm.lora.peft_helper import PEFTHelper
2026-01-09 13:34:11 +08:00
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
2026-01-19 10:38:50 +08:00
from vllm.model_executor.models.utils import WeightsMapper
2026-01-09 13:34:11 +08:00
2026-01-19 10:38:50 +08:00
lora_lst = ["baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b"]
BAICHUAN_LORA_MODULES = [
"W_pack",
"o_proj",
"gate_up_proj",
"down_proj",
]
2026-01-09 13:34:11 +08:00
@pytest.mark.parametrize("lora_name", lora_lst)
def test_load_checkpoints(
lora_name,
baichuan_lora_files,
baichuan_zero_lora_files,
2026-01-19 10:38:50 +08:00
baichuan_regex_lora_files,
2026-01-09 13:34:11 +08:00
chatglm3_lora_files,
):
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
2026-01-19 10:38:50 +08:00
expected_lora_lst: list[str] = []
for module in BAICHUAN_LORA_MODULES:
2026-01-09 13:34:11 +08:00
if module in packed_modules_mapping:
2026-01-19 10:38:50 +08:00
expected_lora_lst.extend(packed_modules_mapping[module])
2026-01-09 13:34:11 +08:00
else:
2026-01-19 10:38:50 +08:00
expected_lora_lst.append(module)
expected_lora_modules = set(expected_lora_lst)
2026-01-09 13:34:11 +08:00
if lora_name == "baichuan7B":
2026-01-19 10:38:50 +08:00
peft_helper = PEFTHelper.from_local_dir(
baichuan_lora_files, max_position_embeddings=4096
)
2026-01-09 13:34:11 +08:00
# For the baichuan7B model, load it's LoRA,
# and the test should pass.
LoRAModel.from_local_checkpoint(
baichuan_lora_files,
expected_lora_modules,
2026-01-19 10:38:50 +08:00
peft_helper=peft_helper,
2026-01-09 13:34:11 +08:00
lora_model_id=1,
device="cpu",
2026-01-19 10:38:50 +08:00
model_vocab_size=64000,
)
2026-01-09 13:34:11 +08:00
elif lora_name == "baichuan7B-zero":
2026-01-19 10:38:50 +08:00
# Test that the target_modules contain prefix
2026-01-09 13:34:11 +08:00
# such as "model.layers.0.self_atten.W_pack", and
# the test should pass.
2026-01-19 10:38:50 +08:00
peft_helper = PEFTHelper.from_local_dir(
baichuan_zero_lora_files, max_position_embeddings=4096
)
2026-01-09 13:34:11 +08:00
LoRAModel.from_local_checkpoint(
baichuan_zero_lora_files,
expected_lora_modules,
2026-01-19 10:38:50 +08:00
peft_helper=peft_helper,
lora_model_id=1,
device="cpu",
model_vocab_size=64000,
)
elif lora_name == "baichuan7B-zero-regex":
# Test that the `target_modules` in the form of regular expressions,
# such as `model\\..*(W_pack|o_proj)`, and the test should pass.
peft_helper = PEFTHelper.from_local_dir(
baichuan_regex_lora_files, max_position_embeddings=4096
)
LoRAModel.from_local_checkpoint(
baichuan_regex_lora_files,
expected_lora_modules,
peft_helper=peft_helper,
2026-01-09 13:34:11 +08:00
lora_model_id=1,
device="cpu",
2026-01-19 10:38:50 +08:00
model_vocab_size=64000,
)
2026-01-09 13:34:11 +08:00
else:
# For the baichuan7B model, load chatglm3-6b's LoRA,
# and the test should raise the following error.
expected_error = "Please verify that the loaded LoRA module is correct" # noqa: E501
2026-01-19 10:38:50 +08:00
peft_helper = PEFTHelper.from_local_dir(
chatglm3_lora_files, max_position_embeddings=4096
)
2026-01-09 13:34:11 +08:00
with pytest.raises(ValueError, match=expected_error):
LoRAModel.from_local_checkpoint(
chatglm3_lora_files,
expected_lora_modules,
2026-01-19 10:38:50 +08:00
peft_helper=peft_helper,
2026-01-09 13:34:11 +08:00
lora_model_id=1,
device="cpu",
2026-01-19 10:38:50 +08:00
model_vocab_size=64000,
)
def test_lora_weights_mapping(baichuan_lora_files):
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
expected_lora_lst: list[str] = []
for module in BAICHUAN_LORA_MODULES:
if module in packed_modules_mapping:
expected_lora_lst.extend(packed_modules_mapping[module])
else:
expected_lora_lst.append(module)
expected_lora_modules = set(expected_lora_lst)
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"model.": "language_model.model.",
},
orig_to_new_substr={
".layers.": ".baichuan_layers.",
},
)
peft_helper = PEFTHelper.from_local_dir(
baichuan_lora_files, max_position_embeddings=4096
)
lora_model = LoRAModel.from_local_checkpoint(
baichuan_lora_files,
expected_lora_modules,
peft_helper=peft_helper,
lora_model_id=1,
device="cpu",
model_vocab_size=64000,
weights_mapper=hf_to_vllm_mapper,
)
for name in lora_model.loras:
assert name.startswith(hf_to_vllm_mapper.orig_to_new_prefix["model."])
assert ".baichuan_layers." in name