Add Custom Kernels For LoRA Performance (#2325)
### What this PR does / why we need it?
Add two custom operators (sgmv_shrink and sgmv_expand) to address the
performance issues of LoRA. Meanwhile, enable the graph mode for LoRA
operators to enter ACL, so as to improve the model inference
performance.
### Does this PR introduce _any_ user-facing change?
no user-facing change
### How was this patch tested?
Based on the actual test of the QWen2.5 7B model using vllm-ascend
version v0.9.2.rc1, in acl graph mode, the TTFT, TPOT and throughput
have increased by about 100%.
Signed-off-by: liuchn <909698896@qq.com>
- vLLM version: v0.10.0
- vLLM main:
1f83e7d849
---------
Signed-off-by: liuchn <909698896@qq.com>
Co-authored-by: liuchn <909698896@qq.com>
This commit is contained in:
@@ -52,9 +52,14 @@ def bgmv_expand_slice(inputs: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = True):
|
||||
return torch.ops._C.bgmv_expand(inputs, lora_b_weights,
|
||||
lora_indices_tensor, output_tensor,
|
||||
slice_offset, slice_size)
|
||||
return torch.ops._C.bgmv_expand(
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
lora_indices_tensor,
|
||||
output_tensor,
|
||||
slice_offset,
|
||||
slice_size
|
||||
)
|
||||
|
||||
|
||||
def sgmv_shrink(
|
||||
@@ -69,11 +74,8 @@ def sgmv_shrink(
|
||||
token_nums: int,
|
||||
scaling: float,
|
||||
):
|
||||
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
|
||||
seq_len_tensor)
|
||||
|
||||
bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices,
|
||||
scaling)
|
||||
return torch.ops._C.sgmv_shrink(inputs, lora_a_weights, lora_indices_tensor,
|
||||
seq_len_tensor, output_tensor, scaling)
|
||||
|
||||
|
||||
def sgmv_expand(inputs: torch.Tensor,
|
||||
@@ -86,11 +88,15 @@ def sgmv_expand(inputs: torch.Tensor,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
add_inputs: bool = False):
|
||||
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
|
||||
seq_len_tensor)
|
||||
|
||||
bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices,
|
||||
add_inputs)
|
||||
return torch.ops._C.sgmv_expand(
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
output_tensor,
|
||||
0,
|
||||
output_tensor.size(1),
|
||||
)
|
||||
|
||||
|
||||
def sgmv_expand_slice(inputs: torch.Tensor,
|
||||
@@ -105,8 +111,12 @@ def sgmv_expand_slice(inputs: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = False):
|
||||
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
|
||||
seq_len_tensor)
|
||||
|
||||
bgmv_expand_slice(inputs, lora_b_weights, output_tensor, exploded_indices,
|
||||
slice_offset, slice_size, add_inputs)
|
||||
return torch.ops._C.sgmv_expand(
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
output_tensor,
|
||||
slice_offset,
|
||||
slice_size
|
||||
)
|
||||
|
||||
@@ -22,8 +22,8 @@ from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
|
||||
# inherit this class
|
||||
class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
"""
|
||||
PunicaWrapperNPU is designed to manage and provide metadata for the punica
|
||||
kernel. The main function is to maintain the state information for
|
||||
PunicaWrapperNPU is designed to manage and provide metadata for the punica
|
||||
kernel. The main function is to maintain the state information for
|
||||
Multi-LoRA, and to provide the interface for the pytorch punica ops.
|
||||
"""
|
||||
|
||||
@@ -130,7 +130,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
add_inputs: bool = True,
|
||||
):
|
||||
"""
|
||||
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
|
||||
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
|
||||
computation, which is suitable for the
|
||||
GEMM of lora'b.
|
||||
"""
|
||||
@@ -166,11 +166,11 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
prefill stage, and the `_shrink_prefill` function should be called.
|
||||
Otherwise, it is the decode stage, and the _shrink_decode function
|
||||
should be called.
|
||||
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_a_stacked)):
|
||||
y[i] += (x @ lora_a_stacked[i]) * scale
|
||||
|
||||
|
||||
Args:
|
||||
y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
|
||||
x (torch.Tensor): Input tensor
|
||||
@@ -195,19 +195,19 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
**kwargs) -> None:
|
||||
"""
|
||||
Performs GEMM and bias addition for multiple slices of lora_b.
|
||||
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_b_stacked)):
|
||||
slice = output_slices[i]
|
||||
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
|
||||
lora_bias_stacked[i]
|
||||
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
|
||||
lora_bias_stacked[i]
|
||||
offset += slice
|
||||
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
|
||||
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
|
||||
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
|
||||
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
|
||||
bias's weight
|
||||
output_slices (Tuple[int, ...]): Every slice's size
|
||||
add_inputs (bool): Defaults to True.
|
||||
@@ -266,7 +266,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
|
||||
**kwargs) -> None:
|
||||
"""
|
||||
Applicable to linear-related lora.
|
||||
Applicable to linear-related lora.
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_a_stacked)):
|
||||
|
||||
Reference in New Issue
Block a user