Sync from v0.13
This commit is contained in:
99
tests/lora/test_peft_helper.py
Normal file
99
tests/lora/test_peft_helper.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
import math
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.lora.peft_helper import PEFTHelper
|
||||
|
||||
ERROR_CASES = [
|
||||
(
|
||||
"test_rank",
|
||||
{"r": 1024},
|
||||
"is greater than max_lora_rank",
|
||||
),
|
||||
("test_dora", {"use_dora": True}, "does not yet support DoRA"),
|
||||
(
|
||||
"test_modules_to_save",
|
||||
{"modules_to_save": ["lm_head"]},
|
||||
"only supports modules_to_save being None",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def test_peft_helper_pass(llama32_lora_files, tmp_path):
|
||||
peft_helper = PEFTHelper.from_local_dir(
|
||||
llama32_lora_files, max_position_embeddings=4096
|
||||
)
|
||||
lora_config = LoRAConfig(max_lora_rank=16, max_cpu_loras=3, max_loras=2)
|
||||
peft_helper.validate_legal(lora_config)
|
||||
assert peft_helper.r == 8
|
||||
assert peft_helper.lora_alpha == 32
|
||||
target_modules = sorted(peft_helper.target_modules)
|
||||
|
||||
assert target_modules == [
|
||||
"down_proj",
|
||||
"embed_tokens",
|
||||
"gate_proj",
|
||||
"k_proj",
|
||||
"lm_head",
|
||||
"o_proj",
|
||||
"q_proj",
|
||||
"up_proj",
|
||||
"v_proj",
|
||||
]
|
||||
assert peft_helper.vllm_max_position_embeddings == 4096
|
||||
|
||||
# test RSLoRA
|
||||
rslora_config = dict(use_rslora=True)
|
||||
test_dir = tmp_path / "test_rslora"
|
||||
shutil.copytree(llama32_lora_files, test_dir)
|
||||
|
||||
# Load and modify configuration
|
||||
config_path = test_dir / "adapter_config.json"
|
||||
with open(config_path) as f:
|
||||
adapter_config = json.load(f)
|
||||
# Apply configuration changes
|
||||
adapter_config.update(rslora_config)
|
||||
|
||||
# Save modified configuration
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(adapter_config, f)
|
||||
|
||||
peft_helper = PEFTHelper.from_local_dir(test_dir, max_position_embeddings=4096)
|
||||
peft_helper.validate_legal(lora_config)
|
||||
scaling = peft_helper.lora_alpha / math.sqrt(peft_helper.r)
|
||||
assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_name,config_change,expected_error", ERROR_CASES)
|
||||
def test_peft_helper_error(
|
||||
llama32_lora_files,
|
||||
tmp_path,
|
||||
test_name: str,
|
||||
config_change: dict,
|
||||
expected_error: str,
|
||||
):
|
||||
test_dir = tmp_path / test_name
|
||||
shutil.copytree(llama32_lora_files, test_dir)
|
||||
|
||||
# Load and modify configuration
|
||||
config_path = test_dir / "adapter_config.json"
|
||||
with open(config_path) as f:
|
||||
adapter_config = json.load(f)
|
||||
# Apply configuration changes
|
||||
adapter_config.update(config_change)
|
||||
|
||||
# Save modified configuration
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(adapter_config, f)
|
||||
lora_config = LoRAConfig(max_lora_rank=16, max_cpu_loras=3, max_loras=2)
|
||||
# Test loading the adapter
|
||||
with pytest.raises(ValueError, match=expected_error):
|
||||
PEFTHelper.from_local_dir(
|
||||
test_dir, max_position_embeddings=4096
|
||||
).validate_legal(lora_config)
|
||||
Reference in New Issue
Block a user