Add Custom Kernels For LoRA Performance (#2325)
### What this PR does / why we need it?
Add two custom operators (sgmv_shrink and sgmv_expand) to address the
performance issues of LoRA. Meanwhile, enable the graph mode for LoRA
operators to enter ACL, so as to improve the model inference
performance.
### Does this PR introduce _any_ user-facing change?
no user-facing change
### How was this patch tested?
Based on the actual test of the QWen2.5 7B model using vllm-ascend
version v0.9.2.rc1, in acl graph mode, the TTFT, TPOT and throughput
have increased by about 100%.
Signed-off-by: liuchn <909698896@qq.com>
- vLLM version: v0.10.0
- vLLM main:
1f83e7d849
---------
Signed-off-by: liuchn <909698896@qq.com>
Co-authored-by: liuchn <909698896@qq.com>
This commit is contained in:
@@ -52,9 +52,14 @@ def bgmv_expand_slice(inputs: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = True):
|
||||
return torch.ops._C.bgmv_expand(inputs, lora_b_weights,
|
||||
lora_indices_tensor, output_tensor,
|
||||
slice_offset, slice_size)
|
||||
return torch.ops._C.bgmv_expand(
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
lora_indices_tensor,
|
||||
output_tensor,
|
||||
slice_offset,
|
||||
slice_size
|
||||
)
|
||||
|
||||
|
||||
def sgmv_shrink(
|
||||
@@ -69,11 +74,8 @@ def sgmv_shrink(
|
||||
token_nums: int,
|
||||
scaling: float,
|
||||
):
|
||||
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
|
||||
seq_len_tensor)
|
||||
|
||||
bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices,
|
||||
scaling)
|
||||
return torch.ops._C.sgmv_shrink(inputs, lora_a_weights, lora_indices_tensor,
|
||||
seq_len_tensor, output_tensor, scaling)
|
||||
|
||||
|
||||
def sgmv_expand(inputs: torch.Tensor,
|
||||
@@ -86,11 +88,15 @@ def sgmv_expand(inputs: torch.Tensor,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
add_inputs: bool = False):
|
||||
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
|
||||
seq_len_tensor)
|
||||
|
||||
bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices,
|
||||
add_inputs)
|
||||
return torch.ops._C.sgmv_expand(
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
output_tensor,
|
||||
0,
|
||||
output_tensor.size(1),
|
||||
)
|
||||
|
||||
|
||||
def sgmv_expand_slice(inputs: torch.Tensor,
|
||||
@@ -105,8 +111,12 @@ def sgmv_expand_slice(inputs: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = False):
|
||||
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
|
||||
seq_len_tensor)
|
||||
|
||||
bgmv_expand_slice(inputs, lora_b_weights, output_tensor, exploded_indices,
|
||||
slice_offset, slice_size, add_inputs)
|
||||
return torch.ops._C.sgmv_expand(
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
output_tensor,
|
||||
slice_offset,
|
||||
slice_size
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user