Add Custom Kernels For LoRA Performance (#2325)

### What this PR does / why we need it?
Add two custom operators (sgmv_shrink and sgmv_expand) to address the
performance issues of LoRA. Meanwhile, enable the graph mode for LoRA
operators to enter ACL, so as to improve the model inference
performance.
### Does this PR introduce _any_ user-facing change?
      no user-facing change
### How was this patch tested?
Based on the actual test of the QWen2.5 7B model using vllm-ascend
version v0.9.2.rc1, in acl graph mode, the TTFT, TPOT and throughput
have increased by about 100%.

Signed-off-by: liuchn <909698896@qq.com>

- vLLM version: v0.10.0
- vLLM main:
1f83e7d849

---------

Signed-off-by: liuchn <909698896@qq.com>
Co-authored-by: liuchn <909698896@qq.com>
This commit is contained in:
liuchenbing
2025-08-19 09:09:11 +08:00
committed by GitHub
parent 8fb50a4248
commit 3648d18e67
8 changed files with 847 additions and 29 deletions

View File

@@ -52,9 +52,14 @@ def bgmv_expand_slice(inputs: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True):
return torch.ops._C.bgmv_expand(inputs, lora_b_weights,
lora_indices_tensor, output_tensor,
slice_offset, slice_size)
return torch.ops._C.bgmv_expand(
inputs,
lora_b_weights,
lora_indices_tensor,
output_tensor,
slice_offset,
slice_size
)
def sgmv_shrink(
@@ -69,11 +74,8 @@ def sgmv_shrink(
token_nums: int,
scaling: float,
):
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
seq_len_tensor)
bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices,
scaling)
return torch.ops._C.sgmv_shrink(inputs, lora_a_weights, lora_indices_tensor,
seq_len_tensor, output_tensor, scaling)
def sgmv_expand(inputs: torch.Tensor,
@@ -86,11 +88,15 @@ def sgmv_expand(inputs: torch.Tensor,
max_seq_length: int,
token_nums: int,
add_inputs: bool = False):
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
seq_len_tensor)
bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices,
add_inputs)
return torch.ops._C.sgmv_expand(
inputs,
lora_b_weights,
lora_indices_tensor,
seq_len_tensor,
output_tensor,
0,
output_tensor.size(1),
)
def sgmv_expand_slice(inputs: torch.Tensor,
@@ -105,8 +111,12 @@ def sgmv_expand_slice(inputs: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = False):
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
seq_len_tensor)
bgmv_expand_slice(inputs, lora_b_weights, output_tensor, exploded_indices,
slice_offset, slice_size, add_inputs)
return torch.ops._C.sgmv_expand(
inputs,
lora_b_weights,
lora_indices_tensor,
seq_len_tensor,
output_tensor,
slice_offset,
slice_size
)

View File

@@ -22,8 +22,8 @@ from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
# inherit this class
class PunicaWrapperNPU(PunicaWrapperBase):
"""
PunicaWrapperNPU is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for
PunicaWrapperNPU is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for
Multi-LoRA, and to provide the interface for the pytorch punica ops.
"""
@@ -130,7 +130,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
add_inputs: bool = True,
):
"""
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
computation, which is suitable for the
GEMM of lora'b.
"""
@@ -166,11 +166,11 @@ class PunicaWrapperNPU(PunicaWrapperBase):
prefill stage, and the `_shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the _shrink_decode function
should be called.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (x @ lora_a_stacked[i]) * scale
Args:
y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
x (torch.Tensor): Input tensor
@@ -195,19 +195,19 @@ class PunicaWrapperNPU(PunicaWrapperBase):
**kwargs) -> None:
"""
Performs GEMM and bias addition for multiple slices of lora_b.
Semantics:
for i in range(len(lora_b_stacked)):
slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
lora_bias_stacked[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
lora_bias_stacked[i]
offset += slice
Args:
y (torch.Tensor): Output tensor.
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
bias's weight
output_slices (Tuple[int, ...]): Every slice's size
add_inputs (bool): Defaults to True.
@@ -266,7 +266,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
**kwargs) -> None:
"""
Applicable to linear-related lora.
Applicable to linear-related lora.
Semantics:
for i in range(len(lora_a_stacked)):