### What this PR does / why we need it?
Add two custom kernels(bgmv_shrink and bgmv expand) to solve the
performance of LoRA
### Does this PR introduce _any_ user-facing change?
no user-facing change
### How was this patch tested?
we add Unit Test file to test the custom ascendc kernel. See
vllm-ascend/tests/e2e/singlecard/ops/test_bgmv_expand.py and
vllm-ascend/tests/e2e/singlecard/ops/test_bgmv_expand.py
Based on the actual test of the QWen2.5 7B model using vllm-ascend
version v0.9.2.rc1, the TTFT, TPOT and throughput have increased by
about 70%.
- vLLM version: v0.9.2
- vLLM main:
40d86ee412
---------
Signed-off-by: taoxudonghaha <justsheldon@163.com>
41 lines
1.1 KiB
Python
41 lines
1.1 KiB
Python
import torch
|
|
|
|
from vllm_ascend.utils import enable_custom_op
|
|
|
|
enable_custom_op()
|
|
|
|
DEFAULT_ATOL = 1e-3
|
|
DEFAULT_RTOL = 1e-3
|
|
|
|
|
|
def bgmv_shrink_cpu_impl(x: torch.Tensor, w: torch.Tensor,
|
|
indices: torch.Tensor, y: torch.tensor,
|
|
scaling: float) -> torch.Tensor:
|
|
W = w[indices, :, :].transpose(-1, -2).to(torch.float32)
|
|
z = torch.bmm(x.unsqueeze(1).to(torch.float32), W).squeeze()
|
|
y[:, :] += z * scaling
|
|
return y
|
|
|
|
|
|
@torch.inference_mode()
|
|
def test_bgmv_shrink() -> None:
|
|
B = 1
|
|
x = torch.randn([B, 128], dtype=torch.float16)
|
|
w = torch.randn([64, 16, 128], dtype=torch.float16)
|
|
indices = torch.zeros([B], dtype=torch.int64)
|
|
y = torch.zeros([B, 16])
|
|
|
|
x_npu = x.npu()
|
|
w_npu = w.npu()
|
|
indices_npu = indices.npu()
|
|
y_npu = y.npu()
|
|
|
|
y = bgmv_shrink_cpu_impl(x, w, indices, y, 0.5)
|
|
torch.ops._C.bgmv_shrink(x_npu, w_npu, indices_npu, y_npu, 0.5)
|
|
|
|
# Compare the results.
|
|
torch.testing.assert_close(y_npu.cpu(),
|
|
y,
|
|
atol=DEFAULT_ATOL,
|
|
rtol=DEFAULT_RTOL)
|