[Platform][Worker][ModelRunner] Add LoRA & Multi-LoRA support (#521)
### What this PR does / why we need it? According to this RFC [[RFC]: Join the MultiLora and MultiLora Dynammic Serving feature develop #396](https://github.com/vllm-project/vllm-ascend/issues/396) and this [vLLM Ascend Roadmap Q2 2025 #448](https://github.com/vllm-project/vllm-ascend/issues/448), we pull request relavant code to support (1) Multi-LoRA and (2) Multi-LoRA Dynamic Serving. LoRA reference is here: [LoRA reference](https://docs.vllm.ai/en/latest/features/lora.html) ### Does this PR introduce _any_ user-facing change? Following openai HTTP apis will be supported: /v1/load_lora_adapter /v1/unload_lora_adapter ### How was this patch tested? git clone https://github.com/vllm-project/vllm.git cd vllm/examples/offline_inference/ && python3 multilora_inference.py --------- Signed-off-by: paulyu <paulyu0307@gmail.com> Co-authored-by: paulyu <paulyu0307@gmail.com>
This commit is contained in:
346
vllm_ascend/lora/punica_wrapper/punica_npu.py
Normal file
346
vllm_ascend/lora/punica_wrapper/punica_npu.py
Normal file
@@ -0,0 +1,346 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
|
||||||
|
bgmv_shrink, sgmv_expand,
|
||||||
|
sgmv_expand_slice, sgmv_shrink)
|
||||||
|
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
|
||||||
|
|
||||||
|
|
||||||
|
# The platforms that are compatible with the PyTorch-native implementation can
|
||||||
|
# inherit this class
|
||||||
|
class PunicaWrapperNPU(PunicaWrapperBase):
|
||||||
|
"""
|
||||||
|
PunicaWrapperNPU is designed to manage and provide metadata for the punica
|
||||||
|
kernel. The main function is to maintain the state information for
|
||||||
|
Multi-LoRA, and to provide the interface for the pytorch punica ops.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_num_batched_tokens: int, max_batches: int,
|
||||||
|
device: Union[torch.device, str], **kwargs):
|
||||||
|
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
|
||||||
|
device)
|
||||||
|
|
||||||
|
def _shrink_prefill(
|
||||||
|
self,
|
||||||
|
y: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
w_t_all: torch.Tensor,
|
||||||
|
scale: float,
|
||||||
|
):
|
||||||
|
#No LoRA request, so return directly
|
||||||
|
if self.no_lora:
|
||||||
|
return
|
||||||
|
sgmv_shrink(
|
||||||
|
x,
|
||||||
|
w_t_all,
|
||||||
|
y,
|
||||||
|
*self.prefill_metadata,
|
||||||
|
scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _shrink_decode(
|
||||||
|
self,
|
||||||
|
y: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
w_t_all: torch.Tensor,
|
||||||
|
scale: float,
|
||||||
|
):
|
||||||
|
bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
|
||||||
|
|
||||||
|
def _expand_prefill(
|
||||||
|
self,
|
||||||
|
y: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
w_t_all: torch.Tensor,
|
||||||
|
add_inputs: bool,
|
||||||
|
):
|
||||||
|
#No LoRA request, so return directly
|
||||||
|
if self.no_lora:
|
||||||
|
return
|
||||||
|
sgmv_expand(
|
||||||
|
x,
|
||||||
|
w_t_all,
|
||||||
|
y,
|
||||||
|
*self.prefill_metadata,
|
||||||
|
add_inputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _expand_decode(
|
||||||
|
self,
|
||||||
|
y: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
w_t_all: torch.Tensor,
|
||||||
|
add_inputs: bool,
|
||||||
|
):
|
||||||
|
bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs)
|
||||||
|
|
||||||
|
def _expand_slice_prefill(
|
||||||
|
self,
|
||||||
|
y: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
w_t_all: torch.Tensor,
|
||||||
|
y_offset: int,
|
||||||
|
y_slice_size: int,
|
||||||
|
add_inputs: bool,
|
||||||
|
):
|
||||||
|
#No LoRA request, so return directly
|
||||||
|
if self.no_lora:
|
||||||
|
return
|
||||||
|
sgmv_expand_slice(
|
||||||
|
x,
|
||||||
|
w_t_all,
|
||||||
|
y,
|
||||||
|
*self.prefill_metadata,
|
||||||
|
y_offset,
|
||||||
|
y_slice_size,
|
||||||
|
add_inputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _expand_slice_decode(
|
||||||
|
self,
|
||||||
|
y: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
w_t_all: torch.Tensor,
|
||||||
|
y_offset: int,
|
||||||
|
y_slice_size: int,
|
||||||
|
add_inputs: bool,
|
||||||
|
):
|
||||||
|
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
|
||||||
|
y_slice_size, add_inputs)
|
||||||
|
|
||||||
|
def _apply_expand(
|
||||||
|
self,
|
||||||
|
y: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
w_t_all: torch.Tensor,
|
||||||
|
y_offset: int,
|
||||||
|
y_slice_size: int,
|
||||||
|
add_inputs: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
|
||||||
|
computation, which is suitable for the
|
||||||
|
GEMM of lora'b.
|
||||||
|
"""
|
||||||
|
|
||||||
|
expand_slice_fun: Callable = (self._expand_slice_prefill
|
||||||
|
if self.is_prefill else
|
||||||
|
self._expand_slice_decode)
|
||||||
|
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs)
|
||||||
|
|
||||||
|
def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor,
|
||||||
|
w_t_all: torch.Tensor, scale: float):
|
||||||
|
"""
|
||||||
|
Perform the ` y+=x@w_t_all` computation, which is suitable for the
|
||||||
|
GEMM of lora'a.
|
||||||
|
When `is_prefill is` true, it indicates that it is currently the
|
||||||
|
prefill stage, and the `_shrink_prefill` function should be called.
|
||||||
|
Otherwise, it is the decode stage, and the _shrink_decode function
|
||||||
|
should be called.
|
||||||
|
"""
|
||||||
|
y_org = y
|
||||||
|
y = y.view(-1, y.shape[-1])
|
||||||
|
shrink_fun: Callable = (self._shrink_prefill
|
||||||
|
if self.is_prefill else self._shrink_decode)
|
||||||
|
shrink_fun(y, x, w_t_all, scale)
|
||||||
|
y = y.view_as(y_org)
|
||||||
|
|
||||||
|
def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
||||||
|
x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
|
||||||
|
scale: float, **kwargs):
|
||||||
|
"""
|
||||||
|
Performs GEMM for multiple slices of lora_a.
|
||||||
|
When `is_prefill is` true, it indicates that it is currently the
|
||||||
|
prefill stage, and the `_shrink_prefill` function should be called.
|
||||||
|
Otherwise, it is the decode stage, and the _shrink_decode function
|
||||||
|
should be called.
|
||||||
|
|
||||||
|
Semantics:
|
||||||
|
for i in range(len(lora_a_stacked)):
|
||||||
|
y[i] += (x @ lora_a_stacked[i]) * scale
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
|
||||||
|
x (torch.Tensor): Input tensor
|
||||||
|
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
|
||||||
|
scale (float): Scaling factor for the operation
|
||||||
|
"""
|
||||||
|
|
||||||
|
x = x.view(-1, x.shape[-1])
|
||||||
|
# TODO fuse these kernels
|
||||||
|
for slice_idx in range(len(lora_a_stacked)):
|
||||||
|
self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx],
|
||||||
|
scale)
|
||||||
|
|
||||||
|
def add_expand(self,
|
||||||
|
y: torch.Tensor,
|
||||||
|
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
||||||
|
lora_b_stacked: Tuple[torch.Tensor, ...],
|
||||||
|
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
||||||
|
output_slices: Tuple[int, ...],
|
||||||
|
offset_start: int = 0,
|
||||||
|
add_inputs=True,
|
||||||
|
**kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Performs GEMM and bias addition for multiple slices of lora_b.
|
||||||
|
|
||||||
|
Semantics:
|
||||||
|
for i in range(len(lora_b_stacked)):
|
||||||
|
slice = output_slices[i]
|
||||||
|
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
|
||||||
|
lora_bias_stacked[i]
|
||||||
|
offset += slice
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y (torch.Tensor): Output tensor.
|
||||||
|
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
|
||||||
|
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
|
||||||
|
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
|
||||||
|
bias's weight
|
||||||
|
output_slices (Tuple[int, ...]): Every slice's size
|
||||||
|
add_inputs (bool): Defaults to True.
|
||||||
|
"""
|
||||||
|
y_org = y
|
||||||
|
y = y.view(-1, y.shape[-1])
|
||||||
|
offset_left = offset_start
|
||||||
|
if lora_bias_stacked is not None:
|
||||||
|
self._apply_bias(self.token_lora_indices, y, output_slices,
|
||||||
|
lora_bias_stacked)
|
||||||
|
for slice_idx in range(len(lora_b_stacked)):
|
||||||
|
self._apply_expand(
|
||||||
|
y,
|
||||||
|
x[slice_idx],
|
||||||
|
lora_b_stacked[slice_idx],
|
||||||
|
offset_left,
|
||||||
|
output_slices[slice_idx],
|
||||||
|
add_inputs=add_inputs,
|
||||||
|
)
|
||||||
|
offset_left += output_slices[slice_idx]
|
||||||
|
y = y.view_as(y_org)
|
||||||
|
|
||||||
|
def add_lora_embedding(self,
|
||||||
|
y: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
lora_b_stacked: torch.Tensor,
|
||||||
|
add_inputs: bool = True,
|
||||||
|
**kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
|
||||||
|
|
||||||
|
Semantics:
|
||||||
|
y += x @ lora_b_stacked
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y (torch.Tensor): Output tensor.
|
||||||
|
x (torch.Tensor): Input tensor.
|
||||||
|
lora_b_stacked (torch.Tensor): lora_b's weights.
|
||||||
|
add_inputs (bool): Default to True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Embedding layer only need expand op
|
||||||
|
expand_fun: Callable = (self._expand_prefill
|
||||||
|
if self.is_prefill else self._expand_decode)
|
||||||
|
expand_fun(y, x, lora_b_stacked, add_inputs)
|
||||||
|
|
||||||
|
def add_lora_linear(self,
|
||||||
|
y: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
lora_a_stacked: Tuple[torch.Tensor, ...],
|
||||||
|
lora_b_stacked: Tuple[torch.Tensor, ...],
|
||||||
|
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
||||||
|
scale: float,
|
||||||
|
output_slices: Tuple[int, ...],
|
||||||
|
*,
|
||||||
|
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
|
||||||
|
**kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Applicable to linear-related lora.
|
||||||
|
|
||||||
|
Semantics:
|
||||||
|
for i in range(len(lora_a_stacked)):
|
||||||
|
y[i] += (
|
||||||
|
x[i].unsqueeze(0)
|
||||||
|
@ lora_a_stacked[indices[i], layer_idx, :, :]
|
||||||
|
@ lora_b_stacked[indices[i], layer_idx, :, :]
|
||||||
|
* scale
|
||||||
|
).squeeze(0)+lora_bias_stacked[i]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y (torch.Tensor): Output tensor. Will be changed in-place.
|
||||||
|
x (torch.Tensor): Input tensor
|
||||||
|
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight.
|
||||||
|
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight.
|
||||||
|
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias.
|
||||||
|
scale (float): Scaling factor.
|
||||||
|
output_slices (Tuple[int, ...]): Every slice's size.
|
||||||
|
buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
|
||||||
|
if lora_bias_stacked is not None:
|
||||||
|
assert len(lora_bias_stacked) == len(output_slices)
|
||||||
|
y = self._apply_bias(self.token_lora_indices, y, output_slices,
|
||||||
|
lora_bias_stacked)
|
||||||
|
|
||||||
|
if buffer is None:
|
||||||
|
r = lora_b_stacked[0].size(-1)
|
||||||
|
# We set the buffer to be float32 by default, consistent with the
|
||||||
|
# triton op
|
||||||
|
buffer = tuple(
|
||||||
|
torch.zeros(
|
||||||
|
(x.size(0), r), dtype=torch.float32, device=x.device)
|
||||||
|
for _ in range(len(output_slices)))
|
||||||
|
self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
|
||||||
|
self.add_expand(y,
|
||||||
|
buffer,
|
||||||
|
lora_b_stacked,
|
||||||
|
None,
|
||||||
|
output_slices,
|
||||||
|
add_inputs=True,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
def add_lora_logits(self,
|
||||||
|
y: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
lora_a_stacked: torch.Tensor,
|
||||||
|
lora_b_stacked: torch.Tensor,
|
||||||
|
scale,
|
||||||
|
*,
|
||||||
|
buffer: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Applies lora specifically for LogitsProcessorWithLoRA.
|
||||||
|
|
||||||
|
Semantics:
|
||||||
|
buffer = (x @ lora_a_stacked) * scale
|
||||||
|
y += buffer @ lora_b_stacked
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y (torch.Tensor): Output tensor.
|
||||||
|
x (torch.Tensor): Input tensor.
|
||||||
|
lora_a_stacked (torch.Tensor): lora_a's weights.
|
||||||
|
lora_b_stacked (torch.Tensor):lora_b's weights.
|
||||||
|
scale (float): Scaling factor.
|
||||||
|
buffer (Optional[torch.Tensor]):Default to None.
|
||||||
|
"""
|
||||||
|
y_org = y
|
||||||
|
y = y.view(-1, y.shape[-1])
|
||||||
|
x = x.view(-1, x.shape[-1])
|
||||||
|
r = lora_b_stacked.size(-1)
|
||||||
|
if buffer is None:
|
||||||
|
# We set the buffer to be float32 by default, consistent with the
|
||||||
|
# triton op
|
||||||
|
buffer = torch.zeros((x.size(0), r),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=x.device)
|
||||||
|
# LogitsProcessorWithLoRA always using bgmv.
|
||||||
|
bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale)
|
||||||
|
bgmv_expand(buffer,
|
||||||
|
lora_b_stacked,
|
||||||
|
y,
|
||||||
|
self.sampler_indices,
|
||||||
|
add_inputs=True)
|
||||||
|
y = y.view_as(y_org)
|
||||||
@@ -141,6 +141,10 @@ class NPUPlatform(Platform):
|
|||||||
return "vllm_ascend.attention.attention.AscendMLAAttentionBackend"
|
return "vllm_ascend.attention.attention.AscendMLAAttentionBackend"
|
||||||
return "vllm_ascend.attention.attention.AscendAttentionBackend"
|
return "vllm_ascend.attention.attention.AscendAttentionBackend"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_punica_wrapper(cls) -> str:
|
||||||
|
return "vllm_ascend.lora.punica_wrapper.punica_npu.PunicaWrapperNPU"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_current_memory_usage(cls,
|
def get_current_memory_usage(cls,
|
||||||
device: Optional[torch.types.Device] = None
|
device: Optional[torch.types.Device] = None
|
||||||
|
|||||||
@@ -38,11 +38,13 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
|||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
from vllm.lora.layers import LoRAMapping
|
from vllm.lora.layers import LoRAMapping
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||||
from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
|
from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
|
||||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||||
|
from vllm.model_executor.models import supports_lora, supports_multimodal
|
||||||
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
|
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
|
||||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
||||||
MultiModalKwargs, MultiModalPlaceholderMap,
|
MultiModalKwargs, MultiModalPlaceholderMap,
|
||||||
@@ -79,6 +81,8 @@ class ModelInputForNPU(ModelRunnerInputBase):
|
|||||||
token_types: Optional[torch.Tensor] = None
|
token_types: Optional[torch.Tensor] = None
|
||||||
seq_lens: Optional[List[int]] = None
|
seq_lens: Optional[List[int]] = None
|
||||||
query_lens: Optional[List[int]] = None
|
query_lens: Optional[List[int]] = None
|
||||||
|
lora_mapping: Optional["LoRAMapping"] = None
|
||||||
|
lora_requests: Optional[Set[LoRARequest]] = None
|
||||||
attn_metadata: Optional["AttentionMetadata"] = None
|
attn_metadata: Optional["AttentionMetadata"] = None
|
||||||
multi_modal_kwargs: Optional[BatchedTensorInputs] = None
|
multi_modal_kwargs: Optional[BatchedTensorInputs] = None
|
||||||
request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
|
request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
|
||||||
@@ -93,6 +97,8 @@ class ModelInputForNPU(ModelRunnerInputBase):
|
|||||||
tensor_dict = {
|
tensor_dict = {
|
||||||
"input_tokens": self.input_tokens,
|
"input_tokens": self.input_tokens,
|
||||||
"input_positions": self.input_positions,
|
"input_positions": self.input_positions,
|
||||||
|
"lora_requests": self.lora_requests,
|
||||||
|
"lora_mapping": self.lora_mapping,
|
||||||
"multi_modal_kwargs": self.multi_modal_kwargs,
|
"multi_modal_kwargs": self.multi_modal_kwargs,
|
||||||
"virtual_engine": self.virtual_engine,
|
"virtual_engine": self.virtual_engine,
|
||||||
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
|
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
|
||||||
@@ -139,6 +145,8 @@ class ModelInputForNPUWithSamplingMetadata(ModelInputForNPU):
|
|||||||
tensor_dict = {
|
tensor_dict = {
|
||||||
"input_tokens": self.input_tokens,
|
"input_tokens": self.input_tokens,
|
||||||
"input_positions": self.input_positions,
|
"input_positions": self.input_positions,
|
||||||
|
"lora_requests": self.lora_requests,
|
||||||
|
"lora_mapping": self.lora_mapping,
|
||||||
"multi_modal_kwargs": self.multi_modal_kwargs,
|
"multi_modal_kwargs": self.multi_modal_kwargs,
|
||||||
"virtual_engine": self.virtual_engine,
|
"virtual_engine": self.virtual_engine,
|
||||||
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
|
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
|
||||||
@@ -181,6 +189,9 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
|||||||
self.query_lens[0] = 0 # type: ignore
|
self.query_lens[0] = 0 # type: ignore
|
||||||
self.context_lens[0] = 0 # type: ignore
|
self.context_lens[0] = 0 # type: ignore
|
||||||
self.curr_sliding_window_blocks[0] = 0 # type: ignore
|
self.curr_sliding_window_blocks[0] = 0 # type: ignore
|
||||||
|
self.lora_index_mapping.clear() # type: ignore
|
||||||
|
self.lora_prompt_mapping.clear() # type: ignore
|
||||||
|
self.lora_requests.clear() # type: ignore
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -211,6 +222,11 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
|||||||
# The current sliding window block.
|
# The current sliding window block.
|
||||||
curr_sliding_window_blocks: Optional[List[int]] = None,
|
curr_sliding_window_blocks: Optional[List[int]] = None,
|
||||||
|
|
||||||
|
# LoRA inputs.
|
||||||
|
lora_index_mapping: Optional[List[List[int]]] = None,
|
||||||
|
lora_prompt_mapping: Optional[List[List[int]]] = None,
|
||||||
|
lora_requests: Optional[Set[LoRARequest]] = None,
|
||||||
|
|
||||||
# Multi-modal inputs.
|
# Multi-modal inputs.
|
||||||
multi_modal_kwargs: Optional[MultiModalKwargs] = None,
|
multi_modal_kwargs: Optional[MultiModalKwargs] = None,
|
||||||
multi_modal_placeholder_maps: Optional[Dict[
|
multi_modal_placeholder_maps: Optional[Dict[
|
||||||
@@ -291,6 +307,19 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
|||||||
for seq_id in range(len(self.seq_ids)):
|
for seq_id in range(len(self.seq_ids)):
|
||||||
self.curr_sliding_window_blocks[seq_id] = 0
|
self.curr_sliding_window_blocks[seq_id] = 0
|
||||||
|
|
||||||
|
if lora_index_mapping:
|
||||||
|
self.lora_index_mapping = lora_index_mapping
|
||||||
|
else:
|
||||||
|
self.lora_index_mapping.clear()
|
||||||
|
if lora_prompt_mapping:
|
||||||
|
self.lora_prompt_mapping = lora_prompt_mapping
|
||||||
|
else:
|
||||||
|
self.lora_prompt_mapping.clear()
|
||||||
|
if lora_requests:
|
||||||
|
self.lora_requests = lora_requests
|
||||||
|
else:
|
||||||
|
self.lora_requests.clear()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.input_tokens = input_tokens or []
|
self.input_tokens = input_tokens or []
|
||||||
self.input_positions = input_positions or []
|
self.input_positions = input_positions or []
|
||||||
@@ -303,6 +332,10 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
|||||||
self.curr_sliding_window_blocks = \
|
self.curr_sliding_window_blocks = \
|
||||||
curr_sliding_window_blocks or []
|
curr_sliding_window_blocks or []
|
||||||
|
|
||||||
|
self.lora_index_mapping = lora_index_mapping or []
|
||||||
|
self.lora_prompt_mapping = lora_prompt_mapping or []
|
||||||
|
self.lora_requests = lora_requests or set()
|
||||||
|
|
||||||
self.multi_modal_kwargs = multi_modal_kwargs
|
self.multi_modal_kwargs = multi_modal_kwargs
|
||||||
self.multi_modal_placeholder_maps = multi_modal_placeholder_maps
|
self.multi_modal_placeholder_maps = multi_modal_placeholder_maps
|
||||||
self.prefix_cache_hit = prefix_cache_hit
|
self.prefix_cache_hit = prefix_cache_hit
|
||||||
@@ -325,6 +358,9 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
|||||||
self.context_lens = [0] * self.n_seqs
|
self.context_lens = [0] * self.n_seqs
|
||||||
self.curr_sliding_window_blocks = [0] * self.n_seqs
|
self.curr_sliding_window_blocks = [0] * self.n_seqs
|
||||||
|
|
||||||
|
self.lora_index_mapping = []
|
||||||
|
self.lora_prompt_mapping = []
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
runner,
|
runner,
|
||||||
finished_requests_ids: Optional[List[str]] = None):
|
finished_requests_ids: Optional[List[str]] = None):
|
||||||
@@ -335,6 +371,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
|||||||
self._compute_lens,
|
self._compute_lens,
|
||||||
self._compute_for_prefix_cache_hit,
|
self._compute_for_prefix_cache_hit,
|
||||||
self._compute_for_sliding_window,
|
self._compute_for_sliding_window,
|
||||||
|
self._compute_lora_input,
|
||||||
]
|
]
|
||||||
# Compute functions for each sequence group.
|
# Compute functions for each sequence group.
|
||||||
# WARNING: The order of the functions matters!
|
# WARNING: The order of the functions matters!
|
||||||
@@ -348,6 +385,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
|||||||
self.scheduler_config = self.runner.scheduler_config
|
self.scheduler_config = self.runner.scheduler_config
|
||||||
self.sliding_window = self.runner.sliding_window
|
self.sliding_window = self.runner.sliding_window
|
||||||
self.block_size = self.runner.block_size
|
self.block_size = self.runner.block_size
|
||||||
|
self.enable_lora = self.runner.lora_config is not None
|
||||||
self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
|
self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
|
||||||
self.finished_requests_ids = finished_requests_ids
|
self.finished_requests_ids = finished_requests_ids
|
||||||
self.decode_only = True
|
self.decode_only = True
|
||||||
@@ -512,6 +550,25 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
|||||||
# Attention metadata.
|
# Attention metadata.
|
||||||
attn_metadata = self.attn_metadata_builder.build(seq_lens, query_lens)
|
attn_metadata = self.attn_metadata_builder.build(seq_lens, query_lens)
|
||||||
|
|
||||||
|
# LoRA data.
|
||||||
|
lora_requests = set()
|
||||||
|
lora_mapping = None
|
||||||
|
if self.enable_lora:
|
||||||
|
lora_requests = set(r for data in self.inter_data_list
|
||||||
|
for r in data.lora_requests)
|
||||||
|
lora_index_mapping = flatten_2d_lists([
|
||||||
|
flatten_2d_lists(inter_data.lora_index_mapping)
|
||||||
|
for inter_data in self.inter_data_list
|
||||||
|
])
|
||||||
|
lora_prompt_mapping = flatten_2d_lists([
|
||||||
|
flatten_2d_lists(inter_data.lora_prompt_mapping)
|
||||||
|
for inter_data in self.inter_data_list
|
||||||
|
])
|
||||||
|
lora_mapping = LoRAMapping(
|
||||||
|
**dict(index_mapping=lora_index_mapping,
|
||||||
|
prompt_mapping=lora_prompt_mapping,
|
||||||
|
is_prefill=not self.decode_only))
|
||||||
|
|
||||||
# Multi-modal data.
|
# Multi-modal data.
|
||||||
multi_modal_kwargs_list = [
|
multi_modal_kwargs_list = [
|
||||||
data.multi_modal_kwargs for data in self.inter_data_list
|
data.multi_modal_kwargs for data in self.inter_data_list
|
||||||
@@ -525,6 +582,8 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
|||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
query_lens=query_lens,
|
query_lens=query_lens,
|
||||||
|
lora_mapping=lora_mapping,
|
||||||
|
lora_requests=lora_requests,
|
||||||
multi_modal_kwargs=multi_modal_kwargs,
|
multi_modal_kwargs=multi_modal_kwargs,
|
||||||
request_ids_to_seq_ids=request_ids_to_seq_ids,
|
request_ids_to_seq_ids=request_ids_to_seq_ids,
|
||||||
finished_requests_ids=self.finished_requests_ids)
|
finished_requests_ids=self.finished_requests_ids)
|
||||||
@@ -663,6 +722,25 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
|||||||
seq_idx] = curr_sliding_window_block
|
seq_idx] = curr_sliding_window_block
|
||||||
inter_data.seq_lens[seq_idx] = sliding_seq_len
|
inter_data.seq_lens[seq_idx] = sliding_seq_len
|
||||||
|
|
||||||
|
def _compute_lora_input(self, inter_data: InterDataForSeqGroup,
|
||||||
|
seq_idx: int,
|
||||||
|
seq_group_metadata: SequenceGroupMetadata):
|
||||||
|
"""If LoRA is enabled, compute LoRA index and prompt mapping."""
|
||||||
|
if not self.enable_lora:
|
||||||
|
return
|
||||||
|
lora_id = seq_group_metadata.lora_int_id
|
||||||
|
if lora_id > 0:
|
||||||
|
inter_data.lora_requests.add(seq_group_metadata.lora_request)
|
||||||
|
query_len = inter_data.query_lens[seq_idx]
|
||||||
|
inter_data.lora_index_mapping.append([lora_id] * query_len)
|
||||||
|
sampling_params = seq_group_metadata.sampling_params
|
||||||
|
if sampling_params and sampling_params.prompt_logprobs is not None:
|
||||||
|
inter_data.lora_prompt_mapping.append([lora_id] * query_len)
|
||||||
|
elif not self.chunked_prefill_enabled or seq_group_metadata.do_sample:
|
||||||
|
inter_data.lora_prompt_mapping.append([lora_id])
|
||||||
|
else:
|
||||||
|
inter_data.lora_prompt_mapping.append([])
|
||||||
|
|
||||||
def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup,
|
def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup,
|
||||||
seq_group_metadata: SequenceGroupMetadata):
|
seq_group_metadata: SequenceGroupMetadata):
|
||||||
"""If multi-modal data is given, add it to the input."""
|
"""If multi-modal data is given, add it to the input."""
|
||||||
@@ -789,6 +867,8 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]):
|
|||||||
|
|
||||||
# Lazy initialization
|
# Lazy initialization
|
||||||
self.model: nn.Module # Set after load_model
|
self.model: nn.Module # Set after load_model
|
||||||
|
# Set after load_model.
|
||||||
|
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
|
||||||
|
|
||||||
set_cpu_offload_max_bytes(
|
set_cpu_offload_max_bytes(
|
||||||
int(self.cache_config.cpu_offload_gb * 1024**3))
|
int(self.cache_config.cpu_offload_gb * 1024**3))
|
||||||
@@ -818,6 +898,32 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]):
|
|||||||
logger.info("Loading model weights took %.4f GB",
|
logger.info("Loading model weights took %.4f GB",
|
||||||
self.model_memory_usage / float(2**30))
|
self.model_memory_usage / float(2**30))
|
||||||
|
|
||||||
|
if self.lora_config:
|
||||||
|
assert supports_lora(
|
||||||
|
self.model
|
||||||
|
), f"{self.model.__class__.__name__} does not support LoRA yet."
|
||||||
|
if supports_multimodal(self.model):
|
||||||
|
logger.warning("Regarding multimodal models, vLLM currently "
|
||||||
|
"only supports adding LoRA to language model.")
|
||||||
|
# It's necessary to distinguish between the max_position_embeddings
|
||||||
|
# of VLMs and LLMs.
|
||||||
|
if hasattr(self.model.config, "max_position_embeddings"):
|
||||||
|
max_pos_embeddings = self.model.config.max_position_embeddings
|
||||||
|
else:
|
||||||
|
max_pos_embeddings = (
|
||||||
|
self.model.config.text_config.max_position_embeddings)
|
||||||
|
self.lora_manager = LRUCacheWorkerLoRAManager(
|
||||||
|
self.scheduler_config.max_num_seqs,
|
||||||
|
self.scheduler_config.max_num_batched_tokens,
|
||||||
|
self.vocab_size,
|
||||||
|
self.lora_config,
|
||||||
|
self.device,
|
||||||
|
self.model.embedding_modules,
|
||||||
|
self.model.embedding_padding_modules,
|
||||||
|
max_position_embeddings=max_pos_embeddings,
|
||||||
|
)
|
||||||
|
self.model = self.lora_manager.create_lora_manager(self.model)
|
||||||
|
|
||||||
def save_sharded_state(
|
def save_sharded_state(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
@@ -967,23 +1073,35 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]):
|
|||||||
return
|
return
|
||||||
|
|
||||||
def remove_all_loras(self):
|
def remove_all_loras(self):
|
||||||
raise RuntimeError("LoRA is not supported on NPU now.")
|
if not self.lora_manager:
|
||||||
|
raise RuntimeError("LoRA is not enabled.")
|
||||||
|
self.lora_manager.remove_all_adapters()
|
||||||
|
|
||||||
def set_active_loras(self, lora_requests: Set[LoRARequest],
|
def set_active_loras(self, lora_requests: Set[LoRARequest],
|
||||||
lora_mapping: LoRAMapping) -> None:
|
lora_mapping: LoRAMapping) -> None:
|
||||||
raise RuntimeError("LoRA is not supported on NPU now.")
|
if not self.lora_manager:
|
||||||
|
raise RuntimeError("LoRA is not enabled.")
|
||||||
|
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
|
||||||
|
|
||||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
raise RuntimeError("LoRA is not supported on NPU now.")
|
if not self.lora_manager:
|
||||||
|
raise RuntimeError("LoRA is not enabled.")
|
||||||
|
return self.lora_manager.add_adapter(lora_request)
|
||||||
|
|
||||||
def remove_lora(self, lora_id: int) -> bool:
|
def remove_lora(self, lora_id: int) -> bool:
|
||||||
raise RuntimeError("LoRA is not supported on NPU now.")
|
if not self.lora_manager:
|
||||||
|
raise RuntimeError("LoRA is not enabled.")
|
||||||
|
return self.lora_manager.remove_adapter(lora_id)
|
||||||
|
|
||||||
def pin_lora(self, lora_id: int) -> bool:
|
def pin_lora(self, lora_id: int) -> bool:
|
||||||
raise RuntimeError("LoRA is not supported on NPU now.")
|
if not self.lora_manager:
|
||||||
|
raise RuntimeError("LoRA is not enabled.")
|
||||||
|
return self.lora_manager.pin_adapter(lora_id)
|
||||||
|
|
||||||
def list_loras(self) -> Set[int]:
|
def list_loras(self) -> Set[int]:
|
||||||
raise RuntimeError("LoRA is not supported on NPU now.")
|
if not self.lora_manager:
|
||||||
|
raise RuntimeError("LoRA is not enabled.")
|
||||||
|
return self.lora_manager.list_adapters()
|
||||||
|
|
||||||
def remove_all_prompt_adapters(self):
|
def remove_all_prompt_adapters(self):
|
||||||
raise RuntimeError("PromptAdapter is not supported on NPU now.")
|
raise RuntimeError("PromptAdapter is not supported on NPU now.")
|
||||||
@@ -1086,6 +1204,12 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
|||||||
if num_steps > 1:
|
if num_steps > 1:
|
||||||
raise ValueError("num_steps > 1 is not supported in ModelRunner")
|
raise ValueError("num_steps > 1 is not supported in ModelRunner")
|
||||||
|
|
||||||
|
if self.lora_config:
|
||||||
|
assert model_input.lora_requests is not None
|
||||||
|
assert model_input.lora_mapping is not None
|
||||||
|
self.set_active_loras(model_input.lora_requests,
|
||||||
|
model_input.lora_mapping)
|
||||||
|
|
||||||
self.attn_state.begin_forward(model_input)
|
self.attn_state.begin_forward(model_input)
|
||||||
|
|
||||||
assert model_input.attn_metadata is not None
|
assert model_input.attn_metadata is not None
|
||||||
|
|||||||
@@ -404,20 +404,16 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
raise NotImplementedError(
|
return self.model_runner.add_lora(lora_request)
|
||||||
"LoRA is not implemented for NPU backend currently.")
|
|
||||||
|
|
||||||
def remove_lora(self, lora_id: int) -> bool:
|
def remove_lora(self, lora_id: int) -> bool:
|
||||||
raise NotImplementedError(
|
return self.model_runner.remove_lora(lora_id)
|
||||||
"LoRA is not implemented for NPU backend currently.")
|
|
||||||
|
|
||||||
def pin_lora(self, lora_id: int) -> bool:
|
def pin_lora(self, lora_id: int) -> bool:
|
||||||
raise NotImplementedError(
|
return self.model_runner.pin_lora(lora_id)
|
||||||
"LoRA is not implemented for NPU backend currently.")
|
|
||||||
|
|
||||||
def list_loras(self) -> Set[int]:
|
def list_loras(self) -> Set[int]:
|
||||||
raise NotImplementedError(
|
return self.model_runner.list_loras()
|
||||||
"LoRA is not implemented for NPU backend currently.")
|
|
||||||
|
|
||||||
def add_prompt_adapter(
|
def add_prompt_adapter(
|
||||||
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
|
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
|
||||||
|
|||||||
Reference in New Issue
Block a user