[BufFix]Fix the error when using Ascend custom operators with rank=128 (#5394)

### What this PR does / why we need it?
The customized ascend operator sgmv_expand and sgmv_shrink applies only
to the scenario where rank is 8,16,32,64. When rank >= 128, the operator
is out of range, causing the model to report an error.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
Depends on this commit https://github.com/vllm-project/vllm/pull/31408 
- vLLM version: release/v0.13.0
- vLLM main:
254f6b9867

---------

Signed-off-by: ZT-AIA <1028681969@qq.com>
Signed-off-by: ZT-AIA <63220130+ZT-AIA@users.noreply.github.com>
This commit is contained in:
ZT-AIA
2026-01-09 15:57:43 +08:00
committed by GitHub
parent d36ca88cf4
commit e11ff8e535
4 changed files with 35 additions and 23 deletions

View File

@@ -105,7 +105,7 @@ jobs:
# xgrammar has parameter mismatching bug, please follows: https://github.com/vllm-project/vllm-ascend/issues/5524
# pytest -sv --durations=0 tests/e2e/singlecard/test_guided_decoding.py
# torch 2.8 doesn't work with lora, fix me
#pytest -sv --durations=0 tests/e2e/singlecard/test_ilama_lora.py
pytest -sv --durations=0 tests/e2e/singlecard/test_ilama_lora.py
pytest -sv --durations=0 tests/e2e/singlecard/test_models.py
pytest -sv --durations=0 tests/e2e/singlecard/test_multistream_overlap_shared_expert.py
pytest -sv --durations=0 tests/e2e/singlecard/test_profile_execute_duration.py
@@ -216,7 +216,7 @@ jobs:
pytest -sv --durations=0 tests/e2e/multicard/2-cards/test_external_launcher.py
pytest -sv --durations=0 tests/e2e/multicard/2-cards/test_full_graph_mode.py
# torch 2.8 doesn't work with lora, fix me
#pytest -sv --durations=0 tests/e2e/multicard/2-cards/test_ilama_lora_tp2.py
pytest -sv --durations=0 tests/e2e/multicard/2-cards/test_ilama_lora_tp2.py
# To avoid oom, we need to run the test in a single process.

View File

@@ -18,6 +18,7 @@ def test_ilama_lora_tp2(distributed_executor_backend, ilama_lora_files):
tensor_parallel_size=2,
cudagraph_capture_sizes=[1, 2, 4, 8],
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True,
) as vllm_model:
output = do_sample(vllm_model.model, ilama_lora_files, lora_id=2)

View File

@@ -53,6 +53,7 @@ def test_ilama_lora(ilama_lora_files):
max_model_len=1024,
cudagraph_capture_sizes=[1, 2, 4, 8],
max_num_seqs=16,
enforce_eager=True,
) as vllm_model:
output1 = do_sample(vllm_model.model, ilama_lora_files, lora_id=1)

View File

@@ -3,21 +3,10 @@
from typing import Callable, Optional, Tuple, Union
import torch
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
if get_ascend_device_type() == AscendDeviceType._310P:
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
bgmv_shrink, sgmv_expand,
sgmv_expand_slice, sgmv_shrink)
else:
from vllm_ascend.lora.lora_ops import (bgmv_expand, bgmv_expand_slice,
bgmv_shrink, sgmv_expand,
sgmv_expand_slice, sgmv_shrink)
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
from vllm_ascend.lora.utils import refresh_all_lora_classes
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
# The platforms that are compatible with the PyTorch-native implementation can
@@ -34,6 +23,27 @@ class PunicaWrapperNPU(PunicaWrapperBase):
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
device)
refresh_all_lora_classes()
self.lora_config = kwargs.get("lora_config")
if get_ascend_device_type() == AscendDeviceType._310P or (
self.lora_config is not None
and self.lora_config.max_lora_rank >= 128):
from vllm.lora.ops.torch_ops import (bgmv_expand,
bgmv_expand_slice,
bgmv_shrink, sgmv_expand,
sgmv_expand_slice,
sgmv_shrink)
else:
from vllm_ascend.lora.lora_ops import (bgmv_expand,
bgmv_expand_slice,
bgmv_shrink, sgmv_expand,
sgmv_expand_slice,
sgmv_shrink)
self.bgmv_expand = bgmv_expand
self.bgmv_expand_slice = bgmv_expand_slice
self.bgmv_shrink = bgmv_shrink
self.sgmv_expand = sgmv_expand
self.sgmv_expand_slice = sgmv_expand_slice
self.sgmv_shrink = sgmv_shrink
def _shrink_prefill(
self,
@@ -45,7 +55,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_shrink(
self.sgmv_shrink(
x,
w_t_all,
y,
@@ -60,7 +70,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
w_t_all: torch.Tensor,
scale: float,
):
bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
self.bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
def _expand_prefill(
self,
@@ -72,7 +82,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_expand(
self.sgmv_expand(
x,
w_t_all,
y,
@@ -87,7 +97,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
w_t_all: torch.Tensor,
add_inputs: bool,
):
bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs)
self.bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs)
def _expand_slice_prefill(
self,
@@ -101,7 +111,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_expand_slice(
self.sgmv_expand_slice(
x,
w_t_all,
y,
@@ -120,8 +130,8 @@ class PunicaWrapperNPU(PunicaWrapperBase):
y_slice_size: int,
add_inputs: bool,
):
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
y_slice_size, add_inputs)
self.bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices,
y_offset, y_slice_size, add_inputs)
def _apply_expand(
self,
@@ -346,7 +356,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
indices = self.sampler_indices
bgmv_shrink(x, lora_a_stacked, buffer, indices, scale)
bgmv_expand(buffer, lora_b_stacked, y, indices, add_inputs=True)
self.bgmv_shrink(x, lora_a_stacked, buffer, indices, scale)
self.bgmv_expand(buffer, lora_b_stacked, y, indices, add_inputs=True)
y = y.view_as(y_org)