Files
xc-llm-ascend/vllm_ascend/lora/punica_npu.py
yupeng 830f39dd70 [Bugfix][LoRA] Fix the issue when enable LoRA + tp + fully_sharded_loras (#6650)
### What this PR does / why we need it?
Fix the issue #6143 .

### Does this PR introduce _any_ user-facing change?
Allow to start the server with "--enable-lora && --fully-sharded-loras
&& --tensor_parallel_size 2".

### How was this patch tested?
pytest -sv tests/e2e/multicard/2-cards/test_llama32_lora_tp2.py
- vLLM version: v0.15.0
- vLLM main:
d7e17aaacd

---------

Signed-off-by: paulyu12 <507435917@qq.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
2026-03-11 15:43:15 +08:00

353 lines
11 KiB
Python
Executable File

# SPDX-License-Identifier: Apache-2.0
from collections.abc import Callable
import torch
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
from vllm_ascend.lora.utils import refresh_all_lora_classes
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
# The platforms that are compatible with the PyTorch-native implementation can
# inherit this class
class PunicaWrapperNPU(PunicaWrapperBase):
"""
PunicaWrapperNPU is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for
Multi-LoRA, and to provide the interface for the pytorch punica ops.
"""
def __init__(self, max_num_batched_tokens: int, max_batches: int, device: torch.device | str, **kwargs):
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device)
refresh_all_lora_classes()
self.lora_config = kwargs.get("lora_config")
if get_ascend_device_type() == AscendDeviceType._310P or (
self.lora_config is not None and self.lora_config.max_lora_rank >= 128
):
from vllm.lora.ops.torch_ops import (
bgmv_expand,
bgmv_expand_slice,
bgmv_shrink,
sgmv_expand,
sgmv_expand_slice,
sgmv_shrink,
)
else:
from vllm_ascend.lora.lora_ops import (
bgmv_expand,
bgmv_expand_slice,
bgmv_shrink,
sgmv_expand,
sgmv_expand_slice,
sgmv_shrink,
)
self.bgmv_expand = bgmv_expand
self.bgmv_expand_slice = bgmv_expand_slice
self.bgmv_shrink = bgmv_shrink
self.sgmv_expand = sgmv_expand
self.sgmv_expand_slice = sgmv_expand_slice
self.sgmv_shrink = sgmv_shrink
def _shrink_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
# No LoRA request, so return directly
if self.no_lora:
return
self.sgmv_shrink(
x,
w_t_all,
y,
*self.prefill_metadata,
scale,
)
def _shrink_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
self.bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
def _expand_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
add_inputs: bool,
):
# No LoRA request, so return directly
if self.no_lora:
return
self.sgmv_expand(
x,
w_t_all,
y,
*self.prefill_metadata,
add_inputs,
)
def _expand_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
add_inputs: bool,
):
self.bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs)
def _expand_slice_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: int,
y_slice_size: int,
add_inputs: bool,
):
# No LoRA request, so return directly
if self.no_lora:
return
self.sgmv_expand_slice(
x,
w_t_all,
y,
*self.prefill_metadata,
y_offset,
y_slice_size,
add_inputs,
)
def _expand_slice_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: int,
y_slice_size: int,
add_inputs: bool,
):
self.bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, y_slice_size, add_inputs)
def _apply_expand(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: int,
y_slice_size: int,
add_inputs: bool = True,
):
"""
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
computation, which is suitable for the
GEMM of lora'b.
"""
expand_slice_fun: Callable = self._expand_slice_prefill if self.is_prefill else self._expand_slice_decode
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs)
def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, scale: float):
"""
Perform the ` y+=x@w_t_all` computation, which is suitable for the
GEMM of lora'a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `_shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the _shrink_decode function
should be called.
"""
y_org = y
y = y.view(-1, y.shape[-1])
shrink_fun: Callable = self._shrink_prefill if self.is_prefill else self._shrink_decode
shrink_fun(y, x, w_t_all, scale)
y = y.view_as(y_org)
def add_shrink(
self,
y: tuple[torch.Tensor, ...] | torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
scale: float,
**kwargs,
):
"""
Performs GEMM for multiple slices of lora_a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `_shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the _shrink_decode function
should be called.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (x @ lora_a_stacked[i]) * scale
Args:
y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation
"""
x = x.view(-1, x.shape[-1])
# TODO fuse these kernels
for slice_idx in range(len(lora_a_stacked)):
self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], scale)
def add_expand(
self,
y: torch.Tensor,
x: tuple[torch.Tensor, ...] | torch.Tensor,
lora_b_stacked: tuple[torch.Tensor, ...],
output_slices: tuple[int, ...],
offset_start: int = 0,
add_inputs=True,
**kwargs,
) -> None:
"""
Performs GEMM and bias addition for multiple slices of lora_b.
Semantics:
for i in range(len(lora_b_stacked)):
slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
offset += slice
Args:
y (torch.Tensor): Output tensor.
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
output_slices (Tuple[int, ...]): Every slice's size
offset_start (int): The starting position of y, defaults to 0
add_inputs (bool): Defaults to True.
"""
y_org = y
y = y.view(-1, y.shape[-1])
offset_left = offset_start
for slice_idx in range(len(lora_b_stacked)):
self._apply_expand(
y,
x[slice_idx],
lora_b_stacked[slice_idx],
offset_left,
output_slices[slice_idx],
add_inputs=add_inputs,
)
offset_left += output_slices[slice_idx]
y = y.view_as(y_org)
def add_lora_embedding(
self, y: torch.Tensor, x: torch.Tensor, lora_b_stacked: torch.Tensor, add_inputs: bool = True, **kwargs
) -> None:
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
Semantics:
y += x @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_b_stacked (torch.Tensor): lora_b's weights.
add_inputs (bool): Default to True.
"""
# Embedding layer only need expand op
expand_fun: Callable = self._expand_prefill if self.is_prefill else self._expand_decode
x = x.to(torch.float32)
expand_fun(y, x, lora_b_stacked, add_inputs)
def add_lora_linear(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
scale: float,
output_slices: tuple[int, ...],
*,
buffer: tuple[torch.Tensor, ...] | None = None,
**kwargs,
) -> None:
"""
Applicable to linear-related lora.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (
x[i].unsqueeze(0)
@ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :]
* scale
).squeeze(0)+lora_bias_stacked[i]
Args:
y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight.
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias.
scale (float): Scaling factor.
output_slices (Tuple[int, ...]): Every slice's size.
buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
"""
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
if buffer is None:
r = lora_b_stacked[0].size(-1)
# We set the buffer to be float32 by default, consistent with the
# triton op
buffer = tuple(
torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) for _ in range(len(output_slices))
)
self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
self.add_expand(y, buffer, lora_b_stacked, output_slices, add_inputs=True, **kwargs)
def add_lora_logits(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: torch.Tensor | None = None,
**kwargs,
) -> None:
"""
Applies lora specifically for LogitsProcessorWithLoRA.
Semantics:
buffer = (x @ lora_a_stacked) * scale
y += buffer @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_a_stacked (torch.Tensor): lora_a's weights.
lora_b_stacked (torch.Tensor):lora_b's weights.
scale (float): Scaling factor.
buffer (Optional[torch.Tensor]):Default to None.
"""
y_org = y
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
r = lora_b_stacked.size(-1)
if buffer is None:
buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device)
indices = self.sampler_indices
self.bgmv_shrink(x, lora_a_stacked, buffer, indices, scale)
self.bgmv_expand(buffer, lora_b_stacked, y, indices, add_inputs=True)
y = y.view_as(y_org)