From 697908f5cd7c65a3a917ec1a962b0886efc98c7e Mon Sep 17 00:00:00 2001 From: paulyu12 <507435917@qq.com> Date: Thu, 17 Apr 2025 16:48:46 +0800 Subject: [PATCH] [Platform][Worker][ModelRunner] Add LoRA & Multi-LoRA support (#521) ### What this PR does / why we need it? According to this RFC [[RFC]: Join the MultiLora and MultiLora Dynammic Serving feature develop #396](https://github.com/vllm-project/vllm-ascend/issues/396) and this [vLLM Ascend Roadmap Q2 2025 #448](https://github.com/vllm-project/vllm-ascend/issues/448), we pull request relavant code to support (1) Multi-LoRA and (2) Multi-LoRA Dynamic Serving. LoRA reference is here: [LoRA reference](https://docs.vllm.ai/en/latest/features/lora.html) ### Does this PR introduce _any_ user-facing change? Following openai HTTP apis will be supported: /v1/load_lora_adapter /v1/unload_lora_adapter ### How was this patch tested? git clone https://github.com/vllm-project/vllm.git cd vllm/examples/offline_inference/ && python3 multilora_inference.py --------- Signed-off-by: paulyu Co-authored-by: paulyu --- vllm_ascend/lora/punica_wrapper/punica_npu.py | 346 ++++++++++++++++++ vllm_ascend/platform.py | 4 + vllm_ascend/worker/model_runner.py | 136 ++++++- vllm_ascend/worker/worker.py | 12 +- 4 files changed, 484 insertions(+), 14 deletions(-) create mode 100644 vllm_ascend/lora/punica_wrapper/punica_npu.py diff --git a/vllm_ascend/lora/punica_wrapper/punica_npu.py b/vllm_ascend/lora/punica_wrapper/punica_npu.py new file mode 100644 index 0000000..339ed36 --- /dev/null +++ b/vllm_ascend/lora/punica_wrapper/punica_npu.py @@ -0,0 +1,346 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable, Optional, Tuple, Union + +import torch +from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice, + bgmv_shrink, sgmv_expand, + sgmv_expand_slice, sgmv_shrink) +from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase + + +# The platforms that are compatible with the PyTorch-native implementation can +# inherit this class +class PunicaWrapperNPU(PunicaWrapperBase): + """ + PunicaWrapperNPU is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for + Multi-LoRA, and to provide the interface for the pytorch punica ops. + """ + + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: Union[torch.device, str], **kwargs): + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, + device) + + def _shrink_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_shrink( + x, + w_t_all, + y, + *self.prefill_metadata, + scale, + ) + + def _shrink_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) + + def _expand_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_inputs: bool, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand( + x, + w_t_all, + y, + *self.prefill_metadata, + add_inputs, + ) + + def _expand_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_inputs: bool, + ): + bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs) + + def _expand_slice_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: int, + y_slice_size: int, + add_inputs: bool, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand_slice( + x, + w_t_all, + y, + *self.prefill_metadata, + y_offset, + y_slice_size, + add_inputs, + ) + + def _expand_slice_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: int, + y_slice_size: int, + add_inputs: bool, + ): + bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, + y_slice_size, add_inputs) + + def _apply_expand( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: int, + y_slice_size: int, + add_inputs: bool = True, + ): + """ + Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all` + computation, which is suitable for the + GEMM of lora'b. + """ + + expand_slice_fun: Callable = (self._expand_slice_prefill + if self.is_prefill else + self._expand_slice_decode) + expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs) + + def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, + w_t_all: torch.Tensor, scale: float): + """ + Perform the ` y+=x@w_t_all` computation, which is suitable for the + GEMM of lora'a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function + should be called. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + shrink_fun: Callable = (self._shrink_prefill + if self.is_prefill else self._shrink_decode) + shrink_fun(y, x, w_t_all, scale) + y = y.view_as(y_org) + + def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], + scale: float, **kwargs): + """ + Performs GEMM for multiple slices of lora_a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function + should be called. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += (x @ lora_a_stacked[i]) * scale + + Args: + y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights + scale (float): Scaling factor for the operation + """ + + x = x.view(-1, x.shape[-1]) + # TODO fuse these kernels + for slice_idx in range(len(lora_a_stacked)): + self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], + scale) + + def add_expand(self, + y: torch.Tensor, + x: Union[Tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + output_slices: Tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs) -> None: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + + Semantics: + for i in range(len(lora_b_stacked)): + slice = output_slices[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] + offset += slice + + Args: + y (torch.Tensor): Output tensor. + x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): + bias's weight + output_slices (Tuple[int, ...]): Every slice's size + add_inputs (bool): Defaults to True. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + offset_left = offset_start + if lora_bias_stacked is not None: + self._apply_bias(self.token_lora_indices, y, output_slices, + lora_bias_stacked) + for slice_idx in range(len(lora_b_stacked)): + self._apply_expand( + y, + x[slice_idx], + lora_b_stacked[slice_idx], + offset_left, + output_slices[slice_idx], + add_inputs=add_inputs, + ) + offset_left += output_slices[slice_idx] + y = y.view_as(y_org) + + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs) -> None: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA. + + Semantics: + y += x @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_b_stacked (torch.Tensor): lora_b's weights. + add_inputs (bool): Default to True. + """ + + # Embedding layer only need expand op + expand_fun: Callable = (self._expand_prefill + if self.is_prefill else self._expand_decode) + expand_fun(y, x, lora_b_stacked, add_inputs) + + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None, + **kwargs) -> None: + """ + Applicable to linear-related lora. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += ( + x[i].unsqueeze(0) + @ lora_a_stacked[indices[i], layer_idx, :, :] + @ lora_b_stacked[indices[i], layer_idx, :, :] + * scale + ).squeeze(0)+lora_bias_stacked[i] + + Args: + y (torch.Tensor): Output tensor. Will be changed in-place. + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. + scale (float): Scaling factor. + output_slices (Tuple[int, ...]): Every slice's size. + buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. + """ + + assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) + if lora_bias_stacked is not None: + assert len(lora_bias_stacked) == len(output_slices) + y = self._apply_bias(self.token_lora_indices, y, output_slices, + lora_bias_stacked) + + if buffer is None: + r = lora_b_stacked[0].size(-1) + # We set the buffer to be float32 by default, consistent with the + # triton op + buffer = tuple( + torch.zeros( + (x.size(0), r), dtype=torch.float32, device=x.device) + for _ in range(len(output_slices))) + self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) + self.add_expand(y, + buffer, + lora_b_stacked, + None, + output_slices, + add_inputs=True, + **kwargs) + + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + + Semantics: + buffer = (x @ lora_a_stacked) * scale + y += buffer @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_a_stacked (torch.Tensor): lora_a's weights. + lora_b_stacked (torch.Tensor):lora_b's weights. + scale (float): Scaling factor. + buffer (Optional[torch.Tensor]):Default to None. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + r = lora_b_stacked.size(-1) + if buffer is None: + # We set the buffer to be float32 by default, consistent with the + # triton op + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + # LogitsProcessorWithLoRA always using bgmv. + bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale) + bgmv_expand(buffer, + lora_b_stacked, + y, + self.sampler_indices, + add_inputs=True) + y = y.view_as(y_org) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index f48ee6e..c54bdfe 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -141,6 +141,10 @@ class NPUPlatform(Platform): return "vllm_ascend.attention.attention.AscendMLAAttentionBackend" return "vllm_ascend.attention.attention.AscendAttentionBackend" + @classmethod + def get_punica_wrapper(cls) -> str: + return "vllm_ascend.lora.punica_wrapper.punica_npu.PunicaWrapperNPU" + @classmethod def get_current_memory_usage(cls, device: Optional[torch.types.Device] = None diff --git a/vllm_ascend/worker/model_runner.py b/vllm_ascend/worker/model_runner.py index 702b9f3..c240c80 100644 --- a/vllm_ascend/worker/model_runner.py +++ b/vllm_ascend/worker/model_runner.py @@ -38,11 +38,13 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata, SamplingMetadataCache from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader.tensorizer import TensorizerConfig +from vllm.model_executor.models import supports_lora, supports_multimodal from vllm.model_executor.models.utils import set_cpu_offload_max_bytes from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalKwargs, MultiModalPlaceholderMap, @@ -79,6 +81,8 @@ class ModelInputForNPU(ModelRunnerInputBase): token_types: Optional[torch.Tensor] = None seq_lens: Optional[List[int]] = None query_lens: Optional[List[int]] = None + lora_mapping: Optional["LoRAMapping"] = None + lora_requests: Optional[Set[LoRARequest]] = None attn_metadata: Optional["AttentionMetadata"] = None multi_modal_kwargs: Optional[BatchedTensorInputs] = None request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None @@ -93,6 +97,8 @@ class ModelInputForNPU(ModelRunnerInputBase): tensor_dict = { "input_tokens": self.input_tokens, "input_positions": self.input_positions, + "lora_requests": self.lora_requests, + "lora_mapping": self.lora_mapping, "multi_modal_kwargs": self.multi_modal_kwargs, "virtual_engine": self.virtual_engine, "request_ids_to_seq_ids": self.request_ids_to_seq_ids, @@ -139,6 +145,8 @@ class ModelInputForNPUWithSamplingMetadata(ModelInputForNPU): tensor_dict = { "input_tokens": self.input_tokens, "input_positions": self.input_positions, + "lora_requests": self.lora_requests, + "lora_mapping": self.lora_mapping, "multi_modal_kwargs": self.multi_modal_kwargs, "virtual_engine": self.virtual_engine, "request_ids_to_seq_ids": self.request_ids_to_seq_ids, @@ -181,6 +189,9 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]): self.query_lens[0] = 0 # type: ignore self.context_lens[0] = 0 # type: ignore self.curr_sliding_window_blocks[0] = 0 # type: ignore + self.lora_index_mapping.clear() # type: ignore + self.lora_prompt_mapping.clear() # type: ignore + self.lora_requests.clear() # type: ignore def __init__( self, @@ -211,6 +222,11 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]): # The current sliding window block. curr_sliding_window_blocks: Optional[List[int]] = None, + # LoRA inputs. + lora_index_mapping: Optional[List[List[int]]] = None, + lora_prompt_mapping: Optional[List[List[int]]] = None, + lora_requests: Optional[Set[LoRARequest]] = None, + # Multi-modal inputs. multi_modal_kwargs: Optional[MultiModalKwargs] = None, multi_modal_placeholder_maps: Optional[Dict[ @@ -291,6 +307,19 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]): for seq_id in range(len(self.seq_ids)): self.curr_sliding_window_blocks[seq_id] = 0 + if lora_index_mapping: + self.lora_index_mapping = lora_index_mapping + else: + self.lora_index_mapping.clear() + if lora_prompt_mapping: + self.lora_prompt_mapping = lora_prompt_mapping + else: + self.lora_prompt_mapping.clear() + if lora_requests: + self.lora_requests = lora_requests + else: + self.lora_requests.clear() + else: self.input_tokens = input_tokens or [] self.input_positions = input_positions or [] @@ -303,6 +332,10 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]): self.curr_sliding_window_blocks = \ curr_sliding_window_blocks or [] + self.lora_index_mapping = lora_index_mapping or [] + self.lora_prompt_mapping = lora_prompt_mapping or [] + self.lora_requests = lora_requests or set() + self.multi_modal_kwargs = multi_modal_kwargs self.multi_modal_placeholder_maps = multi_modal_placeholder_maps self.prefix_cache_hit = prefix_cache_hit @@ -325,6 +358,9 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]): self.context_lens = [0] * self.n_seqs self.curr_sliding_window_blocks = [0] * self.n_seqs + self.lora_index_mapping = [] + self.lora_prompt_mapping = [] + def __init__(self, runner, finished_requests_ids: Optional[List[str]] = None): @@ -335,6 +371,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]): self._compute_lens, self._compute_for_prefix_cache_hit, self._compute_for_sliding_window, + self._compute_lora_input, ] # Compute functions for each sequence group. # WARNING: The order of the functions matters! @@ -348,6 +385,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]): self.scheduler_config = self.runner.scheduler_config self.sliding_window = self.runner.sliding_window self.block_size = self.runner.block_size + self.enable_lora = self.runner.lora_config is not None self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper self.finished_requests_ids = finished_requests_ids self.decode_only = True @@ -512,6 +550,25 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]): # Attention metadata. attn_metadata = self.attn_metadata_builder.build(seq_lens, query_lens) + # LoRA data. + lora_requests = set() + lora_mapping = None + if self.enable_lora: + lora_requests = set(r for data in self.inter_data_list + for r in data.lora_requests) + lora_index_mapping = flatten_2d_lists([ + flatten_2d_lists(inter_data.lora_index_mapping) + for inter_data in self.inter_data_list + ]) + lora_prompt_mapping = flatten_2d_lists([ + flatten_2d_lists(inter_data.lora_prompt_mapping) + for inter_data in self.inter_data_list + ]) + lora_mapping = LoRAMapping( + **dict(index_mapping=lora_index_mapping, + prompt_mapping=lora_prompt_mapping, + is_prefill=not self.decode_only)) + # Multi-modal data. multi_modal_kwargs_list = [ data.multi_modal_kwargs for data in self.inter_data_list @@ -525,6 +582,8 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]): attn_metadata=attn_metadata, seq_lens=seq_lens, query_lens=query_lens, + lora_mapping=lora_mapping, + lora_requests=lora_requests, multi_modal_kwargs=multi_modal_kwargs, request_ids_to_seq_ids=request_ids_to_seq_ids, finished_requests_ids=self.finished_requests_ids) @@ -663,6 +722,25 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]): seq_idx] = curr_sliding_window_block inter_data.seq_lens[seq_idx] = sliding_seq_len + def _compute_lora_input(self, inter_data: InterDataForSeqGroup, + seq_idx: int, + seq_group_metadata: SequenceGroupMetadata): + """If LoRA is enabled, compute LoRA index and prompt mapping.""" + if not self.enable_lora: + return + lora_id = seq_group_metadata.lora_int_id + if lora_id > 0: + inter_data.lora_requests.add(seq_group_metadata.lora_request) + query_len = inter_data.query_lens[seq_idx] + inter_data.lora_index_mapping.append([lora_id] * query_len) + sampling_params = seq_group_metadata.sampling_params + if sampling_params and sampling_params.prompt_logprobs is not None: + inter_data.lora_prompt_mapping.append([lora_id] * query_len) + elif not self.chunked_prefill_enabled or seq_group_metadata.do_sample: + inter_data.lora_prompt_mapping.append([lora_id]) + else: + inter_data.lora_prompt_mapping.append([]) + def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, seq_group_metadata: SequenceGroupMetadata): """If multi-modal data is given, add it to the input.""" @@ -789,6 +867,8 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]): # Lazy initialization self.model: nn.Module # Set after load_model + # Set after load_model. + self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None set_cpu_offload_max_bytes( int(self.cache_config.cpu_offload_gb * 1024**3)) @@ -818,6 +898,32 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]): logger.info("Loading model weights took %.4f GB", self.model_memory_usage / float(2**30)) + if self.lora_config: + assert supports_lora( + self.model + ), f"{self.model.__class__.__name__} does not support LoRA yet." + if supports_multimodal(self.model): + logger.warning("Regarding multimodal models, vLLM currently " + "only supports adding LoRA to language model.") + # It's necessary to distinguish between the max_position_embeddings + # of VLMs and LLMs. + if hasattr(self.model.config, "max_position_embeddings"): + max_pos_embeddings = self.model.config.max_position_embeddings + else: + max_pos_embeddings = ( + self.model.config.text_config.max_position_embeddings) + self.lora_manager = LRUCacheWorkerLoRAManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, + self.vocab_size, + self.lora_config, + self.device, + self.model.embedding_modules, + self.model.embedding_padding_modules, + max_position_embeddings=max_pos_embeddings, + ) + self.model = self.lora_manager.create_lora_manager(self.model) + def save_sharded_state( self, path: str, @@ -967,23 +1073,35 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]): return def remove_all_loras(self): - raise RuntimeError("LoRA is not supported on NPU now.") + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.remove_all_adapters() def set_active_loras(self, lora_requests: Set[LoRARequest], lora_mapping: LoRAMapping) -> None: - raise RuntimeError("LoRA is not supported on NPU now.") + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.set_active_adapters(lora_requests, lora_mapping) def add_lora(self, lora_request: LoRARequest) -> bool: - raise RuntimeError("LoRA is not supported on NPU now.") + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.add_adapter(lora_request) def remove_lora(self, lora_id: int) -> bool: - raise RuntimeError("LoRA is not supported on NPU now.") + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.remove_adapter(lora_id) def pin_lora(self, lora_id: int) -> bool: - raise RuntimeError("LoRA is not supported on NPU now.") + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.pin_adapter(lora_id) def list_loras(self) -> Set[int]: - raise RuntimeError("LoRA is not supported on NPU now.") + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.list_adapters() def remove_all_prompt_adapters(self): raise RuntimeError("PromptAdapter is not supported on NPU now.") @@ -1086,6 +1204,12 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]): if num_steps > 1: raise ValueError("num_steps > 1 is not supported in ModelRunner") + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) + self.attn_state.begin_forward(model_input) assert model_input.attn_metadata is not None diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index 39d6fdc..3dc1a88 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -404,20 +404,16 @@ class NPUWorker(LocalOrDistributedWorkerBase): return output def add_lora(self, lora_request: LoRARequest) -> bool: - raise NotImplementedError( - "LoRA is not implemented for NPU backend currently.") + return self.model_runner.add_lora(lora_request) def remove_lora(self, lora_id: int) -> bool: - raise NotImplementedError( - "LoRA is not implemented for NPU backend currently.") + return self.model_runner.remove_lora(lora_id) def pin_lora(self, lora_id: int) -> bool: - raise NotImplementedError( - "LoRA is not implemented for NPU backend currently.") + return self.model_runner.pin_lora(lora_id) def list_loras(self) -> Set[int]: - raise NotImplementedError( - "LoRA is not implemented for NPU backend currently.") + return self.model_runner.list_loras() def add_prompt_adapter( self, prompt_adapter_request: PromptAdapterRequest) -> bool: