Add Custom Kernels For LoRA Performance (#2325)

### What this PR does / why we need it?
Add two custom operators (sgmv_shrink and sgmv_expand) to address the
performance issues of LoRA. Meanwhile, enable the graph mode for LoRA
operators to enter ACL, so as to improve the model inference
performance.
### Does this PR introduce _any_ user-facing change?
      no user-facing change
### How was this patch tested?
Based on the actual test of the QWen2.5 7B model using vllm-ascend
version v0.9.2.rc1, in acl graph mode, the TTFT, TPOT and throughput
have increased by about 100%.

Signed-off-by: liuchn <909698896@qq.com>

- vLLM version: v0.10.0
- vLLM main:
1f83e7d849

---------

Signed-off-by: liuchn <909698896@qq.com>
Co-authored-by: liuchn <909698896@qq.com>
This commit is contained in:
liuchenbing
2025-08-19 09:09:11 +08:00
committed by GitHub
parent 8fb50a4248
commit 3648d18e67
8 changed files with 847 additions and 29 deletions

View File

@@ -52,9 +52,14 @@ def bgmv_expand_slice(inputs: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True):
return torch.ops._C.bgmv_expand(inputs, lora_b_weights,
lora_indices_tensor, output_tensor,
slice_offset, slice_size)
return torch.ops._C.bgmv_expand(
inputs,
lora_b_weights,
lora_indices_tensor,
output_tensor,
slice_offset,
slice_size
)
def sgmv_shrink(
@@ -69,11 +74,8 @@ def sgmv_shrink(
token_nums: int,
scaling: float,
):
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
seq_len_tensor)
bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices,
scaling)
return torch.ops._C.sgmv_shrink(inputs, lora_a_weights, lora_indices_tensor,
seq_len_tensor, output_tensor, scaling)
def sgmv_expand(inputs: torch.Tensor,
@@ -86,11 +88,15 @@ def sgmv_expand(inputs: torch.Tensor,
max_seq_length: int,
token_nums: int,
add_inputs: bool = False):
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
seq_len_tensor)
bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices,
add_inputs)
return torch.ops._C.sgmv_expand(
inputs,
lora_b_weights,
lora_indices_tensor,
seq_len_tensor,
output_tensor,
0,
output_tensor.size(1),
)
def sgmv_expand_slice(inputs: torch.Tensor,
@@ -105,8 +111,12 @@ def sgmv_expand_slice(inputs: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = False):
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
seq_len_tensor)
bgmv_expand_slice(inputs, lora_b_weights, output_tensor, exploded_indices,
slice_offset, slice_size, add_inputs)
return torch.ops._C.sgmv_expand(
inputs,
lora_b_weights,
lora_indices_tensor,
seq_len_tensor,
output_tensor,
slice_offset,
slice_size
)