[Platform][Worker][ModelRunner] Add LoRA & Multi-LoRA support (#521)

### What this PR does / why we need it?
According to this RFC [[RFC]: Join the MultiLora and MultiLora Dynammic
Serving feature develop
#396](https://github.com/vllm-project/vllm-ascend/issues/396) and this
[vLLM Ascend Roadmap Q2 2025
#448](https://github.com/vllm-project/vllm-ascend/issues/448), we pull
request relavant code to support (1) Multi-LoRA and (2) Multi-LoRA
Dynamic Serving.

LoRA reference is here: [LoRA
reference](https://docs.vllm.ai/en/latest/features/lora.html)

### Does this PR introduce _any_ user-facing change?

Following openai HTTP apis will be supported:
/v1/load_lora_adapter
/v1/unload_lora_adapter

### How was this patch tested?
git clone https://github.com/vllm-project/vllm.git
cd vllm/examples/offline_inference/ && python3 multilora_inference.py

---------

Signed-off-by: paulyu <paulyu0307@gmail.com>
Co-authored-by: paulyu <paulyu0307@gmail.com>
This commit is contained in:
paulyu12
2025-04-17 16:48:46 +08:00
committed by GitHub
parent 9935d45728
commit 697908f5cd
4 changed files with 484 additions and 14 deletions

View File

@@ -38,11 +38,13 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models import supports_lora, supports_multimodal
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalKwargs, MultiModalPlaceholderMap,
@@ -79,6 +81,8 @@ class ModelInputForNPU(ModelRunnerInputBase):
token_types: Optional[torch.Tensor] = None
seq_lens: Optional[List[int]] = None
query_lens: Optional[List[int]] = None
lora_mapping: Optional["LoRAMapping"] = None
lora_requests: Optional[Set[LoRARequest]] = None
attn_metadata: Optional["AttentionMetadata"] = None
multi_modal_kwargs: Optional[BatchedTensorInputs] = None
request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
@@ -93,6 +97,8 @@ class ModelInputForNPU(ModelRunnerInputBase):
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs,
"virtual_engine": self.virtual_engine,
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
@@ -139,6 +145,8 @@ class ModelInputForNPUWithSamplingMetadata(ModelInputForNPU):
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs,
"virtual_engine": self.virtual_engine,
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
@@ -181,6 +189,9 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
self.query_lens[0] = 0 # type: ignore
self.context_lens[0] = 0 # type: ignore
self.curr_sliding_window_blocks[0] = 0 # type: ignore
self.lora_index_mapping.clear() # type: ignore
self.lora_prompt_mapping.clear() # type: ignore
self.lora_requests.clear() # type: ignore
def __init__(
self,
@@ -211,6 +222,11 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
# The current sliding window block.
curr_sliding_window_blocks: Optional[List[int]] = None,
# LoRA inputs.
lora_index_mapping: Optional[List[List[int]]] = None,
lora_prompt_mapping: Optional[List[List[int]]] = None,
lora_requests: Optional[Set[LoRARequest]] = None,
# Multi-modal inputs.
multi_modal_kwargs: Optional[MultiModalKwargs] = None,
multi_modal_placeholder_maps: Optional[Dict[
@@ -291,6 +307,19 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
for seq_id in range(len(self.seq_ids)):
self.curr_sliding_window_blocks[seq_id] = 0
if lora_index_mapping:
self.lora_index_mapping = lora_index_mapping
else:
self.lora_index_mapping.clear()
if lora_prompt_mapping:
self.lora_prompt_mapping = lora_prompt_mapping
else:
self.lora_prompt_mapping.clear()
if lora_requests:
self.lora_requests = lora_requests
else:
self.lora_requests.clear()
else:
self.input_tokens = input_tokens or []
self.input_positions = input_positions or []
@@ -303,6 +332,10 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
self.curr_sliding_window_blocks = \
curr_sliding_window_blocks or []
self.lora_index_mapping = lora_index_mapping or []
self.lora_prompt_mapping = lora_prompt_mapping or []
self.lora_requests = lora_requests or set()
self.multi_modal_kwargs = multi_modal_kwargs
self.multi_modal_placeholder_maps = multi_modal_placeholder_maps
self.prefix_cache_hit = prefix_cache_hit
@@ -325,6 +358,9 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
self.context_lens = [0] * self.n_seqs
self.curr_sliding_window_blocks = [0] * self.n_seqs
self.lora_index_mapping = []
self.lora_prompt_mapping = []
def __init__(self,
runner,
finished_requests_ids: Optional[List[str]] = None):
@@ -335,6 +371,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
self._compute_lens,
self._compute_for_prefix_cache_hit,
self._compute_for_sliding_window,
self._compute_lora_input,
]
# Compute functions for each sequence group.
# WARNING: The order of the functions matters!
@@ -348,6 +385,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
self.scheduler_config = self.runner.scheduler_config
self.sliding_window = self.runner.sliding_window
self.block_size = self.runner.block_size
self.enable_lora = self.runner.lora_config is not None
self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
self.finished_requests_ids = finished_requests_ids
self.decode_only = True
@@ -512,6 +550,25 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
# Attention metadata.
attn_metadata = self.attn_metadata_builder.build(seq_lens, query_lens)
# LoRA data.
lora_requests = set()
lora_mapping = None
if self.enable_lora:
lora_requests = set(r for data in self.inter_data_list
for r in data.lora_requests)
lora_index_mapping = flatten_2d_lists([
flatten_2d_lists(inter_data.lora_index_mapping)
for inter_data in self.inter_data_list
])
lora_prompt_mapping = flatten_2d_lists([
flatten_2d_lists(inter_data.lora_prompt_mapping)
for inter_data in self.inter_data_list
])
lora_mapping = LoRAMapping(
**dict(index_mapping=lora_index_mapping,
prompt_mapping=lora_prompt_mapping,
is_prefill=not self.decode_only))
# Multi-modal data.
multi_modal_kwargs_list = [
data.multi_modal_kwargs for data in self.inter_data_list
@@ -525,6 +582,8 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
attn_metadata=attn_metadata,
seq_lens=seq_lens,
query_lens=query_lens,
lora_mapping=lora_mapping,
lora_requests=lora_requests,
multi_modal_kwargs=multi_modal_kwargs,
request_ids_to_seq_ids=request_ids_to_seq_ids,
finished_requests_ids=self.finished_requests_ids)
@@ -663,6 +722,25 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
seq_idx] = curr_sliding_window_block
inter_data.seq_lens[seq_idx] = sliding_seq_len
def _compute_lora_input(self, inter_data: InterDataForSeqGroup,
seq_idx: int,
seq_group_metadata: SequenceGroupMetadata):
"""If LoRA is enabled, compute LoRA index and prompt mapping."""
if not self.enable_lora:
return
lora_id = seq_group_metadata.lora_int_id
if lora_id > 0:
inter_data.lora_requests.add(seq_group_metadata.lora_request)
query_len = inter_data.query_lens[seq_idx]
inter_data.lora_index_mapping.append([lora_id] * query_len)
sampling_params = seq_group_metadata.sampling_params
if sampling_params and sampling_params.prompt_logprobs is not None:
inter_data.lora_prompt_mapping.append([lora_id] * query_len)
elif not self.chunked_prefill_enabled or seq_group_metadata.do_sample:
inter_data.lora_prompt_mapping.append([lora_id])
else:
inter_data.lora_prompt_mapping.append([])
def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup,
seq_group_metadata: SequenceGroupMetadata):
"""If multi-modal data is given, add it to the input."""
@@ -789,6 +867,8 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]):
# Lazy initialization
self.model: nn.Module # Set after load_model
# Set after load_model.
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
set_cpu_offload_max_bytes(
int(self.cache_config.cpu_offload_gb * 1024**3))
@@ -818,6 +898,32 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]):
logger.info("Loading model weights took %.4f GB",
self.model_memory_usage / float(2**30))
if self.lora_config:
assert supports_lora(
self.model
), f"{self.model.__class__.__name__} does not support LoRA yet."
if supports_multimodal(self.model):
logger.warning("Regarding multimodal models, vLLM currently "
"only supports adding LoRA to language model.")
# It's necessary to distinguish between the max_position_embeddings
# of VLMs and LLMs.
if hasattr(self.model.config, "max_position_embeddings"):
max_pos_embeddings = self.model.config.max_position_embeddings
else:
max_pos_embeddings = (
self.model.config.text_config.max_position_embeddings)
self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens,
self.vocab_size,
self.lora_config,
self.device,
self.model.embedding_modules,
self.model.embedding_padding_modules,
max_position_embeddings=max_pos_embeddings,
)
self.model = self.lora_manager.create_lora_manager(self.model)
def save_sharded_state(
self,
path: str,
@@ -967,23 +1073,35 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]):
return
def remove_all_loras(self):
raise RuntimeError("LoRA is not supported on NPU now.")
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
self.lora_manager.remove_all_adapters()
def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None:
raise RuntimeError("LoRA is not supported on NPU now.")
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
def add_lora(self, lora_request: LoRARequest) -> bool:
raise RuntimeError("LoRA is not supported on NPU now.")
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.add_adapter(lora_request)
def remove_lora(self, lora_id: int) -> bool:
raise RuntimeError("LoRA is not supported on NPU now.")
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.remove_adapter(lora_id)
def pin_lora(self, lora_id: int) -> bool:
raise RuntimeError("LoRA is not supported on NPU now.")
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.pin_adapter(lora_id)
def list_loras(self) -> Set[int]:
raise RuntimeError("LoRA is not supported on NPU now.")
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.list_adapters()
def remove_all_prompt_adapters(self):
raise RuntimeError("PromptAdapter is not supported on NPU now.")
@@ -1086,6 +1204,12 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
if num_steps > 1:
raise ValueError("num_steps > 1 is not supported in ModelRunner")
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
self.attn_state.begin_forward(model_input)
assert model_input.attn_metadata is not None