[Platform][Worker][ModelRunner] Add LoRA & Multi-LoRA support (#521)
### What this PR does / why we need it? According to this RFC [[RFC]: Join the MultiLora and MultiLora Dynammic Serving feature develop #396](https://github.com/vllm-project/vllm-ascend/issues/396) and this [vLLM Ascend Roadmap Q2 2025 #448](https://github.com/vllm-project/vllm-ascend/issues/448), we pull request relavant code to support (1) Multi-LoRA and (2) Multi-LoRA Dynamic Serving. LoRA reference is here: [LoRA reference](https://docs.vllm.ai/en/latest/features/lora.html) ### Does this PR introduce _any_ user-facing change? Following openai HTTP apis will be supported: /v1/load_lora_adapter /v1/unload_lora_adapter ### How was this patch tested? git clone https://github.com/vllm-project/vllm.git cd vllm/examples/offline_inference/ && python3 multilora_inference.py --------- Signed-off-by: paulyu <paulyu0307@gmail.com> Co-authored-by: paulyu <paulyu0307@gmail.com>
This commit is contained in:
@@ -38,11 +38,13 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||
from vllm.logger import logger
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||
from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.model_executor.models import supports_lora, supports_multimodal
|
||||
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
||||
MultiModalKwargs, MultiModalPlaceholderMap,
|
||||
@@ -79,6 +81,8 @@ class ModelInputForNPU(ModelRunnerInputBase):
|
||||
token_types: Optional[torch.Tensor] = None
|
||||
seq_lens: Optional[List[int]] = None
|
||||
query_lens: Optional[List[int]] = None
|
||||
lora_mapping: Optional["LoRAMapping"] = None
|
||||
lora_requests: Optional[Set[LoRARequest]] = None
|
||||
attn_metadata: Optional["AttentionMetadata"] = None
|
||||
multi_modal_kwargs: Optional[BatchedTensorInputs] = None
|
||||
request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
|
||||
@@ -93,6 +97,8 @@ class ModelInputForNPU(ModelRunnerInputBase):
|
||||
tensor_dict = {
|
||||
"input_tokens": self.input_tokens,
|
||||
"input_positions": self.input_positions,
|
||||
"lora_requests": self.lora_requests,
|
||||
"lora_mapping": self.lora_mapping,
|
||||
"multi_modal_kwargs": self.multi_modal_kwargs,
|
||||
"virtual_engine": self.virtual_engine,
|
||||
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
|
||||
@@ -139,6 +145,8 @@ class ModelInputForNPUWithSamplingMetadata(ModelInputForNPU):
|
||||
tensor_dict = {
|
||||
"input_tokens": self.input_tokens,
|
||||
"input_positions": self.input_positions,
|
||||
"lora_requests": self.lora_requests,
|
||||
"lora_mapping": self.lora_mapping,
|
||||
"multi_modal_kwargs": self.multi_modal_kwargs,
|
||||
"virtual_engine": self.virtual_engine,
|
||||
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
|
||||
@@ -181,6 +189,9 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
self.query_lens[0] = 0 # type: ignore
|
||||
self.context_lens[0] = 0 # type: ignore
|
||||
self.curr_sliding_window_blocks[0] = 0 # type: ignore
|
||||
self.lora_index_mapping.clear() # type: ignore
|
||||
self.lora_prompt_mapping.clear() # type: ignore
|
||||
self.lora_requests.clear() # type: ignore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -211,6 +222,11 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
# The current sliding window block.
|
||||
curr_sliding_window_blocks: Optional[List[int]] = None,
|
||||
|
||||
# LoRA inputs.
|
||||
lora_index_mapping: Optional[List[List[int]]] = None,
|
||||
lora_prompt_mapping: Optional[List[List[int]]] = None,
|
||||
lora_requests: Optional[Set[LoRARequest]] = None,
|
||||
|
||||
# Multi-modal inputs.
|
||||
multi_modal_kwargs: Optional[MultiModalKwargs] = None,
|
||||
multi_modal_placeholder_maps: Optional[Dict[
|
||||
@@ -291,6 +307,19 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
for seq_id in range(len(self.seq_ids)):
|
||||
self.curr_sliding_window_blocks[seq_id] = 0
|
||||
|
||||
if lora_index_mapping:
|
||||
self.lora_index_mapping = lora_index_mapping
|
||||
else:
|
||||
self.lora_index_mapping.clear()
|
||||
if lora_prompt_mapping:
|
||||
self.lora_prompt_mapping = lora_prompt_mapping
|
||||
else:
|
||||
self.lora_prompt_mapping.clear()
|
||||
if lora_requests:
|
||||
self.lora_requests = lora_requests
|
||||
else:
|
||||
self.lora_requests.clear()
|
||||
|
||||
else:
|
||||
self.input_tokens = input_tokens or []
|
||||
self.input_positions = input_positions or []
|
||||
@@ -303,6 +332,10 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
self.curr_sliding_window_blocks = \
|
||||
curr_sliding_window_blocks or []
|
||||
|
||||
self.lora_index_mapping = lora_index_mapping or []
|
||||
self.lora_prompt_mapping = lora_prompt_mapping or []
|
||||
self.lora_requests = lora_requests or set()
|
||||
|
||||
self.multi_modal_kwargs = multi_modal_kwargs
|
||||
self.multi_modal_placeholder_maps = multi_modal_placeholder_maps
|
||||
self.prefix_cache_hit = prefix_cache_hit
|
||||
@@ -325,6 +358,9 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
self.context_lens = [0] * self.n_seqs
|
||||
self.curr_sliding_window_blocks = [0] * self.n_seqs
|
||||
|
||||
self.lora_index_mapping = []
|
||||
self.lora_prompt_mapping = []
|
||||
|
||||
def __init__(self,
|
||||
runner,
|
||||
finished_requests_ids: Optional[List[str]] = None):
|
||||
@@ -335,6 +371,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
self._compute_lens,
|
||||
self._compute_for_prefix_cache_hit,
|
||||
self._compute_for_sliding_window,
|
||||
self._compute_lora_input,
|
||||
]
|
||||
# Compute functions for each sequence group.
|
||||
# WARNING: The order of the functions matters!
|
||||
@@ -348,6 +385,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
self.scheduler_config = self.runner.scheduler_config
|
||||
self.sliding_window = self.runner.sliding_window
|
||||
self.block_size = self.runner.block_size
|
||||
self.enable_lora = self.runner.lora_config is not None
|
||||
self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
|
||||
self.finished_requests_ids = finished_requests_ids
|
||||
self.decode_only = True
|
||||
@@ -512,6 +550,25 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
# Attention metadata.
|
||||
attn_metadata = self.attn_metadata_builder.build(seq_lens, query_lens)
|
||||
|
||||
# LoRA data.
|
||||
lora_requests = set()
|
||||
lora_mapping = None
|
||||
if self.enable_lora:
|
||||
lora_requests = set(r for data in self.inter_data_list
|
||||
for r in data.lora_requests)
|
||||
lora_index_mapping = flatten_2d_lists([
|
||||
flatten_2d_lists(inter_data.lora_index_mapping)
|
||||
for inter_data in self.inter_data_list
|
||||
])
|
||||
lora_prompt_mapping = flatten_2d_lists([
|
||||
flatten_2d_lists(inter_data.lora_prompt_mapping)
|
||||
for inter_data in self.inter_data_list
|
||||
])
|
||||
lora_mapping = LoRAMapping(
|
||||
**dict(index_mapping=lora_index_mapping,
|
||||
prompt_mapping=lora_prompt_mapping,
|
||||
is_prefill=not self.decode_only))
|
||||
|
||||
# Multi-modal data.
|
||||
multi_modal_kwargs_list = [
|
||||
data.multi_modal_kwargs for data in self.inter_data_list
|
||||
@@ -525,6 +582,8 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
attn_metadata=attn_metadata,
|
||||
seq_lens=seq_lens,
|
||||
query_lens=query_lens,
|
||||
lora_mapping=lora_mapping,
|
||||
lora_requests=lora_requests,
|
||||
multi_modal_kwargs=multi_modal_kwargs,
|
||||
request_ids_to_seq_ids=request_ids_to_seq_ids,
|
||||
finished_requests_ids=self.finished_requests_ids)
|
||||
@@ -663,6 +722,25 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
seq_idx] = curr_sliding_window_block
|
||||
inter_data.seq_lens[seq_idx] = sliding_seq_len
|
||||
|
||||
def _compute_lora_input(self, inter_data: InterDataForSeqGroup,
|
||||
seq_idx: int,
|
||||
seq_group_metadata: SequenceGroupMetadata):
|
||||
"""If LoRA is enabled, compute LoRA index and prompt mapping."""
|
||||
if not self.enable_lora:
|
||||
return
|
||||
lora_id = seq_group_metadata.lora_int_id
|
||||
if lora_id > 0:
|
||||
inter_data.lora_requests.add(seq_group_metadata.lora_request)
|
||||
query_len = inter_data.query_lens[seq_idx]
|
||||
inter_data.lora_index_mapping.append([lora_id] * query_len)
|
||||
sampling_params = seq_group_metadata.sampling_params
|
||||
if sampling_params and sampling_params.prompt_logprobs is not None:
|
||||
inter_data.lora_prompt_mapping.append([lora_id] * query_len)
|
||||
elif not self.chunked_prefill_enabled or seq_group_metadata.do_sample:
|
||||
inter_data.lora_prompt_mapping.append([lora_id])
|
||||
else:
|
||||
inter_data.lora_prompt_mapping.append([])
|
||||
|
||||
def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup,
|
||||
seq_group_metadata: SequenceGroupMetadata):
|
||||
"""If multi-modal data is given, add it to the input."""
|
||||
@@ -789,6 +867,8 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]):
|
||||
|
||||
# Lazy initialization
|
||||
self.model: nn.Module # Set after load_model
|
||||
# Set after load_model.
|
||||
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
|
||||
|
||||
set_cpu_offload_max_bytes(
|
||||
int(self.cache_config.cpu_offload_gb * 1024**3))
|
||||
@@ -818,6 +898,32 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]):
|
||||
logger.info("Loading model weights took %.4f GB",
|
||||
self.model_memory_usage / float(2**30))
|
||||
|
||||
if self.lora_config:
|
||||
assert supports_lora(
|
||||
self.model
|
||||
), f"{self.model.__class__.__name__} does not support LoRA yet."
|
||||
if supports_multimodal(self.model):
|
||||
logger.warning("Regarding multimodal models, vLLM currently "
|
||||
"only supports adding LoRA to language model.")
|
||||
# It's necessary to distinguish between the max_position_embeddings
|
||||
# of VLMs and LLMs.
|
||||
if hasattr(self.model.config, "max_position_embeddings"):
|
||||
max_pos_embeddings = self.model.config.max_position_embeddings
|
||||
else:
|
||||
max_pos_embeddings = (
|
||||
self.model.config.text_config.max_position_embeddings)
|
||||
self.lora_manager = LRUCacheWorkerLoRAManager(
|
||||
self.scheduler_config.max_num_seqs,
|
||||
self.scheduler_config.max_num_batched_tokens,
|
||||
self.vocab_size,
|
||||
self.lora_config,
|
||||
self.device,
|
||||
self.model.embedding_modules,
|
||||
self.model.embedding_padding_modules,
|
||||
max_position_embeddings=max_pos_embeddings,
|
||||
)
|
||||
self.model = self.lora_manager.create_lora_manager(self.model)
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
path: str,
|
||||
@@ -967,23 +1073,35 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]):
|
||||
return
|
||||
|
||||
def remove_all_loras(self):
|
||||
raise RuntimeError("LoRA is not supported on NPU now.")
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
self.lora_manager.remove_all_adapters()
|
||||
|
||||
def set_active_loras(self, lora_requests: Set[LoRARequest],
|
||||
lora_mapping: LoRAMapping) -> None:
|
||||
raise RuntimeError("LoRA is not supported on NPU now.")
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
raise RuntimeError("LoRA is not supported on NPU now.")
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.add_adapter(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
raise RuntimeError("LoRA is not supported on NPU now.")
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.remove_adapter(lora_id)
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
raise RuntimeError("LoRA is not supported on NPU now.")
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.pin_adapter(lora_id)
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
raise RuntimeError("LoRA is not supported on NPU now.")
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.list_adapters()
|
||||
|
||||
def remove_all_prompt_adapters(self):
|
||||
raise RuntimeError("PromptAdapter is not supported on NPU now.")
|
||||
@@ -1086,6 +1204,12 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
||||
if num_steps > 1:
|
||||
raise ValueError("num_steps > 1 is not supported in ModelRunner")
|
||||
|
||||
if self.lora_config:
|
||||
assert model_input.lora_requests is not None
|
||||
assert model_input.lora_mapping is not None
|
||||
self.set_active_loras(model_input.lora_requests,
|
||||
model_input.lora_mapping)
|
||||
|
||||
self.attn_state.begin_forward(model_input)
|
||||
|
||||
assert model_input.attn_metadata is not None
|
||||
|
||||
Reference in New Issue
Block a user