[V1][LoRA][Test] V1 Engine LoRA support & e2e test (#893)

### What this PR does / why we need it?

Add V1Engine LoRA support.
Add LoRA e2e test on single card and multiple cards.

### Does this PR introduce _any_ user-facing change?
support lora for V1

### How was this patch tested?

CI passed with new added test

---------

Signed-off-by: jesse <szxfml@gmail.com>
Signed-off-by: paulyu <paulyu0307@gmail.com>
Signed-off-by: paulyu12 <507435917@qq.com>
Co-authored-by: jesse <szxfml@gmail.com>
Co-authored-by: paulyu <paulyu0307@gmail.com>
This commit is contained in:
yupeng
2025-05-22 19:20:51 +08:00
committed by GitHub
parent 7aa4f85f10
commit 0f53b138f6
6 changed files with 167 additions and 38 deletions

View File

@@ -51,11 +51,11 @@ jobs:
vllm_verison: [main, v0.8.5.post1]
concurrency:
group: >
${{
matrix.os == 'linux-arm64-npu-4'
&& github.event.pull_request.number
&& format('pr-{0}-limit-npu-4', github.event.pull_request.number)
|| format('job-{0}-{1}-{2}', matrix.os, matrix.vllm_verison, github.event.pull_request.number)
${{
matrix.os == 'linux-arm64-npu-4'
&& github.event.pull_request.number
&& format('pr-{0}-limit-npu-4', github.event.pull_request.number)
|| format('job-{0}-{1}-{2}', matrix.os, matrix.vllm_verison, github.event.pull_request.number)
}}
cancel-in-progress: false
name: vLLM Ascend test
@@ -112,10 +112,12 @@ jobs:
run: |
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
pytest -sv tests/singlecard/test_offline_inference.py
pytest -sv tests/singlecard/test_ilama_lora.py
pytest -sv tests/ops
pytest -sv tests/compile
else
pytest -sv -k "QwQ" tests/multicard/test_offline_inference_distributed.py
pytest -sv tests/multicard/test_ilama_lora_tp2.py
pytest -sv tests/ops
pytest -sv tests/compile
fi
@@ -125,9 +127,11 @@ jobs:
VLLM_USE_V1: 0
run: |
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
pytest -sv tests/singlecard/test_ilama_lora.py
pytest -sv tests/singlecard/test_offline_inference.py
pytest -sv tests/ops
else
pytest -sv tests/multicard/test_ilama_lora_tp2.py
pytest -sv -k "QwQ" tests/multicard/test_offline_inference_distributed.py
pytest -sv -k "DeepSeek" tests/multicard/test_offline_inference_distributed.py
pytest -sv tests/ops

View File

@@ -23,6 +23,7 @@ from typing import List, Optional, Tuple, TypeVar, Union
import numpy as np
import pytest
import torch
from huggingface_hub import snapshot_download
from PIL import Image
from vllm import LLM, SamplingParams
from vllm.config import TaskOption
@@ -348,4 +349,9 @@ def vllm_runner():
@pytest.fixture(params=list(PROMPT_TEMPLATES.keys()))
def prompt_template(request):
return PROMPT_TEMPLATES[request.param]
return PROMPT_TEMPLATES[request.param]
@pytest.fixture(scope="session")
def ilama_lora_files():
return snapshot_download(repo_id="jeeejeee/ilama-text2sql-spider")

View File

@@ -0,0 +1,21 @@
import pytest
from tests.conftest import VllmRunner
from tests.singlecard.test_ilama_lora import (EXPECTED_LORA_OUTPUT, MODEL_PATH,
do_sample)
@pytest.mark.parametrize("distributed_executor_backend", ["mp"])
def test_ilama_lora_tp2(distributed_executor_backend, ilama_lora_files):
with VllmRunner(model_name=MODEL_PATH,
enable_lora=True,
max_loras=4,
max_model_len=1024,
max_num_seqs=16,
tensor_parallel_size=2,
distributed_executor_backend=distributed_executor_backend
) as vllm_model:
output = do_sample(vllm_model.model, ilama_lora_files, lora_id=2)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output[i] == EXPECTED_LORA_OUTPUT[i]

View File

@@ -0,0 +1,60 @@
# SPDX-License-Identifier: Apache-2.0
import vllm
from vllm.lora.request import LoRARequest
from tests.conftest import VllmRunner
MODEL_PATH = "ArthurZ/ilama-3.2-1B"
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501
EXPECTED_LORA_OUTPUT = [
"SELECT count(*) FROM singer",
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", # noqa: E501
"SELECT DISTINCT Country FROM singer WHERE Age > 20",
]
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
prompts = [
PROMPT_TEMPLATE.format(query="How many singers do we have?"),
PROMPT_TEMPLATE.format(
query=
"What is the average, minimum, and maximum age of all singers from France?" # noqa: E501
),
PROMPT_TEMPLATE.format(
query=
"What are all distinct countries where singers above age 20 are from?" # noqa: E501
),
]
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts: list[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
def test_ilama_lora(ilama_lora_files):
with VllmRunner(model_name=MODEL_PATH,
enable_lora=True,
max_loras=4,
max_model_len=1024,
max_num_seqs=16) as vllm_model:
output1 = do_sample(vllm_model.model, ilama_lora_files, lora_id=1)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output1[i] == EXPECTED_LORA_OUTPUT[i]
output2 = do_sample(vllm_model.model, ilama_lora_files, lora_id=2)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output2[i] == EXPECTED_LORA_OUTPUT[i]

View File

@@ -50,6 +50,7 @@ from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
from vllm.v1.sample.sampler import Sampler
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm_ascend.attention.attention import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
@@ -102,7 +103,7 @@ def graph_capture(device: torch.device):
yield graph_capture_context
class NPUModelRunner:
class NPUModelRunner(LoRAModelRunnerMixin):
def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config
@@ -543,6 +544,10 @@ class NPUModelRunner:
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
num_tokens)
# Hot-Swap lora model
if self.lora_config:
self.set_active_loras(self.input_batch, num_scheduled_tokens)
# Prepare positions
req_indices = np.repeat(self.arange_np[:num_reqs],
num_scheduled_tokens)
@@ -867,39 +872,55 @@ class NPUModelRunner:
@torch.inference_mode()
def _dummy_run(self, num_tokens: int) -> torch.Tensor:
model = self.model
if self.is_multimodal_model:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
else:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = None
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
# for dummy run with LoRA so that the num_reqs collectively
# has num_tokens in total.
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
max_num_reqs = self.scheduler_config.max_num_seqs
num_reqs = max_num_reqs if num_tokens >= max_num_reqs else num_tokens
min_tokens_per_req = num_tokens // num_reqs
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
assert sum(num_scheduled_tokens_list) == num_tokens
assert len(num_scheduled_tokens_list) == num_reqs
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
dtype=np.int32)
with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens):
model = self.model
if self.is_multimodal_model:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
else:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = None
if self.uses_mrope:
positions = self.mrope_positions[:, :num_tokens]
else:
positions = self.positions[:num_tokens]
if self.uses_mrope:
positions = self.mrope_positions[:, :num_tokens]
else:
positions = self.positions[:num_tokens]
if get_pp_group().is_first_rank:
intermediate_tensors = None
else:
if self.intermediate_tensors is None:
self.intermediate_tensors = (
self.model.make_empty_intermediate_tensors(
batch_size=num_tokens,
dtype=self.dtype,
device=self.device))
intermediate_tensors = IntermediateTensors({
k: v[:num_tokens]
for k, v in self.intermediate_tensors.items()
})
if get_pp_group().is_first_rank:
intermediate_tensors = None
else:
if self.intermediate_tensors is None:
self.intermediate_tensors = (
self.model.make_empty_intermediate_tensors(
batch_size=num_tokens,
dtype=self.dtype,
device=self.device))
intermediate_tensors = IntermediateTensors({
k: v[:num_tokens]
for k, v in self.intermediate_tensors.items()
})
with set_forward_context(None, self.vllm_config):
hidden_states = model(input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states
with set_forward_context(None, self.vllm_config):
hidden_states = model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states
def profile_run(self) -> None:
# Profile with multimodal encoder & encoder cache.
@@ -948,7 +969,11 @@ class NPUModelRunner:
with DeviceMemoryProfiler() as m: # noqa: SIM117
self.model = get_model(vllm_config=self.vllm_config)
if self.lora_config:
raise ValueError("LoRA model is not supported on NPU now.")
self.model = self.load_lora_model(self.model,
self.model_config,
self.scheduler_config,
self.lora_config,
self.device)
logger.info("Loading model weights took %.4f GB",
m.consumed_memory / float(2**30))

View File

@@ -31,6 +31,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
set_custom_all_reduce)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.logger import logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import SchedulerOutput
@@ -216,6 +217,18 @@ class NPUWorker(WorkerBase):
else:
self.profiler.stop()
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.model_runner.remove_lora(lora_id)
def list_loras(self) -> set[int]:
return self.model_runner.list_loras()
def pin_lora(self, lora_id: int) -> bool:
return self.model_runner.pin_lora(lora_id)
def execute_dummy_batch(self) -> None:
self.model_runner._dummy_run(1)