[V1][LoRA][Test] V1 Engine LoRA support & e2e test (#893)

### What this PR does / why we need it?

Add V1Engine LoRA support.
Add LoRA e2e test on single card and multiple cards.

### Does this PR introduce _any_ user-facing change?
support lora for V1

### How was this patch tested?

CI passed with new added test

---------

Signed-off-by: jesse <szxfml@gmail.com>
Signed-off-by: paulyu <paulyu0307@gmail.com>
Signed-off-by: paulyu12 <507435917@qq.com>
Co-authored-by: jesse <szxfml@gmail.com>
Co-authored-by: paulyu <paulyu0307@gmail.com>
This commit is contained in:
yupeng
2025-05-22 19:20:51 +08:00
committed by GitHub
parent 7aa4f85f10
commit 0f53b138f6
6 changed files with 167 additions and 38 deletions

View File

@@ -50,6 +50,7 @@ from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
from vllm.v1.sample.sampler import Sampler
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm_ascend.attention.attention import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
@@ -102,7 +103,7 @@ def graph_capture(device: torch.device):
yield graph_capture_context
class NPUModelRunner:
class NPUModelRunner(LoRAModelRunnerMixin):
def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config
@@ -543,6 +544,10 @@ class NPUModelRunner:
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
num_tokens)
# Hot-Swap lora model
if self.lora_config:
self.set_active_loras(self.input_batch, num_scheduled_tokens)
# Prepare positions
req_indices = np.repeat(self.arange_np[:num_reqs],
num_scheduled_tokens)
@@ -867,39 +872,55 @@ class NPUModelRunner:
@torch.inference_mode()
def _dummy_run(self, num_tokens: int) -> torch.Tensor:
model = self.model
if self.is_multimodal_model:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
else:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = None
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
# for dummy run with LoRA so that the num_reqs collectively
# has num_tokens in total.
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
max_num_reqs = self.scheduler_config.max_num_seqs
num_reqs = max_num_reqs if num_tokens >= max_num_reqs else num_tokens
min_tokens_per_req = num_tokens // num_reqs
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
assert sum(num_scheduled_tokens_list) == num_tokens
assert len(num_scheduled_tokens_list) == num_reqs
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
dtype=np.int32)
with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens):
model = self.model
if self.is_multimodal_model:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
else:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = None
if self.uses_mrope:
positions = self.mrope_positions[:, :num_tokens]
else:
positions = self.positions[:num_tokens]
if self.uses_mrope:
positions = self.mrope_positions[:, :num_tokens]
else:
positions = self.positions[:num_tokens]
if get_pp_group().is_first_rank:
intermediate_tensors = None
else:
if self.intermediate_tensors is None:
self.intermediate_tensors = (
self.model.make_empty_intermediate_tensors(
batch_size=num_tokens,
dtype=self.dtype,
device=self.device))
intermediate_tensors = IntermediateTensors({
k: v[:num_tokens]
for k, v in self.intermediate_tensors.items()
})
if get_pp_group().is_first_rank:
intermediate_tensors = None
else:
if self.intermediate_tensors is None:
self.intermediate_tensors = (
self.model.make_empty_intermediate_tensors(
batch_size=num_tokens,
dtype=self.dtype,
device=self.device))
intermediate_tensors = IntermediateTensors({
k: v[:num_tokens]
for k, v in self.intermediate_tensors.items()
})
with set_forward_context(None, self.vllm_config):
hidden_states = model(input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states
with set_forward_context(None, self.vllm_config):
hidden_states = model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states
def profile_run(self) -> None:
# Profile with multimodal encoder & encoder cache.
@@ -948,7 +969,11 @@ class NPUModelRunner:
with DeviceMemoryProfiler() as m: # noqa: SIM117
self.model = get_model(vllm_config=self.vllm_config)
if self.lora_config:
raise ValueError("LoRA model is not supported on NPU now.")
self.model = self.load_lora_model(self.model,
self.model_config,
self.scheduler_config,
self.lora_config,
self.device)
logger.info("Loading model weights took %.4f GB",
m.consumed_memory / float(2**30))

View File

@@ -31,6 +31,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
set_custom_all_reduce)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.logger import logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import SchedulerOutput
@@ -216,6 +217,18 @@ class NPUWorker(WorkerBase):
else:
self.profiler.stop()
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.model_runner.remove_lora(lora_id)
def list_loras(self) -> set[int]:
return self.model_runner.list_loras()
def pin_lora(self, lora_id: int) -> bool:
return self.model_runner.pin_lora(lora_id)
def execute_dummy_batch(self) -> None:
self.model_runner._dummy_run(1)