add qwen3

This commit is contained in:
Chranos
2026-02-04 17:22:39 +08:00
parent d1c0f68ab4
commit 8511fe8530
1932 changed files with 300426 additions and 0 deletions

View File

View File

@@ -0,0 +1,290 @@
import tempfile
from collections import OrderedDict
from typing import Dict, List, TypedDict
from unittest.mock import MagicMock, patch
import pytest
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download
import vllm
from vllm.config import LoRAConfig
from vllm.distributed import (cleanup_dist_env_and_memory,
init_distributed_environment,
initialize_model_parallel)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader import get_model
class ContextIDInfo(TypedDict):
lora_id: int
context_length: str
class ContextInfo(TypedDict):
lora: str
context_length: str
LONG_LORA_INFOS: List[ContextIDInfo] = [{
"lora_id": 1,
"context_length": "16k",
}, {
"lora_id": 2,
"context_length": "16k",
}, {
"lora_id": 3,
"context_length": "32k",
}]
@pytest.fixture()
def should_do_global_cleanup_after_test(request) -> bool:
"""Allow subdirectories to skip global cleanup by overriding this fixture.
This can provide a ~10x speedup for non-GPU unit tests since they don't need
to initialize torch.
"""
return not request.node.get_closest_marker("skip_global_cleanup")
@pytest.fixture(autouse=True)
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
yield
if should_do_global_cleanup_after_test:
cleanup_dist_env_and_memory(shutdown_ray=True)
@pytest.fixture
def dist_init():
temp_file = tempfile.mkstemp()[1]
init_distributed_environment(
world_size=1,
rank=0,
distributed_init_method=f"file://{temp_file}",
local_rank=0,
backend="nccl",
)
initialize_model_parallel(1, 1)
yield
cleanup_dist_env_and_memory(shutdown_ray=True)
@pytest.fixture
def dist_init_torch_only():
if torch.distributed.is_initialized():
return
temp_file = tempfile.mkstemp()[1]
torch.distributed.init_process_group(
backend="nccl",
world_size=1,
rank=0,
init_method=f"file://{temp_file}",
)
@pytest.fixture
def dummy_model() -> nn.Module:
model = nn.Sequential(
OrderedDict([
("dense1", ColumnParallelLinear(764, 100)),
("dense2", RowParallelLinear(100, 50)),
(
"layer1",
nn.Sequential(
OrderedDict([
("dense1", ColumnParallelLinear(100, 10)),
("dense2", RowParallelLinear(10, 50)),
])),
),
("act2", nn.ReLU()),
("output", ColumnParallelLinear(50, 10)),
("outact", nn.Sigmoid()),
# Special handling for lm_head & sampler
("lm_head", ParallelLMHead(512, 10)),
("logits_processor", LogitsProcessor(512)),
("sampler", Sampler())
]))
model.config = MagicMock()
return model
@pytest.fixture
def dummy_model_gate_up() -> nn.Module:
model = nn.Sequential(
OrderedDict([
("dense1", ColumnParallelLinear(764, 100)),
("dense2", RowParallelLinear(100, 50)),
(
"layer1",
nn.Sequential(
OrderedDict([
("dense1", ColumnParallelLinear(100, 10)),
("dense2", RowParallelLinear(10, 50)),
])),
),
("act2", nn.ReLU()),
("gate_up_proj", MergedColumnParallelLinear(50, [5, 5])),
("outact", nn.Sigmoid()),
# Special handling for lm_head & sampler
("lm_head", ParallelLMHead(512, 10)),
("logits_processor", LogitsProcessor(512)),
("sampler", Sampler())
]))
model.config = MagicMock()
return model
'''
=============================
Modify by vllm_mlu
=============================
@brief: use the linked models in ci
'''
def get_repo_path(repo_id):
"""Do not download the repo when the path exists."""
import os
if os.path.exists(repo_id):
return repo_id
return snapshot_download(repo_id=repo_id)
@pytest.fixture(scope="session")
def sql_lora_huggingface_id():
# huggingface repo id is used to test lora runtime downloading.
return get_repo_path("yard1/llama-2-7b-sql-lora-test")
@pytest.fixture(scope="session")
def sql_lora_files(sql_lora_huggingface_id):
return get_repo_path(repo_id=sql_lora_huggingface_id)
@pytest.fixture(scope="session")
def lora_bias_files():
return snapshot_download(repo_id="followumesh/granite-3b-lora8-bias")
@pytest.fixture(scope="session")
def mixtral_lora_files():
# Note: this module has incorrect adapter_config.json to test
# https://github.com/vllm-project/vllm/pull/5909/files.
return get_repo_path(repo_id="SangBinCho/mixtral-lora")
@pytest.fixture(scope="session")
def mixtral_lora_files_all_target_modules():
return snapshot_download(repo_id="dyang415/mixtral-lora-v0")
@pytest.fixture(scope="session")
def gemma_lora_files():
return get_repo_path(repo_id="wskwon/gemma-7b-test-lora")
@pytest.fixture(scope="session")
def chatglm3_lora_files():
return get_repo_path(repo_id="jeeejeee/chatglm3-text2sql-spider")
@pytest.fixture(scope="session")
def baichuan_lora_files():
return get_repo_path(repo_id="jeeejeee/baichuan7b-text2sql-spider")
@pytest.fixture(scope="session")
def baichuan_zero_lora_files():
# all the lora_B weights are initialized to zero.
return get_repo_path(repo_id="jeeejeee/baichuan7b-zero-init")
@pytest.fixture(scope="session")
def baichuan_regex_lora_files():
return get_repo_path(repo_id="jeeejeee/baichuan-7b-lora-zero-regex")
@pytest.fixture(scope="session")
def minicpmv_lora_files():
return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon")
@pytest.fixture(scope="session")
def tinyllama_lora_files():
return get_repo_path(repo_id="jashing/tinyllama-colorist-lora")
@pytest.fixture(scope="session")
def phi2_lora_files():
return get_repo_path(repo_id="isotr0py/phi-2-test-sql-lora")
@pytest.fixture(scope="session")
def long_context_lora_files_16k_1():
return get_repo_path(repo_id="SangBinCho/long_context_16k_testing_1")
@pytest.fixture(scope="session")
def long_context_lora_files_16k_2():
return get_repo_path(repo_id="SangBinCho/long_context_16k_testing_2")
@pytest.fixture(scope="session")
def long_context_lora_files_32k():
return get_repo_path(repo_id="SangBinCho/long_context_32k_testing")
'''
==================
End of MLU Hijack
==================
'''
@pytest.fixture(scope="session")
def long_context_infos(long_context_lora_files_16k_1,
long_context_lora_files_16k_2,
long_context_lora_files_32k):
cleanup_dist_env_and_memory(shutdown_ray=True)
infos: Dict[int, ContextInfo] = {}
for lora_checkpoint_info in LONG_LORA_INFOS:
lora_id = lora_checkpoint_info["lora_id"]
if lora_id == 1:
lora = long_context_lora_files_16k_1
elif lora_id == 2:
lora = long_context_lora_files_16k_2
elif lora_id == 3:
lora = long_context_lora_files_32k
else:
raise AssertionError("Unknown lora id")
infos[lora_id] = {
"context_length": lora_checkpoint_info["context_length"],
"lora": lora,
}
return infos
@pytest.fixture
def llama_2_7b_engine_extra_embeddings():
cleanup_dist_env_and_memory(shutdown_ray=True)
get_model_old = get_model
def get_model_patched(**kwargs):
kwargs["vllm_config"].lora_config = LoRAConfig(max_loras=4,
max_lora_rank=8)
return get_model_old(**kwargs)
with patch("vllm.worker.model_runner.get_model", get_model_patched):
engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)
yield engine.llm_engine
del engine
cleanup_dist_env_and_memory(shutdown_ray=True)
@pytest.fixture
def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings):
yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker.
model_runner.model)

View File

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,113 @@
from typing import List
import pytest
import vllm
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.lora.request import LoRARequest
MODEL_PATH = "baichuan-inc/Baichuan-7B"
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
prompts = [
PROMPT_TEMPLATE.format(query="How many singers do we have?"),
PROMPT_TEMPLATE.format(
query=
"What is the average, minimum, and maximum age of all singers from France?" # noqa: E501
),
PROMPT_TEMPLATE.format(
query=
"Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501
),
]
print(prompts)
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
def test_baichuan_lora(baichuan_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
max_lora_rank=64,
trust_remote_code=True)
expected_lora_output = [
"SELECT count(*) FROM singer",
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE Country = 'France'", # noqa: E501
"SELECT name , country , age FROM singer ORDER BY age ASC",
]
output1 = do_sample(llm, baichuan_lora_files, lora_id=1)
for i in range(len(expected_lora_output)):
assert output1[i] == expected_lora_output[i]
output2 = do_sample(llm, baichuan_lora_files, lora_id=2)
for i in range(len(expected_lora_output)):
assert output2[i] == expected_lora_output[i]
@pytest.mark.skip("Requires multiple GPUs")
@pytest.mark.parametrize("fully_sharded", [True, False])
def test_baichuan_tensor_parallel_equality(baichuan_lora_files,
num_gpus_available, fully_sharded):
if num_gpus_available < 4:
pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
llm_tp1 = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=1,
trust_remote_code=True,
fully_sharded_loras=fully_sharded)
output_tp1 = do_sample(llm_tp1, baichuan_lora_files, lora_id=1)
del llm_tp1
cleanup_dist_env_and_memory()
llm_tp2 = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=2,
trust_remote_code=True,
fully_sharded_loras=fully_sharded)
output_tp2 = do_sample(llm_tp2, baichuan_lora_files, lora_id=2)
del llm_tp2
cleanup_dist_env_and_memory()
assert output_tp1 == output_tp2
llm_tp4 = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=4,
trust_remote_code=True,
fully_sharded_loras=fully_sharded)
output_tp4 = do_sample(llm_tp4, baichuan_lora_files, lora_id=2)
del llm_tp4
cleanup_dist_env_and_memory()
assert output_tp1 == output_tp4

View File

@@ -0,0 +1,59 @@
from typing import List
import vllm
from vllm.lora.request import LoRARequest
MODEL_PATH = "THUDM/chatglm3-6b"
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
prompts = [
PROMPT_TEMPLATE.format(query="How many singers do we have?"),
PROMPT_TEMPLATE.format(
query=
"What is the average, minimum, and maximum age of all singers from France?" # noqa: E501
),
PROMPT_TEMPLATE.format(
query=
"Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501
),
]
print(prompts)
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
def test_chatglm3_lora(chatglm3_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
max_lora_rank=64,
trust_remote_code=True)
expected_lora_output = [
"SELECT count(*) FROM singer",
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", # noqa: E501
"SELECT name , country , age FROM singer ORDER BY age",
]
output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
for i in range(len(expected_lora_output)):
assert output1[i] == expected_lora_output[i]
output2 = do_sample(llm, chatglm3_lora_files, lora_id=2)
for i in range(len(expected_lora_output)):
assert output2[i] == expected_lora_output[i]

View File

@@ -0,0 +1,54 @@
from typing import List
import pytest
import vllm
from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform
MODEL_PATH = "google/gemma-7b"
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
prompts = [
"Quote: Imagination is",
"Quote: Be yourself;",
"Quote: Painting is poetry that is seen rather than felt,",
]
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
@pytest.mark.xfail(current_platform.is_rocm(),
reason="There can be output mismatch on ROCm")
def test_gemma_lora(gemma_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4)
expected_lora_output = [
"more important than knowledge.\nAuthor: Albert Einstein\n",
"everyone else is already taken.\nAuthor: Oscar Wilde\n",
"and poetry is painting that is felt rather than seen.\n"
"Author: Leonardo da Vinci\n",
]
output1 = do_sample(llm, gemma_lora_files, lora_id=1)
for i in range(len(expected_lora_output)):
assert output1[i].startswith(expected_lora_output[i])
output2 = do_sample(llm, gemma_lora_files, lora_id=2)
for i in range(len(expected_lora_output)):
assert output2[i].startswith(expected_lora_output[i])

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,146 @@
from typing import List
import pytest
import ray
import vllm
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.lora.request import LoRARequest
MODEL_PATH = "meta-llama/Llama-2-7b-hf"
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
prompts = [
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501
]
sampling_params = vllm.SamplingParams(temperature=0,
max_tokens=256,
stop=["[/assistant]"])
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
@pytest.mark.parametrize("tp_size", [1, 2, 4])
def test_llama_lora(sql_lora_files, tp_size, num_gpus_available):
if num_gpus_available < tp_size:
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
llm = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=tp_size)
expected_no_lora_output = [
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_76 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_77 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_78 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user]", # noqa: E501
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? ", # noqa: E501
"\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_97 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_98 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m", # noqa: E501
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. ", # noqa: E501
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ", # noqa: E501
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE", # noqa: E501
]
expected_lora_output = [
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501
" SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ", # noqa: E501
" SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ", # noqa: E501
" SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", # noqa: E501
" SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ", # noqa: E501
" SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' " # noqa: E501
]
print("lora adapter created")
assert do_sample(llm, sql_lora_files, lora_id=0) == expected_no_lora_output
print("lora 1")
assert do_sample(llm, sql_lora_files, lora_id=1) == expected_lora_output
print("no lora")
assert do_sample(llm, sql_lora_files, lora_id=0) == expected_no_lora_output
print("lora 2")
assert do_sample(llm, sql_lora_files, lora_id=2) == expected_lora_output
print("removing lora")
def test_llama_tensor_parallel_equality(sql_lora_files, num_gpus_available):
if num_gpus_available < 4:
pytest.skip("Not enough GPUs for tensor parallelism 4")
llm_tp1 = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=1)
output_tp1 = do_sample(llm_tp1, sql_lora_files, lora_id=1)
del llm_tp1
cleanup_dist_env_and_memory()
llm_tp2 = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=2)
output_tp2 = do_sample(llm_tp2, sql_lora_files, lora_id=1)
del llm_tp2
cleanup_dist_env_and_memory()
assert output_tp1 == output_tp2
llm_tp4 = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=4)
output_tp4 = do_sample(llm_tp4, sql_lora_files, lora_id=1)
del llm_tp4
cleanup_dist_env_and_memory()
assert output_tp1 == output_tp4
def test_llama_lora_warmup(sql_lora_files):
"""Test that the LLM initialization works with a warmup LORA path and
is more conservative"""
@ray.remote(num_gpus=1)
def get_num_gpu_blocks_lora():
llm = vllm.LLM(MODEL_PATH, enable_lora=True, max_num_seqs=16)
num_gpu_blocks_lora_warmup = llm.llm_engine.cache_config.num_gpu_blocks
return num_gpu_blocks_lora_warmup
@ray.remote(num_gpus=1)
def get_num_gpu_blocks_no_lora():
llm = vllm.LLM(MODEL_PATH, max_num_seqs=16)
num_gpu_blocks_no_lora_warmup = (
llm.llm_engine.cache_config.num_gpu_blocks)
return num_gpu_blocks_no_lora_warmup
num_gpu_blocks_lora_warmup = ray.get(get_num_gpu_blocks_lora.remote())
num_gpu_blocks_no_lora_warmup = ray.get(
get_num_gpu_blocks_no_lora.remote())
assert num_gpu_blocks_lora_warmup < num_gpu_blocks_no_lora_warmup, (
"The warmup with lora should be more "
"conservative than without lora, therefore the number of "
"memory blocks for the KV cache should be "
"less when using lora than when not using lora")

View File

@@ -0,0 +1,298 @@
import ast
from typing import List, Optional, Tuple
import numpy as np
import pytest
import vllm
from vllm import SamplingParams
from vllm.lora.layers import LinearScalingRotaryEmbeddingWithLora
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.rotary_embedding import (
LinearScalingRotaryEmbedding)
from .data.long_context_test_data import prompts_and_responses
context_len_to_scaling_factor = {
"16k": 4,
"32k": 8,
}
# We use the same sampling params for all requests
sampling_params = SamplingParams(
temperature=0,
max_tokens=100,
)
def _create_lora_request(lora_id, long_context_infos):
context_len = long_context_infos[lora_id]["context_length"]
scaling_factor = context_len_to_scaling_factor[context_len]
return LoRARequest(
# There are 2 LoRAs for 16K, we need to add lora_id to indicate
# they are different LoRAs.
context_len + str(lora_id),
lora_id,
long_context_infos[lora_id]["lora"],
None,
4096 * scaling_factor,
)
def evaluate_json_response(model_response, golden_response):
"""Evaluates the model response against the golden response.
Returns a score between 0 and 1, where 1 is a perfect match and 0 is no
match. The score quantifies how well the model is able to extract the
golden JSON from the long context.
"""
try:
model_response = ast.literal_eval(model_response)
except Exception as e:
raise ValueError(
f"Model response is not a valid JSON. Expected {golden_response}, "
f"got {model_response}") from e
# Normally, we would flatten the dictionary and compare the values, but in
# this case, we know that the dictionary is only 2 levels deep
positive_values = 0
total_values = 0
# We look at all the attributes of the person that we are extracting a
# biography of and copmare them to the golden response
for person_attribute, person_attribute_value in golden_response.items():
if person_attribute in model_response:
if isinstance(person_attribute_value, dict):
for (sub_attribute,
sub_attribute_value) in person_attribute_value.items():
total_values += 1
if sub_attribute in model_response[
person_attribute] and model_response[
person_attribute][
sub_attribute] == sub_attribute_value:
positive_values += 1
else:
total_values += 1
if model_response[person_attribute] == person_attribute_value:
positive_values += 1
else:
# We count a missing sub-dict as a single missed value.
total_values += 1
# Return a score between 0 and 1
return positive_values / total_values
def generate(
llm: vllm.LLM,
inputs: Tuple[str, SamplingParams, Optional[LoRARequest]],
):
prompts, sampling_param, lora_request = inputs
outputs = llm.generate(prompts, sampling_param, lora_request=lora_request)
return outputs[0].outputs[0].text.strip()
def batched_generate(
llm: vllm.LLM,
inputs: List[Tuple[str, SamplingParams, Optional[LoRARequest]]],
):
for input in inputs:
prompt, sampling_param, lora_req = input
# Add requests to the engine and run the engine
llm._validate_and_add_requests(prompt,
sampling_param,
lora_request=lora_req,
prompt_adapter_request=None)
outputs = llm._run_engine(use_tqdm=True)
return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))]
@pytest.fixture(scope="module")
def lora_llm(long_context_infos):
scaling_factors = [
context_len_to_scaling_factor[info["context_length"]]
for info in long_context_infos.values()
]
llm = vllm.LLM(
"meta-llama/Llama-2-13b-chat-hf",
enable_lora=True,
max_num_seqs=16,
max_loras=2,
long_lora_scaling_factors=tuple(scaling_factors),
max_num_batched_tokens=4096 * 8,
tensor_parallel_size=4,
# FIXME enable async output processor
disable_async_output_proc=True,
distributed_executor_backend="mp")
yield llm
del llm
def test_rotary_emb_replaced(dist_init):
"""Verify rotary emb in all the layers are replaced"""
from vllm.engine.arg_utils import EngineArgs
from vllm.worker.model_runner import ModelRunner
engine_args = EngineArgs("meta-llama/Llama-2-7b-hf",
long_lora_scaling_factors=(4.0, ),
enable_lora=True)
engine_config = engine_args.create_engine_config()
model_runner = ModelRunner(
vllm_config=engine_config,
is_driver_worker=True,
)
model_runner.load_model()
rotary_emb_count = 0
for module_name, module in model_runner.model.named_modules(
remove_duplicate=False):
if "rotary_emb" in module_name:
if "base_layer" not in module_name:
rotary_emb_count += 1
assert isinstance(module, LinearScalingRotaryEmbeddingWithLora)
else:
assert isinstance(module, LinearScalingRotaryEmbedding)
# Llama 2 has 32 layers.
assert rotary_emb_count == 32
@pytest.mark.skip_global_cleanup
def test_batched_rope_kernel(lora_llm, long_context_infos):
"""We test the batched kernel by comparing the results of batched an
non-batched generation.
"""
# Create non batched results first to compare against batched results
non_batched_results: List[str] = []
for lora_id, info in long_context_infos.items():
context_len = info["context_length"]
lora_prompt = (prompts_and_responses[context_len][0]["prompt"],
sampling_params,
_create_lora_request(lora_id, long_context_infos))
lora_output = generate(lora_llm, lora_prompt)
non_batched_results.append(lora_output)
# Create batched results
# Each element of the batch must be
# (prompt, prompt_sampling_params, prompt_lora_request)
batched_prompts: List[Tuple[str, SamplingParams,
Optional[LoRARequest]]] = []
for lora_id, info in long_context_infos.items():
context_len = info["context_length"]
batched_prompts.extend([
(prompts_and_responses[context_len][0]["prompt"], sampling_params,
_create_lora_request(lora_id, long_context_infos))
])
batched_results = batched_generate(lora_llm, batched_prompts)
# Results should be the same
for non_batched, batched in zip(non_batched_results, batched_results):
assert non_batched == batched, (
"Non batched and batched results should be the "
f"same:\n{batched}\n{non_batched}")
@pytest.mark.skip_global_cleanup
def test_self_consistency(lora_llm, long_context_infos):
"""We test consistency of the batched kernel by permuting batched
inputs and comparing the results to the non-permuted batched results.
"""
num_loras = len(long_context_infos)
# Create results in order of long_context_infos
batched_prompts: List[Tuple[str, SamplingParams,
Optional[LoRARequest]]] = []
for lora_id, info in long_context_infos.items():
context_len = info["context_length"]
batched_prompts.extend([
(prompts_and_responses[context_len][0]["prompt"], sampling_params,
_create_lora_request(lora_id, long_context_infos))
])
batched_results = batched_generate(lora_llm, batched_prompts)
permutation = np.random.default_rng(seed=42).permutation(num_loras)
# Create results in random order of permutation
batched_prompts = []
for i in permutation:
lora_id, info = list(long_context_infos.items())[i]
context_len = info["context_length"]
batched_prompts.extend([
(prompts_and_responses[context_len][0]["prompt"], sampling_params,
_create_lora_request(lora_id, long_context_infos))
])
permutated_batched_results = batched_generate(lora_llm, batched_prompts)
# Results should be the same
for i in range(num_loras):
assert batched_results[i] == permutated_batched_results[
permutation[i]], (
f"Results should be the same:\n{batched_results[i]}"
f"\n{permutated_batched_results[permutation[i]]}")
@pytest.mark.skip_global_cleanup
def test_quality(lora_llm, long_context_infos):
"""We test the quality of the answers given by the LoRA model by
comparing the generated text to the merged model's outputs.
This is effectively a mini-benchmark over four prompts.
If this test fails, this indicates that the quality of the LoRA model
is suboptimal compared to the merged model. For example, if the model
does not output valid dictionaries, this test will fail.
If needed for testing, the merged versions of the models are available
as part of the `conftest`.
The test is expected to run for about 1 minute on a p4de.24xlarge
instance.
"""
scores: List[float] = []
for lora_id, info in long_context_infos.items():
context_len = info["context_length"]
for prompt_and_response in prompts_and_responses[context_len]:
lora_prompt = (prompt_and_response["prompt"], sampling_params,
_create_lora_request(lora_id, long_context_infos))
response = generate(lora_llm, lora_prompt)
golden_answer = prompt_and_response["golden_answer"]
score = evaluate_json_response(response, golden_answer)
scores.append(score)
assert score > 0.3, ("Quality of the answer is not good enough. "
f"Expected {golden_answer}, got {response}")
assert np.mean(scores) > 0.5
@pytest.mark.skip_global_cleanup
def test_max_len(lora_llm, long_context_infos):
"""Test that we raise an ValueError when the input of a given LoRA
model exceeds the maximum length."""
# Since each LoRA model has a different maximum length, we need to
# test each one separately
for lora_id, info in long_context_infos.items():
context_len = info["context_length"]
lora_request = _create_lora_request(lora_id, long_context_infos)
# Good prompt should be fine
good_prompt = prompts_and_responses[context_len][0]["prompt"]
generate(lora_llm, (good_prompt, sampling_params, lora_request))
# Bad prompt should raise an error
bad_prompt = good_prompt * 2
with pytest.raises(ValueError):
generate(lora_llm, (bad_prompt, sampling_params, lora_request))
# Also test batched
batched_prompts: List[Tuple[str, SamplingParams,
Optional[LoRARequest]]] = []
for lora_id_with_bad_inputs in long_context_infos:
for lora_id, info in long_context_infos.items():
context_len = info["context_length"]
batched_prompts.extend([
(prompts_and_responses[context_len][0]["prompt"] *
(2 if lora_id == lora_id_with_bad_inputs else 1),
sampling_params,
_create_lora_request(lora_id, long_context_infos))
])
# Turn good prompt into bad prompt inside of batched prompts
with pytest.raises(ValueError):
batched_generate(lora_llm, batched_prompts)

View File

@@ -0,0 +1,52 @@
from typing import List
import pytest
import vllm
from vllm.lora.request import LoRARequest
MODEL_PATH = "ibm-granite/granite-3b-code-base"
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
prompts = [
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501
]
sampling_params = vllm.SamplingParams(temperature=0,
max_tokens=256,
stop=["[/assistant]"])
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
generated_texts: List[str] = []
for output in outputs:
generated_text = output.outputs[0].text
generated_texts.append(generated_text)
return generated_texts
@pytest.mark.parametrize("lora_bias", [True])
@pytest.mark.parametrize("fully_sharded", [True, False])
def test_lora_bias(lora_bias_files: str, lora_bias: bool, fully_sharded: bool):
llm = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_lora_rank=8,
max_loras=1,
enable_lora_bias=lora_bias,
tensor_parallel_size=1,
fully_sharded_loras=fully_sharded)
print("lora adapter created")
output1 = do_sample(llm, lora_bias_files, lora_id=0)
print("lora")
output2 = do_sample(llm, lora_bias_files, lora_id=1)
if lora_bias:
assert output1 != output2
else:
assert output1 == output2

View File

@@ -0,0 +1,73 @@
from typing import List
import pytest
from vllm.lora.models import LoRAModel
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
lora_lst = [
"baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b"
]
@pytest.mark.parametrize("lora_name", lora_lst)
def test_load_checkpoints(
lora_name,
baichuan_lora_files,
baichuan_zero_lora_files,
baichuan_regex_lora_files,
chatglm3_lora_files,
):
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
expected_lora_modules: List[str] = []
for module in supported_lora_modules:
if module in packed_modules_mapping:
expected_lora_modules.extend(packed_modules_mapping[module])
else:
expected_lora_modules.append(module)
if lora_name == "baichuan7B":
# For the baichuan7B model, load it's LoRA,
# and the test should pass.
LoRAModel.from_local_checkpoint(
baichuan_lora_files,
expected_lora_modules,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules)
elif lora_name == "baichuan7B-zero":
# Test that the target_modules contain prefix
# such as "model.layers.0.self_atten.W_pack", and
# the test should pass.
LoRAModel.from_local_checkpoint(
baichuan_zero_lora_files,
expected_lora_modules,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules)
elif lora_name == "baichuan7B-zero-regex":
# Test that the `target_modules` in the form of regular expressions,
# such as `model\\..*(W_pack|o_proj)`, and the test should pass.
LoRAModel.from_local_checkpoint(
baichuan_regex_lora_files,
expected_lora_modules,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules)
else:
# For the baichuan7B model, load chatglm3-6b's LoRA,
# and the test should raise the following error.
expected_error = "Please verify that the loaded LoRA module is correct" # noqa: E501
with pytest.raises(ValueError, match=expected_error):
LoRAModel.from_local_checkpoint(
chatglm3_lora_files,
expected_lora_modules,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules)

View File

@@ -0,0 +1,39 @@
from typing import List
import pytest
from vllm.lora.models import LoRAModel
from vllm.lora.utils import get_adapter_absolute_path
from vllm.model_executor.models.llama import LlamaForCausalLM
# Provide absolute path and huggingface lora ids
lora_fixture_name = ["sql_lora_files", "sql_lora_huggingface_id"]
@pytest.mark.parametrize("lora_fixture_name", lora_fixture_name)
def test_load_checkpoints_from_huggingface(lora_fixture_name, request):
lora_name = request.getfixturevalue(lora_fixture_name)
supported_lora_modules = LlamaForCausalLM.supported_lora_modules
packed_modules_mapping = LlamaForCausalLM.packed_modules_mapping
embedding_modules = LlamaForCausalLM.embedding_modules
embed_padding_modules = LlamaForCausalLM.embedding_padding_modules
expected_lora_modules: List[str] = []
for module in supported_lora_modules:
if module in packed_modules_mapping:
expected_lora_modules.extend(packed_modules_mapping[module])
else:
expected_lora_modules.append(module)
lora_path = get_adapter_absolute_path(lora_name)
# lora loading should work for either absolute path and hugggingface id.
lora_model = LoRAModel.from_local_checkpoint(
lora_path,
expected_lora_modules,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules)
# Assertions to ensure the model is loaded correctly
assert lora_model is not None, "LoRAModel is not loaded correctly"

View File

@@ -0,0 +1,637 @@
import os
from typing import Dict, List
import pytest
import torch
from safetensors.torch import load_file
from torch import nn
from vllm.config import LoRAConfig
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA,
RowParallelLinearWithLoRA)
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.models import (LoRAMapping, LoRAModel, LoRAModelManager,
LRUCacheLoRAModelManager)
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
WorkerLoRAManager)
from vllm.model_executor.layers.linear import RowParallelLinear
EMBEDDING_MODULES = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
EMBEDDING_PADDING_MODULES = ["lm_head"]
'''
=============================
Modify by vllm_mlu
=============================
@brief: need use mlu type for device check
'''
CUDA_DEVICES = [
f"mlu:{i}" for i in range(1)
]
'''
==================
End of MLU Hijack
==================
'''
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_from_lora_tensors(sql_lora_files, device):
tensors = load_file(
os.path.join(sql_lora_files, "adapter_model.safetensors"))
new_embeddings = load_file(
os.path.join(sql_lora_files, "new_embeddings.safetensors"))
lora_model = LoRAModel.from_lora_tensors(
1,
8,
16,
tensors,
device,
embeddings=new_embeddings,
embedding_modules=EMBEDDING_MODULES,
embedding_padding_modules=EMBEDDING_PADDING_MODULES)
for module_name, lora in lora_model.loras.items():
assert lora.module_name == module_name
assert lora.rank == 8
assert lora.lora_alpha == 16
assert lora.lora_a is not None
assert lora.lora_b is not None
assert lora.lora_a.device == torch.device(device)
assert lora.lora_b.device == torch.device(device)
assert (lora.lora_a.shape[1] == lora.lora_b.shape[0]
), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
assert lora.lora_a.shape[1] == 8
embeddings_module = next(
(k for k in EMBEDDING_MODULES if k in module_name), None)
if embeddings_module:
assert torch.equal(
lora.embeddings_tensor,
new_embeddings[EMBEDDING_MODULES[embeddings_module]].to(
device=lora.embeddings_tensor.device))
else:
assert lora.embeddings_tensor is None
def create_lora(lora_id: int, model: nn.Module, sub_modules: List[str],
device: torch.device) -> LoRAModel:
loras: Dict[str, LoRALayerWeights] = {}
for name in sub_modules:
w = model.get_submodule(name).weight
loras[name] = LoRALayerWeights(
name,
8,
16,
torch.rand([w.shape[1], 8], device=device),
torch.rand([8, w.shape[0]], device=device),
)
return LoRAModel(lora_id, 8, loras)
def create_packed_lora(
lora_id: int,
model: nn.Module,
module_name,
replaced_module_names,
device: torch.device,
empty_replaced_module_name=None,
) -> LoRAModel:
w = model.get_submodule(module_name).weight
loras: Dict[str, LoRALayerWeights] = {}
for replaced_module_name in replaced_module_names:
if replaced_module_name == empty_replaced_module_name:
continue
loras[replaced_module_name] = LoRALayerWeights(
replaced_module_name,
8,
16,
torch.rand([w.shape[1], 8], device=device),
torch.rand([8, w.shape[0] // len(replaced_module_names)],
device=device),
)
return LoRAModel(lora_id, 8, loras)
def test_replace_submodules(dist_init, dummy_model):
model = dummy_model
model.supported_lora_modules = ["dense1", "layer1.dense2"]
model.packed_modules_mapping = {}
manager = LoRAModelManager(
model, 1, 1, 1,
LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8),
torch.device("cuda"))
model = manager.model
assert isinstance(model.get_submodule("dense1"),
ColumnParallelLinearWithLoRA)
assert isinstance(model.get_submodule("layer1.dense1"),
ColumnParallelLinearWithLoRA)
assert isinstance(model.get_submodule("dense2"), RowParallelLinear)
assert isinstance(model.get_submodule("layer1.dense2"),
RowParallelLinearWithLoRA)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_lora_model_manager(dist_init, dummy_model, device):
model = dummy_model
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
model.packed_modules_mapping = {}
model_lora1 = create_lora(1,
model, ["layer1.dense1", "dense2", "lm_head"],
device=device)
model_lora2 = create_lora(2,
model, ["dense1", "dense2", "lm_head"],
device=device)
model_lora3 = create_lora(3,
model, ["dense1", "dense2", "lm_head"],
device=device)
manager = LoRAModelManager(model,
2,
2,
2,
LoRAConfig(max_lora_rank=8,
max_cpu_loras=3,
max_loras=2),
device=device)
assert all(x is None for x in manager.lora_index_to_id)
assert manager.add_adapter(model_lora1)
assert manager.activate_adapter(1)
assert manager.lora_index_to_id[0] == 1
assert not manager.add_adapter(model_lora1)
assert not manager.activate_adapter(1)
assert manager.add_adapter(model_lora2)
assert manager.activate_adapter(2)
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] == 2
assert not manager.add_adapter(model_lora2)
assert not manager.activate_adapter(2)
assert manager.add_adapter(model_lora3)
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] == 2
with pytest.raises(ValueError):
assert manager.activate_adapter(3)
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] == 2
assert manager.remove_adapter(model_lora2.id)
assert manager.lora_index_to_id[1] is None
assert not manager.remove_adapter(model_lora2.id)
assert manager.remove_adapter(model_lora1.id)
assert not manager.remove_adapter(model_lora1.id)
assert manager.add_adapter(model_lora1)
assert manager.lora_index_to_id[0] is None
assert manager.lora_index_to_id[1] is None
assert manager.add_adapter(model_lora2)
assert manager.activate_adapter(3)
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] is None
assert manager.activate_adapter(2)
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 2
assert manager.device == device
assert manager.punica_wrapper.device == device
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
model = dummy_model
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
model.packed_modules_mapping = {}
model_lora1 = create_lora(1,
model, ["layer1.dense1", "dense2", "lm_head"],
device=device)
model_lora2 = create_lora(2,
model, ["dense1", "dense2", "lm_head"],
device=device)
model_lora3 = create_lora(3,
model, ["dense1", "dense2", "lm_head"],
device=device)
manager = LRUCacheLoRAModelManager(model,
2,
2,
2,
LoRAConfig(max_lora_rank=8,
max_cpu_loras=3,
max_loras=2),
device=device)
assert all(x is None for x in manager.lora_index_to_id)
assert manager.add_adapter(model_lora1)
assert manager.activate_adapter(1)
assert manager.lora_index_to_id[0] == 1
assert not manager.add_adapter(model_lora1)
assert not manager.activate_adapter(1)
assert manager.add_adapter(model_lora2)
assert manager.activate_adapter(2)
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] == 2
assert not manager.add_adapter(model_lora2)
assert not manager.activate_adapter(2)
assert manager.add_adapter(model_lora3)
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] == 2
assert manager.activate_adapter(3)
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 2
assert manager.remove_adapter(model_lora2.id)
assert manager.lora_index_to_id[1] is None
assert not manager.remove_adapter(model_lora2.id)
assert manager.remove_adapter(model_lora1.id)
assert not manager.remove_adapter(model_lora1.id)
assert manager.add_adapter(model_lora1)
assert manager.activate_adapter(1)
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 1
assert manager.add_adapter(model_lora2)
assert manager.deactivate_adapter(3)
assert manager.lora_index_to_id[0] is None
assert manager.lora_index_to_id[1] == 1
assert manager.activate_adapter(2)
assert manager.lora_index_to_id[0] == 2
assert manager.lora_index_to_id[1] == 1
assert manager.activate_adapter(3)
assert manager.lora_index_to_id[0] == 2
assert manager.lora_index_to_id[1] == 3
assert manager.pin_adapter(2)
assert manager.lora_index_to_id[0] == 2
assert manager.lora_index_to_id[1] == 3
assert manager.activate_adapter(1)
assert manager.lora_index_to_id[0] == 2
assert manager.lora_index_to_id[1] == 1
assert manager.deactivate_adapter(2)
assert manager.lora_index_to_id[0] is None
assert manager.lora_index_to_id[1] == 1
assert manager.activate_adapter(3)
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 1
assert manager.pin_adapter(3)
assert manager.pin_adapter(1)
with pytest.raises(RuntimeError):
assert manager.pin_adapter(2)
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 1
with pytest.raises(RuntimeError):
assert manager.activate_adapter(2)
assert manager.deactivate_adapter(3)
assert manager.pin_adapter(2)
assert manager.lora_index_to_id[0] == 2
assert manager.lora_index_to_id[1] == 1
assert manager.remove_adapter(3)
with pytest.raises(ValueError):
assert manager.pin_adapter(3)
assert manager.punica_wrapper.device == device
assert manager.device == device
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_lru_lora_model_manager(dist_init, dummy_model, device):
# This tests just the LRU cache functionality, everything else is
# tested in test_lora_model_manager
model = dummy_model
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
model.packed_modules_mapping = {}
model_lora1 = create_lora(1,
model, ["layer1.dense1", "dense2", "lm_head"],
device=device)
model_lora2 = create_lora(2,
model, ["dense1", "dense2", "lm_head"],
device=device)
model_lora3 = create_lora(3,
model, ["dense1", "dense2", "lm_head"],
device=device)
model_lora4 = create_lora(4,
model, ["dense1", "dense2", "lm_head"],
device=device)
manager = LRUCacheLoRAModelManager(model,
2,
2,
2,
LoRAConfig(max_lora_rank=8,
max_cpu_loras=2,
max_loras=2),
device=device)
assert all(x is None for x in manager.lora_index_to_id)
# Add up to capacity
assert manager.add_adapter(model_lora1)
assert manager.add_adapter(model_lora2)
assert manager.activate_adapter(1)
assert manager.activate_adapter(2)
assert set(manager.list_adapters()) == {1, 2}
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] == 2
# Add over capacity
assert manager.add_adapter(model_lora3)
assert manager.add_adapter(model_lora4)
assert manager.activate_adapter(3)
assert manager.activate_adapter(4)
assert set(manager.list_adapters()) == {3, 4}
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 4
# Add 3 again to move it to the top and then add 2
# should return false since it's in already
assert not manager.add_adapter(model_lora3)
assert not manager.activate_adapter(3)
assert manager.add_adapter(model_lora2)
assert manager.activate_adapter(2)
assert set(manager.list_adapters()) == {3, 2}
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 2
# Remove manually
assert manager.remove_adapter(3)
assert not manager.remove_adapter(3)
assert set(manager.list_adapters()) == {2}
assert manager.lora_index_to_id[0] is None
assert manager.lora_index_to_id[1] == 2
assert manager.add_adapter(model_lora3)
assert manager.activate_adapter(3)
assert manager.add_adapter(model_lora4)
assert manager.activate_adapter(4)
assert set(manager.list_adapters()) == {3, 4}
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 4
assert manager.remove_oldest_adapter()
assert set(manager.list_adapters()) == {4}
assert manager.lora_index_to_id[0] is None
assert manager.lora_index_to_id[1] == 4
assert manager.remove_oldest_adapter()
assert set(manager.list_adapters()) == set()
assert all(x is None for x in manager.lora_index_to_id)
assert not manager.remove_oldest_adapter()
assert set(manager.list_adapters()) == set()
assert all(x is None for x in manager.lora_index_to_id)
# pinning
assert manager.add_adapter(model_lora3)
assert manager.activate_adapter(3)
assert manager.add_adapter(model_lora4)
assert manager.activate_adapter(4)
assert set(manager.list_adapters()) == {3, 4}
with pytest.raises(ValueError):
assert manager.pin_adapter(1)
assert manager.pin_adapter(3)
# Remove manually
assert manager.remove_adapter(3)
assert not manager.remove_adapter(3)
assert set(manager.list_adapters()) == {4}
assert manager.lora_index_to_id[0] is None
assert manager.lora_index_to_id[1] == 4
assert manager.add_adapter(model_lora1)
assert manager.pin_adapter(1)
assert manager.add_adapter(model_lora2)
assert manager.activate_adapter(2)
assert set(manager.list_adapters()) == {1, 2}
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] == 2
assert manager.remove_oldest_adapter()
assert set(manager.list_adapters()) == {1}
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] is None
with pytest.raises(RuntimeError):
assert manager.remove_oldest_adapter()
assert set(manager.list_adapters()) == {1}
assert manager.punica_wrapper.device == device
assert manager.device == device
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
sql_lora_files, device):
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
worker_adapter_manager = LRUCacheWorkerLoRAManager(
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
lora_config.lora_extra_vocab_size, lora_config, device,
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
worker_adapter_manager.create_lora_manager(
llama_2_7b_model_extra_embeddings)
mapping = LoRAMapping([], [])
worker_adapter_manager.set_active_adapters([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("2", 2, sql_lora_files)
], mapping)
assert worker_adapter_manager.list_adapters() == {1, 2}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
worker_adapter_manager.set_active_adapters([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("3", 3, sql_lora_files),
LoRARequest("4", 4, sql_lora_files)
], mapping)
assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 3
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
worker_adapter_manager.set_active_adapters([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("2", 2, sql_lora_files),
LoRARequest("5", 5, sql_lora_files)
], mapping)
assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
worker_adapter_manager.set_active_adapters([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("1", 1, sql_lora_files),
LoRARequest("1", 1, sql_lora_files)
], mapping)
assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
worker_adapter_manager.set_active_adapters([
LoRARequest("6", 6, sql_lora_files),
LoRARequest("7", 7, sql_lora_files),
LoRARequest("8", 8, sql_lora_files)
], mapping)
assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 7
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 8
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 6
# Over capacity
with pytest.raises(RuntimeError):
worker_adapter_manager.set_active_adapters([
LoRARequest("10", 10, sql_lora_files),
LoRARequest("11", 11, sql_lora_files),
LoRARequest("12", 12, sql_lora_files),
LoRARequest("13", 13, sql_lora_files),
LoRARequest("14", 14, sql_lora_files)
], mapping)
assert worker_adapter_manager.device == device
assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
device)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
sql_lora_files, device):
# Should remove every LoRA not specified in the request.
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
worker_adapter_manager = WorkerLoRAManager(
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
lora_config.lora_extra_vocab_size, lora_config, device,
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
worker_adapter_manager.create_lora_manager(
llama_2_7b_model_extra_embeddings)
mapping = LoRAMapping([], [])
worker_adapter_manager.set_active_adapters([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("2", 2, sql_lora_files)
], mapping)
assert worker_adapter_manager.list_adapters() == {1, 2}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
worker_adapter_manager.set_active_adapters([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("3", 3, sql_lora_files),
LoRARequest("4", 4, sql_lora_files)
], mapping)
assert worker_adapter_manager.list_adapters() == {1, 3, 4}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 3
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4
worker_adapter_manager.set_active_adapters([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("2", 2, sql_lora_files),
LoRARequest("5", 5, sql_lora_files)
], mapping)
assert worker_adapter_manager.list_adapters() == {1, 2, 5}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
worker_adapter_manager.set_active_adapters([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("1", 1, sql_lora_files),
LoRARequest("1", 1, sql_lora_files)
], mapping)
assert worker_adapter_manager.list_adapters() == {1}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] is None
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None
worker_adapter_manager.set_active_adapters([
LoRARequest("6", 6, sql_lora_files),
LoRARequest("7", 7, sql_lora_files),
LoRARequest("8", 8, sql_lora_files)
], mapping)
assert worker_adapter_manager.list_adapters() == {6, 7, 8}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 6
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 7
# Over capacity
with pytest.raises(RuntimeError):
worker_adapter_manager.set_active_adapters([
LoRARequest("10", 10, sql_lora_files),
LoRARequest("11", 11, sql_lora_files),
LoRARequest("12", 12, sql_lora_files),
LoRARequest("13", 13, sql_lora_files),
LoRARequest("14", 14, sql_lora_files)
], mapping)
assert worker_adapter_manager.device == device
assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
device)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_packed_loras(dist_init, dummy_model_gate_up, device):
model = dummy_model_gate_up
model.supported_lora_modules = ["gate_up_proj"]
model.packed_modules_mapping = {
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
model_lora = create_packed_lora(
1,
model,
module_name="gate_up_proj",
replaced_module_names=["gate_proj", "up_proj"],
device=device)
model_lora1 = create_packed_lora(
2,
model,
module_name="gate_up_proj",
replaced_module_names=["gate_proj", "up_proj"],
device=device,
empty_replaced_module_name="gate_proj",
)
manager = LoRAModelManager(model,
2,
2,
2,
LoRAConfig(max_lora_rank=8,
max_cpu_loras=2,
max_loras=2),
device=device)
model = manager.model
assert isinstance(model.get_submodule("gate_up_proj"),
MergedColumnParallelLinearWithLoRA)
assert manager.add_adapter(model_lora)
assert manager.add_adapter(model_lora1)
packed_lora = model_lora.get_lora("gate_up_proj")
assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)
torch.testing.assert_close(packed_lora.lora_a[0],
model_lora.get_lora("gate_proj").lora_a)
torch.testing.assert_close(packed_lora.lora_b[0],
model_lora.get_lora("gate_proj").lora_b)
torch.testing.assert_close(packed_lora.lora_a[1],
model_lora.get_lora("up_proj").lora_a)
torch.testing.assert_close(packed_lora.lora_b[1],
model_lora.get_lora("up_proj").lora_b)
packed_lora1 = model_lora1.get_lora("gate_up_proj")
assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights)
assert packed_lora1.lora_a[0] is None
assert packed_lora1.lora_b[0] is None
torch.testing.assert_close(packed_lora1.lora_a[1],
model_lora1.get_lora("up_proj").lora_a)
torch.testing.assert_close(packed_lora1.lora_b[1],
model_lora1.get_lora("up_proj").lora_b)

View File

@@ -0,0 +1,77 @@
from typing import List
import pytest
import vllm
from vllm.assets.image import ImageAsset
from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform
MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5"
PROMPT_TEMPLATE = (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
"(<image>./</image>)\nWhat is in the image?<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n")
IMAGE_ASSETS = [
ImageAsset("stop_sign"),
ImageAsset("cherry_blossom"),
]
# After fine-tuning with LoRA, all generated content should start begin `A`.
EXPECTED_OUTPUT = [
"A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents.", # noqa: E501
"A pink cherry blossom tree with a blue sky in the background.",
]
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
sampling_params = vllm.SamplingParams(
temperature=0,
max_tokens=5,
stop_token_ids=[128001, 128009], # eos_id, eot_id
)
inputs = [{
"prompt": PROMPT_TEMPLATE,
"multi_modal_data": {
"image": asset.pil_image
},
} for asset in IMAGE_ASSETS]
outputs = llm.generate(
inputs,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None,
)
# Print the outputs.
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
@pytest.mark.xfail(
current_platform.is_rocm(),
reason="MiniCPM-V dependency xformers incompatible with ROCm")
def test_minicpmv_lora(minicpmv_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_num_seqs=2,
enable_lora=True,
max_loras=4,
max_lora_rank=64,
trust_remote_code=True,
gpu_memory_utilization=0.97 # This model is pretty big for CI gpus
)
output1 = do_sample(llm, minicpmv_lora_files, lora_id=1)
for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output1[i])
output2 = do_sample(llm, minicpmv_lora_files, lora_id=2)
for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output2[i])

View File

@@ -0,0 +1,95 @@
from typing import List
import pytest
import vllm
from vllm.assets.image import ImageAsset
from vllm.lora.request import LoRARequest
from ..utils import multi_gpu_test
MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5"
PROMPT_TEMPLATE = (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
"(<image>./</image>)\nWhat is in the image?<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n")
IMAGE_ASSETS = [
ImageAsset("stop_sign"),
ImageAsset("cherry_blossom"),
]
# After fine-tuning with LoRA, all generated content should start begin `A`.
EXPECTED_OUTPUT = [
"A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents.", # noqa: E501
"A pink cherry blossom tree with a blue sky in the background.",
]
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
sampling_params = vllm.SamplingParams(
temperature=0,
max_tokens=5,
stop_token_ids=[128001, 128009], # eos_id, eot_id
)
inputs = [{
"prompt": PROMPT_TEMPLATE,
"multi_modal_data": {
"image": asset.pil_image
},
} for asset in IMAGE_ASSETS]
outputs = llm.generate(
inputs,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None,
)
# Print the outputs.
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("fully_sharded", [True, False])
def test_minicpmv_tp2(minicpmv_lora_files, fully_sharded):
llm = vllm.LLM(
MODEL_PATH,
enable_lora=True,
max_num_seqs=2,
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=2,
trust_remote_code=True,
fully_sharded_loras=fully_sharded,
)
output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)
for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output_tp[i])
@multi_gpu_test(num_gpus=4)
@pytest.mark.parametrize("fully_sharded", [True, False])
def test_minicpmv_tp4(minicpmv_lora_files, fully_sharded):
llm = vllm.LLM(
MODEL_PATH,
enable_lora=True,
max_num_seqs=2,
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=4,
trust_remote_code=True,
fully_sharded_loras=fully_sharded,
)
output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)
for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output_tp[i])

View File

@@ -0,0 +1,100 @@
from typing import List
import pytest
import torch
import vllm
from vllm.lora.request import LoRARequest
MODEL_PATH = "mistralai/Mixtral-8x7B-Instruct-v0.1"
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int,
prompts: List[str]) -> List[str]:
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
@pytest.mark.parametrize("tp_size", [4])
def test_mixtral_lora(mixtral_lora_files, tp_size):
"""Original test, the LoRA model has the common target modules, not all"""
if torch.cuda.device_count() < tp_size:
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
prompts = [
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nSpellForce 3 is a pretty bad game. The developer Grimlore Games is clearly a bunch of no-talent hacks, and 2017 was a terrible year for games anyway. [/user] [assistant]", # noqa: E501
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nI wanted to like Grimlore Games' 2017 entry, but in SpellForce 3 they just didn't get anything right. [/user] [assistant]", # noqa: E501
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nBioShock is a good role-playing, action-adventure, shooter that released for PlayStation, Xbox, and PC in 2007. It is available on Steam, and it has a Mac release but not a Linux release. [/user] [assistant]", # noqa: E501
]
llm = vllm.LLM(
MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
distributed_executor_backend="ray",
tensor_parallel_size=tp_size,
)
expected_lora_output = [
"give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])", # noqa: E501
"give_opinion(name[SpellForce 3], developer[Grimlore Games], release_year[2017], rating[poor])", # noqa: E501
"inform(name[BioShock], release_year[2007], rating[good], genres[action-adventure, role-playing, shooter], platforms[PlayStation, Xbox, PC], available_on_steam[yes], has_linux_release[no], has_mac_release[yes])", # noqa: E501
]
assert do_sample(llm, mixtral_lora_files, lora_id=1,
prompts=prompts) == expected_lora_output
assert do_sample(llm, mixtral_lora_files, lora_id=2,
prompts=prompts) == expected_lora_output
@pytest.mark.parametrize("tp_size", [4])
def test_mixtral_lora_all_target_modules(mixtral_lora_files_all_target_modules,
tp_size):
"""This LoRA model has all supported Mixtral target modules"""
if torch.cuda.device_count() < tp_size:
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
prompts = [
"Question: What happens if you touch the eyes of a blind man?\nChoices:\nB: The man's sight will be restored.\nA: Nothing happens if you touch the eyes of a blind man.\nC: The man will get his eyesight back.\nD: The man will no longer be blind.\nAnswer:", # noqa: E501
"Question: Which will increase the energy of motion of water molecules?\nChoices:\nA: add heat\nB: add ice\nC: remove some of the water\nD: place the water in the freezer\nAnswer:", # noqa: E501
"Since Craig threw aluminum cans in the trash and Benjamin recycled, _ was environmentally irresponsible.\nChoices:\n1: Craig\n2: Benjamin\nAnswer:", # noqa: E501
]
llm = vllm.LLM(
MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
distributed_executor_backend="ray",
tensor_parallel_size=tp_size,
max_lora_rank=32,
)
expected_lora_output = [
"A: Nothing happens if you touch the eyes of a blind man.",
"A: add heat",
"1: Craig",
]
assert do_sample(llm,
mixtral_lora_files_all_target_modules,
lora_id=1,
prompts=prompts) == expected_lora_output
assert do_sample(llm,
mixtral_lora_files_all_target_modules,
lora_id=2,
prompts=prompts) == expected_lora_output

View File

@@ -0,0 +1,69 @@
from typing import List
import vllm
from vllm.lora.request import LoRARequest
MODEL_PATH = "microsoft/phi-2"
PROMPT_TEMPLATE = "### Instruct: {sql_prompt}\n\n### Context: {context}\n\n### Output:" # noqa: E501
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
prompts = [
PROMPT_TEMPLATE.format(
sql_prompt=
"Which catalog publisher has published the most catalogs?",
context="CREATE TABLE catalogs (catalog_publisher VARCHAR);"),
PROMPT_TEMPLATE.format(
sql_prompt=
"Which trip started from the station with the largest dock count? Give me the trip id.", # noqa: E501
context=
"CREATE TABLE trip (id VARCHAR, start_station_id VARCHAR); CREATE TABLE station (id VARCHAR, dock_count VARCHAR);" # noqa: E501
),
PROMPT_TEMPLATE.format(
sql_prompt=
"How many marine species are found in the Southern Ocean?", # noqa: E501
context=
"CREATE TABLE marine_species (name VARCHAR(50), common_name VARCHAR(50), location VARCHAR(50));" # noqa: E501
),
]
sampling_params = vllm.SamplingParams(temperature=0,
max_tokens=64,
stop="### End")
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None,
)
# Print the outputs.
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
def test_phi2_lora(phi2_lora_files):
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
# Otherwise, the lora-test will fail due to CUDA OOM.
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=2,
enforce_eager=True)
expected_lora_output = [
"SELECT catalog_publisher, COUNT(*) as num_catalogs FROM catalogs GROUP BY catalog_publisher ORDER BY num_catalogs DESC LIMIT 1;", # noqa: E501
"SELECT trip.id FROM trip JOIN station ON trip.start_station_id = station.id WHERE station.dock_count = (SELECT MAX(dock_count) FROM station);", # noqa: E501
"SELECT COUNT(*) FROM marine_species WHERE location = 'Southern Ocean';", # noqa: E501
]
output1 = do_sample(llm, phi2_lora_files, lora_id=1)
for i in range(len(expected_lora_output)):
assert output1[i].startswith(expected_lora_output[i])
output2 = do_sample(llm, phi2_lora_files, lora_id=2)
for i in range(len(expected_lora_output)):
assert output2[i].startswith(expected_lora_output[i])

View File

@@ -0,0 +1,395 @@
"""
This script is mainly used to tests various hidden_sizes. We have collected the
hidden_sizes included in the LoRA models currently supported by vLLM. It tests
whether the corresponding Triton kernel can run normally when tensor parallelism
is set to [1, 2, 4, 8, 16, 32, 64].
"""
import pytest
import torch
from vllm.lora.ops.bgmv_expand import bgmv_expand
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
from vllm.lora.ops.bgmv_shrink import bgmv_shrink
from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
from vllm.platforms import current_platform
'''
=============================
Modify by vllm_mlu
=============================
@brief: use mlu sgmv functions
'''
if current_platform.is_mlu():
from vllm_mlu.lora.ops.sgmv_expand import sgmv_expand_mlu as sgmv_expand
from vllm_mlu.lora.ops.sgmv_expand_slice import sgmv_expand_slice_mlu as sgmv_expand_slice
from vllm_mlu.lora.ops.sgmv_shrink import sgmv_shrink_mlu as sgmv_shrink
'''
==================
End of MLU Hijack
==================
'''
from .utils import (generate_data, generate_data_for_expand_nslices,
ref_torch_groupgemm)
HIDDEN_SIZES = [
128,
256,
512,
896,
1024,
1152,
1216,
1280,
1536,
1664,
2048,
2240,
2304,
2368,
2432,
2560,
2752,
3072,
3328,
3456,
3584,
3712,
4096,
4480,
4608,
4736,
4864,
5120,
5504,
5632,
5888,
6144,
6400,
6848,
6912,
7168,
7424,
8192,
8960,
9216,
9472,
10240,
11008,
11264,
13824,
14336,
14784,
14848,
15360,
18944,
22016,
22528,
24576,
27392,
27648,
29568,
29696,
32000,
32256,
32512,
32768,
33024,
36864,
43264,
49152,
49408,
60544,
60672,
64000,
64256,
102400,
102656,
128000,
128256,
]
#The size of TP
divisibility = [1, 2, 8, 16, 64]
all_hidden_size = []
for div in divisibility:
for hidden_size in HIDDEN_SIZES:
all_hidden_size.append(hidden_size // div)
HIDDEN_SIZES = list(set(all_hidden_size))
BATCHES = [4]
NUM_LORA = [4]
DTYPES = [torch.float16, torch.bfloat16]
MAX_RANKS = [32]
SCALES = [0.5]
SEED = [0]
CUDA_DEVICES = [f"cuda:{0}"]
def assert_close(a, b):
rtol, atol = {
torch.float16: (6e-2, 6e-2),
torch.bfloat16: (6e-2, 6e-2),
torch.float32: (1e-2, 1e-2),
}[a.dtype]
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
@pytest.mark.parametrize("batches", BATCHES)
@pytest.mark.parametrize("num_loras", NUM_LORA)
@pytest.mark.parametrize("rank", MAX_RANKS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("scaling", SCALES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_punica_sgmv(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
scaling: float,
dtype: torch.dtype,
op_type: str,
seed: int,
device: str,
):
torch.set_default_device(device)
current_platform.seed_everything(seed)
seq_length = 128
(
inputs_tensor,
lora_weights,
our_out_tensor,
ref_out_tensor,
b_seq_start_loc,
lora_indices_tensor,
seq_len_tensor,
indices,
) = generate_data(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
op_type,
device,
)
max_seq_length = seq_len_tensor.max()
token_nums = seq_len_tensor.sum().item()
if isinstance(max_seq_length, tuple):
max_seq_length = max_seq_length[0].item()
else:
max_seq_length = max_seq_length.item()
if op_type == "shrink":
sgmv_shrink(
inputs_tensor,
lora_weights,
our_out_tensor,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
scaling,
)
else:
sgmv_expand(
inputs_tensor,
lora_weights,
our_out_tensor,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
add_inputs=True,
)
ref_torch_groupgemm(
ref_out_tensor,
inputs_tensor,
lora_weights,
lora_indices_tensor,
seq_len_tensor,
batches,
scaling if op_type == "shrink" else 1.0,
op_type,
)
if op_type == "shrink":
ref_out_tensor = ref_out_tensor.to(torch.float32)
assert_close(our_out_tensor, ref_out_tensor)
@pytest.mark.parametrize("batches", BATCHES)
@pytest.mark.parametrize("num_loras", NUM_LORA)
@pytest.mark.parametrize("rank", MAX_RANKS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("scaling", SCALES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_punica_bgmv(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
scaling: float,
dtype: torch.dtype,
op_type: str,
seed: int,
device: str,
):
torch.set_default_device(device)
current_platform.seed_everything(seed)
seq_length = 1
(
inputs_tensor,
lora_weights,
our_out_tensor,
ref_out_tensor,
b_seq_start_loc,
lora_indices_tensor,
seq_len_tensor,
indices,
) = generate_data(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
op_type,
device,
)
if op_type == "shrink":
bgmv_shrink(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
scaling,
)
else:
bgmv_expand(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
add_inputs=True,
)
ref_torch_groupgemm(
ref_out_tensor,
inputs_tensor,
lora_weights,
lora_indices_tensor,
seq_len_tensor,
batches,
scaling if op_type == "shrink" else 1.0,
op_type,
)
if op_type == "shrink":
ref_out_tensor = ref_out_tensor.to(torch.float32)
assert_close(our_out_tensor, ref_out_tensor)
@pytest.mark.parametrize("batches", BATCHES)
@pytest.mark.parametrize("num_loras", NUM_LORA)
@pytest.mark.parametrize("rank", MAX_RANKS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("nslices", [2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["sgmv", "bgmv"])
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_punica_expand_nslices(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
nslices: int,
dtype: torch.dtype,
op_type: str,
seed: int,
device: str,
):
torch.set_default_device(device)
current_platform.seed_everything(seed)
seq_length = 128 if op_type == "sgmv" else 1
(
inputs_tensor,
lora_weights_lst,
our_outputs,
ref_outputs,
b_seq_start_loc,
lora_indices_tensor,
seq_len_tensor,
indices,
) = generate_data_for_expand_nslices(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
nslices,
device,
)
max_seq_length = seq_len_tensor.max()
token_nums = seq_len_tensor.sum().item()
if isinstance(max_seq_length, tuple):
max_seq_length = max_seq_length[0].item()
else:
max_seq_length = max_seq_length.item()
slice_offset = 0
for index in range(nslices):
lora_weights = lora_weights_lst[index]
if op_type == "sgmv":
sgmv_expand_slice(
inputs_tensor,
lora_weights,
our_outputs,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
slice_offset,
hidden_size,
add_inputs=True,
)
else:
bgmv_expand_slice(
inputs_tensor,
lora_weights,
our_outputs,
indices,
slice_offset,
slice_size=hidden_size,
add_inputs=True,
)
ref_torch_groupgemm(
ref_outputs[:, slice_offset:slice_offset + hidden_size],
inputs_tensor,
lora_weights,
lora_indices_tensor,
seq_len_tensor,
batches,
1.0,
op_type="expand",
)
slice_offset += hidden_size
assert_close(our_outputs, ref_outputs)

View File

@@ -0,0 +1,310 @@
"""
This script is mainly used to test whether trtion kernels can run normally
under different conditions, including various batches, numbers of LoRA , and
maximum ranks.
"""
import pytest
import torch
from vllm.lora.ops.bgmv_expand import bgmv_expand
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
from vllm.lora.ops.bgmv_shrink import bgmv_shrink
from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
from vllm.platforms import current_platform
'''
=============================
Modify by vllm_mlu
=============================
@brief: use mlu sgmv functions
'''
if current_platform.is_mlu():
from vllm_mlu.lora.ops.sgmv_expand import sgmv_expand_mlu as sgmv_expand
from vllm_mlu.lora.ops.sgmv_expand_slice import sgmv_expand_slice_mlu as sgmv_expand_slice
from vllm_mlu.lora.ops.sgmv_shrink import sgmv_shrink_mlu as sgmv_shrink
'''
==================
End of MLU Hijack
==================
'''
from .utils import (generate_data, generate_data_for_expand_nslices,
ref_torch_groupgemm)
HIDDEN_SIZES = [4097]
BATCHES = [1, 4, 16, 32]
NUM_LORA = [1, 8, 32, 128]
DTYPES = [torch.float16, torch.bfloat16]
MAX_RANKS = [1, 4, 8, 16, 32, 64, 128, 256]
SCALES = [0.5]
SEED = [0]
CUDA_DEVICES = [f"cuda:{0}"]
def assert_close(a, b):
rtol, atol = {
torch.float16: (6e-2, 6e-2),
torch.bfloat16: (6e-2, 6e-2),
torch.float32: (1e-2, 1e-2),
}[a.dtype]
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
@pytest.mark.parametrize("batches", BATCHES)
@pytest.mark.parametrize("num_loras", NUM_LORA)
@pytest.mark.parametrize("rank", MAX_RANKS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("scaling", SCALES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_punica_sgmv(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
scaling: float,
dtype: torch.dtype,
op_type: str,
seed: int,
device: str,
):
torch.set_default_device(device)
current_platform.seed_everything(seed)
seq_length = 128
(
inputs_tensor,
lora_weights,
our_out_tensor,
ref_out_tensor,
b_seq_start_loc,
lora_indices_tensor,
seq_len_tensor,
indices,
) = generate_data(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
op_type,
device,
)
max_seq_length = seq_len_tensor.max()
token_nums = seq_len_tensor.sum().item()
if isinstance(max_seq_length, tuple):
max_seq_length = max_seq_length[0].item()
else:
max_seq_length = max_seq_length.item()
if op_type == "shrink":
sgmv_shrink(
inputs_tensor,
lora_weights,
our_out_tensor,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
scaling,
)
else:
sgmv_expand(
inputs_tensor,
lora_weights,
our_out_tensor,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
add_inputs=True,
)
ref_torch_groupgemm(
ref_out_tensor,
inputs_tensor,
lora_weights,
lora_indices_tensor,
seq_len_tensor,
batches,
scaling if op_type == "shrink" else 1.0,
op_type,
)
if op_type == "shrink":
ref_out_tensor = ref_out_tensor.to(torch.float32)
assert_close(our_out_tensor, ref_out_tensor)
@pytest.mark.parametrize("batches", BATCHES)
@pytest.mark.parametrize("num_loras", NUM_LORA)
@pytest.mark.parametrize("rank", MAX_RANKS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("scaling", SCALES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_punica_bgmv(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
scaling: float,
dtype: torch.dtype,
op_type: str,
seed: int,
device: str,
):
torch.set_default_device(device)
current_platform.seed_everything(seed)
seq_length = 1
(
inputs_tensor,
lora_weights,
our_out_tensor,
ref_out_tensor,
b_seq_start_loc,
lora_indices_tensor,
seq_len_tensor,
indices,
) = generate_data(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
op_type,
device,
)
if op_type == "shrink":
bgmv_shrink(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
scaling,
)
else:
bgmv_expand(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
add_inputs=True,
)
ref_torch_groupgemm(
ref_out_tensor,
inputs_tensor,
lora_weights,
lora_indices_tensor,
seq_len_tensor,
batches,
scaling if op_type == "shrink" else 1.0,
op_type,
)
if op_type == "shrink":
ref_out_tensor = ref_out_tensor.to(torch.float32)
assert_close(our_out_tensor, ref_out_tensor)
@pytest.mark.parametrize("batches", BATCHES)
@pytest.mark.parametrize("num_loras", NUM_LORA)
@pytest.mark.parametrize("rank", MAX_RANKS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("nslices", [2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["sgmv", "bgmv"])
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_punica_expand_nslices(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
nslices: int,
dtype: torch.dtype,
op_type: str,
seed: int,
device: str,
):
torch.set_default_device(device)
current_platform.seed_everything(seed)
seq_length = 128 if op_type == "sgmv" else 1
(
inputs_tensor,
lora_weights_lst,
our_outputs,
ref_outputs,
b_seq_start_loc,
lora_indices_tensor,
seq_len_tensor,
indices,
) = generate_data_for_expand_nslices(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
nslices,
device,
)
max_seq_length = seq_len_tensor.max()
token_nums = seq_len_tensor.sum().item()
if isinstance(max_seq_length, tuple):
max_seq_length = max_seq_length[0].item()
else:
max_seq_length = max_seq_length.item()
slice_offset = 0
for index in range(nslices):
lora_weights = lora_weights_lst[index]
if op_type == "sgmv":
sgmv_expand_slice(
inputs_tensor,
lora_weights,
our_outputs,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
slice_offset,
hidden_size,
add_inputs=True,
)
else:
bgmv_expand_slice(
inputs_tensor,
lora_weights,
our_outputs,
indices,
slice_offset,
slice_size=hidden_size,
add_inputs=True,
)
ref_torch_groupgemm(
ref_outputs[:, slice_offset:slice_offset + hidden_size],
inputs_tensor,
lora_weights,
lora_indices_tensor,
seq_len_tensor,
batches,
1.0,
op_type="expand",
)
slice_offset += hidden_size
assert_close(our_outputs, ref_outputs)

View File

@@ -0,0 +1,198 @@
# Adapted from
# https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/tests/lora/test_llama.py
from dataclasses import dataclass
from typing import List
import pytest
import vllm
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform
@dataclass
class ModelWithQuantization:
model_path: str
quantization: str
MODELS: List[ModelWithQuantization]
#AWQ quantization is currently not supported in ROCm.
if current_platform.is_rocm():
MODELS = [
ModelWithQuantization(
model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
quantization="GPTQ"),
]
else:
MODELS = [
ModelWithQuantization(
model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ",
quantization="AWQ"),
ModelWithQuantization(
model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
quantization="GPTQ"),
]
def do_sample(llm: vllm.LLM,
lora_path: str,
lora_id: int,
max_tokens: int = 256) -> List[str]:
raw_prompts = [
"Give me an orange-ish brown color",
"Give me a neon pink color",
]
def format_prompt_tuples(prompt):
return f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
prompts = [format_prompt_tuples(p) for p in raw_prompts]
sampling_params = vllm.SamplingParams(temperature=0,
max_tokens=max_tokens,
stop=["<|im_end|>"])
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tp_size", [1])
def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model,
tp_size):
if num_gpus_available < tp_size:
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
llm = vllm.LLM(
model=model.model_path,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
max_model_len=400,
tensor_parallel_size=tp_size,
gpu_memory_utilization=0.2, #avoid OOM
quantization=model.quantization,
trust_remote_code=True)
if model.quantization is None:
expected_no_lora_output = [
"Here are some examples of orange-brown colors",
"I'm sorry, I don't have"
]
expected_lora_output = [
"#ff8050",
"#ff8080",
]
elif model.quantization == "AWQ":
expected_no_lora_output = [
"I'm sorry, I don't understand",
"I'm sorry, I don't understand",
]
expected_lora_output = [
"#f07700: A v",
"#f00000: A v",
]
elif model.quantization == "GPTQ":
expected_no_lora_output = [
"I'm sorry, I don't have",
"I'm sorry, I don't have",
]
expected_lora_output = [
"#f08800: This is",
"#f07788 \n#",
]
def expect_match(output, expected_output):
# HACK: GPTQ lora outputs are just incredibly unstable.
# Assert that the outputs changed.
if (model.quantization == "GPTQ"
and expected_output is expected_lora_output):
assert output != expected_no_lora_output
for i, o in enumerate(output):
assert o.startswith(
'#'), f"Expected example {i} to start with # but got {o}"
return
assert output == expected_output
max_tokens = 10
print("lora adapter created")
output = do_sample(llm,
tinyllama_lora_files,
lora_id=0,
max_tokens=max_tokens)
expect_match(output, expected_no_lora_output)
print("lora 1")
output = do_sample(llm,
tinyllama_lora_files,
lora_id=1,
max_tokens=max_tokens)
expect_match(output, expected_lora_output)
print("no lora")
output = do_sample(llm,
tinyllama_lora_files,
lora_id=0,
max_tokens=max_tokens)
expect_match(output, expected_no_lora_output)
print("lora 2")
output = do_sample(llm,
tinyllama_lora_files,
lora_id=2,
max_tokens=max_tokens)
expect_match(output, expected_lora_output)
print("removing lora")
del llm
cleanup_dist_env_and_memory()
@pytest.mark.parametrize("model", MODELS)
def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available,
model):
if num_gpus_available < 2:
pytest.skip(f"Not enough GPUs for tensor parallelism {2}")
llm_tp1 = vllm.LLM(
model=model.model_path,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=1,
gpu_memory_utilization=0.2, #avoid OOM
quantization=model.quantization,
trust_remote_code=True)
output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1)
del llm_tp1
cleanup_dist_env_and_memory()
llm_tp2 = vllm.LLM(
model=model.model_path,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=2,
gpu_memory_utilization=0.2, #avoid OOM
quantization=model.quantization)
output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1)
del llm_tp2
cleanup_dist_env_and_memory()
assert output_tp1 == output_tp2

View File

@@ -0,0 +1,55 @@
import pytest
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import get_lora_tokenizer
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
from ..conftest import get_tokenizer_pool_config
@pytest.mark.asyncio
@pytest.mark.parametrize("tokenizer_group_type", [None, "ray"])
async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type):
reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files)
tokenizer_group = get_tokenizer_group(
get_tokenizer_pool_config(tokenizer_group_type),
tokenizer_id="gpt2",
enable_lora=True,
max_num_seqs=1,
max_input_length=None,
)
lora_request = LoRARequest("1", 1, sql_lora_files)
assert reference_tokenizer.encode("prompt") == tokenizer_group.encode(
request_id="request_id", prompt="prompt", lora_request=lora_request)
assert reference_tokenizer.encode(
"prompt") == await tokenizer_group.encode_async(
request_id="request_id",
prompt="prompt",
lora_request=lora_request)
assert isinstance(tokenizer_group.get_lora_tokenizer(None),
PreTrainedTokenizerBase)
assert tokenizer_group.get_lora_tokenizer(
None) == await tokenizer_group.get_lora_tokenizer_async(None)
assert isinstance(tokenizer_group.get_lora_tokenizer(lora_request),
PreTrainedTokenizerBase)
assert tokenizer_group.get_lora_tokenizer(
lora_request) != tokenizer_group.get_lora_tokenizer(None)
assert tokenizer_group.get_lora_tokenizer(
lora_request) == await tokenizer_group.get_lora_tokenizer_async(
lora_request)
def test_get_lora_tokenizer(sql_lora_files, tmp_path):
lora_request = None
tokenizer = get_lora_tokenizer(lora_request)
assert not tokenizer
lora_request = LoRARequest("1", 1, sql_lora_files)
tokenizer = get_lora_tokenizer(lora_request)
assert tokenizer.get_added_vocab()
lora_request = LoRARequest("1", 1, str(tmp_path))
tokenizer = get_lora_tokenizer(lora_request)
assert not tokenizer

View File

@@ -0,0 +1,243 @@
from collections import OrderedDict
from unittest.mock import patch
import pytest
from huggingface_hub.utils import HfHubHTTPError
from torch import nn
from vllm.lora.utils import (get_adapter_absolute_path,
parse_fine_tuned_lora_name, replace_submodule)
from vllm.utils import LRUCache
def test_parse_fine_tuned_lora_name_valid():
fixture = {
("base_model.model.lm_head.lora_A.weight", "lm_head", True, False),
("base_model.model.lm_head.lora_B.weight", "lm_head", False, False),
(
"base_model.model.model.embed_tokens.lora_embedding_A",
"model.embed_tokens",
True,
False,
),
(
"base_model.model.model.embed_tokens.lora_embedding_B",
"model.embed_tokens",
False,
False,
),
(
"base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
"model.layers.9.mlp.down_proj",
True,
False,
),
(
"base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
"model.layers.9.mlp.down_proj",
False,
False,
),
}
for name, module_name, is_lora_a, is_bias in fixture:
assert (module_name, is_lora_a,
is_bias) == parse_fine_tuned_lora_name(name)
def test_parse_fine_tuned_lora_name_invalid():
fixture = {
"base_model.weight",
"base_model.model.weight",
}
for name in fixture:
with pytest.raises(ValueError, match="unsupported LoRA weight"):
parse_fine_tuned_lora_name(name)
def test_replace_submodule():
model = nn.Sequential(
OrderedDict([
("dense1", nn.Linear(764, 100)),
("act1", nn.ReLU()),
("dense2", nn.Linear(100, 50)),
(
"seq1",
nn.Sequential(
OrderedDict([
("dense1", nn.Linear(100, 10)),
("dense2", nn.Linear(10, 50)),
])),
),
("act2", nn.ReLU()),
("output", nn.Linear(50, 10)),
("outact", nn.Sigmoid()),
]))
sigmoid = nn.Sigmoid()
replace_submodule(model, "act1", sigmoid)
assert dict(model.named_modules())["act1"] == sigmoid
dense2 = nn.Linear(1, 5)
replace_submodule(model, "seq1.dense2", dense2)
assert dict(model.named_modules())["seq1.dense2"] == dense2
class TestLRUCache(LRUCache):
def _on_remove(self, key, value):
if not hasattr(self, "_remove_counter"):
self._remove_counter = 0
self._remove_counter += 1
def test_lru_cache():
cache = TestLRUCache(3)
cache.put(1, 1)
assert len(cache) == 1
cache.put(1, 1)
assert len(cache) == 1
cache.put(2, 2)
assert len(cache) == 2
cache.put(3, 3)
assert len(cache) == 3
assert set(cache.cache) == {1, 2, 3}
cache.put(4, 4)
assert len(cache) == 3
assert set(cache.cache) == {2, 3, 4}
assert cache._remove_counter == 1
assert cache.get(2) == 2
cache.put(5, 5)
assert set(cache.cache) == {2, 4, 5}
assert cache._remove_counter == 2
assert cache.pop(5) == 5
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3
cache.pop(10)
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3
cache.get(10)
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3
cache.put(6, 6)
assert len(cache) == 3
assert set(cache.cache) == {2, 4, 6}
assert 2 in cache
assert 4 in cache
assert 6 in cache
cache.remove_oldest()
assert len(cache) == 2
assert set(cache.cache) == {2, 6}
assert cache._remove_counter == 4
cache.clear()
assert len(cache) == 0
assert cache._remove_counter == 6
cache._remove_counter = 0
cache[1] = 1
assert len(cache) == 1
cache[1] = 1
assert len(cache) == 1
cache[2] = 2
assert len(cache) == 2
cache[3] = 3
assert len(cache) == 3
assert set(cache.cache) == {1, 2, 3}
cache[4] = 4
assert len(cache) == 3
assert set(cache.cache) == {2, 3, 4}
assert cache._remove_counter == 1
assert cache[2] == 2
cache[5] = 5
assert set(cache.cache) == {2, 4, 5}
assert cache._remove_counter == 2
del cache[5]
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3
cache.pop(10)
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3
cache[6] = 6
assert len(cache) == 3
assert set(cache.cache) == {2, 4, 6}
assert 2 in cache
assert 4 in cache
assert 6 in cache
# Unit tests for get_adapter_absolute_path
@patch('os.path.isabs')
def test_get_adapter_absolute_path_absolute(mock_isabs):
path = '/absolute/path/to/lora'
mock_isabs.return_value = True
assert get_adapter_absolute_path(path) == path
@patch('os.path.expanduser')
def test_get_adapter_absolute_path_expanduser(mock_expanduser):
# Path with ~ that needs to be expanded
path = '~/relative/path/to/lora'
absolute_path = '/home/user/relative/path/to/lora'
mock_expanduser.return_value = absolute_path
assert get_adapter_absolute_path(path) == absolute_path
@patch('os.path.exists')
@patch('os.path.abspath')
def test_get_adapter_absolute_path_local_existing(mock_abspath, mock_exist):
# Relative path that exists locally
path = 'relative/path/to/lora'
absolute_path = '/absolute/path/to/lora'
mock_exist.return_value = True
mock_abspath.return_value = absolute_path
assert get_adapter_absolute_path(path) == absolute_path
@patch('huggingface_hub.snapshot_download')
@patch('os.path.exists')
def test_get_adapter_absolute_path_huggingface(mock_exist,
mock_snapshot_download):
# Hugging Face model identifier
path = 'org/repo'
absolute_path = '/mock/snapshot/path'
mock_exist.return_value = False
mock_snapshot_download.return_value = absolute_path
assert get_adapter_absolute_path(path) == absolute_path
@patch('huggingface_hub.snapshot_download')
@patch('os.path.exists')
def test_get_adapter_absolute_path_huggingface_error(mock_exist,
mock_snapshot_download):
# Hugging Face model identifier with download error
path = 'org/repo'
mock_exist.return_value = False
mock_snapshot_download.side_effect = HfHubHTTPError(
"failed to query model info")
assert get_adapter_absolute_path(path) == path

View File

@@ -0,0 +1,74 @@
import os
import random
import tempfile
from unittest.mock import patch
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VllmConfig)
from vllm.lora.models import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.worker.worker import Worker
@patch.dict(os.environ, {"RANK": "0"})
def test_worker_apply_lora(sql_lora_files):
vllm_config = VllmConfig(
model_config=ModelConfig(
"meta-llama/Llama-2-7b-hf",
task="auto",
tokenizer="meta-llama/Llama-2-7b-hf",
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
revision=None,
),
load_config=LoadConfig(
download_dir=None,
load_format="dummy",
),
parallel_config=ParallelConfig(1, 1, False),
scheduler_config=SchedulerConfig("generate", 32, 32, 32),
device_config=DeviceConfig("cuda"),
cache_config=CacheConfig(block_size=16,
gpu_memory_utilization=1.,
swap_space=0,
cache_dtype="auto"),
lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32,
max_loras=32),
)
worker = Worker(
vllm_config=vllm_config,
local_rank=0,
rank=0,
distributed_init_method=f"file://{tempfile.mkstemp()[1]}",
)
worker.init_device()
worker.load_model()
worker.model_runner.set_active_loras([], LoRAMapping([], []))
assert worker.list_loras() == set()
n_loras = 32
lora_requests = [
LoRARequest(str(i + 1), i + 1, sql_lora_files) for i in range(n_loras)
]
worker.model_runner.set_active_loras(lora_requests, LoRAMapping([], []))
assert worker.list_loras() == {
lora_request.lora_int_id
for lora_request in lora_requests
}
for i in range(32):
random.seed(i)
iter_lora_requests = random.choices(lora_requests,
k=random.randint(1, n_loras))
random.shuffle(iter_lora_requests)
iter_lora_requests = iter_lora_requests[:-random.randint(0, n_loras)]
worker.model_runner.set_active_loras(iter_lora_requests,
LoRAMapping([], []))
assert worker.list_loras().issuperset(
{lora_request.lora_int_id
for lora_request in iter_lora_requests})

View File

@@ -0,0 +1,237 @@
from typing import Dict, List, Optional
import torch
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
class DummyLoRAManager:
def __init__(self, device: torch.device = "cuda:0"):
super().__init__()
self._loras: Dict[str, LoRALayerWeights] = {}
self._device = device
def set_module_lora(self, module_name: str, lora: LoRALayerWeights):
self._loras[module_name] = lora
def get_module_lora(self, module_name: str) -> LoRALayerWeights:
return self._loras[module_name]
def init_random_lora(self,
module_name: str,
weight: torch.Tensor,
rank: int = 8,
generate_embeddings_tensor: int = 0):
lora = LoRALayerWeights(
module_name,
rank=rank,
lora_alpha=1,
lora_a=torch.rand([weight.shape[1], rank],
dtype=weight.dtype,
device=self._device),
lora_b=torch.rand([rank, weight.shape[0]],
dtype=weight.dtype,
device=self._device),
)
if generate_embeddings_tensor:
lora.embeddings_tensor = torch.rand(5,
generate_embeddings_tensor,
dtype=weight.dtype,
device=self._device)
self.set_module_lora(module_name, lora)
return lora
def init_lora(self,
module_name: str,
input_dim: int,
output_dim: int,
rank=8,
noop=False,
embeddings_tensor=None):
lora = LoRALayerWeights(
module_name,
rank=rank,
lora_alpha=1,
lora_a=torch.rand([input_dim, rank], device="cuda"),
lora_b=torch.rand([rank, output_dim], device="cuda"),
embeddings_tensor=embeddings_tensor,
)
self.set_module_lora(module_name, lora)
return lora
def reset_lora(self):
self._loras = {}
def init_packed_lora(
self,
module_name: str,
input_dim: int,
output_dims: List[int],
noop_lora_index: Optional[List[int]] = None,
rank: int = 8,
):
base_loras: List[LoRALayerWeights] = []
noop_lora_index_set = set(noop_lora_index or [])
for i, out_dim in enumerate(output_dims):
base_lora = self.init_lora(
module_name + "_000_" + str(i),
input_dim,
out_dim,
rank=rank,
noop=i in noop_lora_index_set,
)
base_loras.append(base_lora)
packed_lora = PackedLoRALayerWeights.pack(base_loras)
self.set_module_lora(module_name, packed_lora)
return packed_lora
def assert_close(a, b):
rtol, atol = {
torch.float16: (6e-2, 6e-2),
torch.bfloat16: (6e-2, 6e-2),
torch.float32: (1e-2, 1e-2),
}[a.dtype]
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
def ref_torch_groupgemm(
out_tensor,
inputs,
lora_weights,
lora_indices_tensor,
seq_len_tensor,
batches,
scaling,
op_type,
) -> torch.Tensor:
out_list = []
current_offset = 0
for lora_index, b_length in zip(range(batches), seq_len_tensor):
input_weight = inputs[current_offset:b_length + current_offset, :]
current_offset += b_length
lora_weight = lora_weights[lora_indices_tensor[lora_index]]
result = torch.nn.functional.linear(input_weight, lora_weight)
result *= scaling
out_list.append(result)
cat_result = torch.cat(out_list, dim=0)
if op_type == "expand":
out_tensor += cat_result
else:
out_tensor.copy_(cat_result)
return
def generate_data(batches, hidden_size, lora_nums, max_rank, seq_length, dtype,
op_type, device):
seq_len_tensor = torch.randint(seq_length, seq_length + 1,
(batches, )).to(device)
b_seq_start_loc = torch.cumsum(
torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long),
dim=0,
).to(device)
total_tokens = seq_len_tensor.sum()
if op_type == "shrink":
inputs_tensor = torch.rand((total_tokens, hidden_size),
dtype=dtype).to(device)
lora_weights = torch.rand(
(lora_nums, max_rank, hidden_size), # col-major
dtype=dtype,
).to(device)
# shrink op need atomic_add, so output is initinized by 0
ref_out_tensor = torch.zeros((total_tokens, max_rank),
dtype=dtype,
device=inputs_tensor.device)
# NOTE shrink kernel using torch.float32 as output type
our_out_tensor = torch.zeros((total_tokens, max_rank),
dtype=torch.float32).to(device)
else:
inputs_tensor = torch.rand(
(total_tokens, max_rank),
dtype=dtype,
).to(device)
lora_weights = torch.rand(
(lora_nums, hidden_size, max_rank), # col-major
dtype=dtype,
).to(device)
# expand op needs to complete y+=a@lora_b, so output is
# initinized randomly
ref_out_tensor = torch.rand(
(total_tokens, hidden_size),
dtype=dtype,
).to(device)
# Ensure the same input.
our_out_tensor = ref_out_tensor.clone()
lora_indices_tensor = torch.randint(0,
lora_nums - 1 if lora_nums > 1 else 1,
(batches, )).to(device)
indices = torch.zeros((total_tokens), dtype=torch.long).to(device)
current_offset = 0
for b_id in range(batches):
lora_index = lora_indices_tensor[b_id]
indices[current_offset:current_offset +
seq_len_tensor[b_id]].copy_(lora_index)
current_offset += seq_len_tensor[b_id].item()
return (
inputs_tensor,
lora_weights,
our_out_tensor,
ref_out_tensor,
b_seq_start_loc,
lora_indices_tensor,
seq_len_tensor,
indices,
)
def generate_data_for_expand_nslices(batches, hidden_size, lora_nums, max_rank,
seq_length, dtype, nslices, device):
seq_len_tensor = torch.randint(seq_length, seq_length + 1,
(batches, )).to(device)
b_seq_start_loc = torch.cumsum(
torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long),
dim=0,
).to(device)
total_tokens = seq_len_tensor.sum()
inputs_tensor = torch.rand(
(total_tokens, max_rank),
dtype=dtype,
).to(device)
lora_weights_lst = []
for _ in range(nslices):
lora_weights_lst.append(
torch.rand(
(lora_nums, hidden_size, max_rank), # col-major
dtype=dtype,
).to(device))
# expand op needs to complete y+=a@lora_b, so output is
# initinized randomly
ref_out_tensor = torch.rand((total_tokens, hidden_size * nslices),
dtype=dtype).to(device)
# Ensure the same input.
our_out_tensor = ref_out_tensor.clone()
lora_indices_tensor = torch.randint(0,
lora_nums - 1 if lora_nums > 1 else 1,
(batches, ))
indices = torch.zeros((total_tokens), dtype=torch.long).to(device)
current_offset = 0
for b_id in range(batches):
lora_index = lora_indices_tensor[b_id]
indices[current_offset:current_offset +
seq_len_tensor[b_id]] = lora_index.item()
current_offset += seq_len_tensor[b_id].item()
lora_indices_tensor = lora_indices_tensor.to(device)
return (
inputs_tensor,
lora_weights_lst,
our_out_tensor,
ref_out_tensor,
b_seq_start_loc,
lora_indices_tensor,
seq_len_tensor,
indices,
)