From e11ff8e535b7bce5614f1eb13b09255df26b8fff Mon Sep 17 00:00:00 2001 From: ZT-AIA <63220130+ZT-AIA@users.noreply.github.com> Date: Fri, 9 Jan 2026 15:57:43 +0800 Subject: [PATCH] [BufFix]Fix the error when using Ascend custom operators with rank=128 (#5394) ### What this PR does / why we need it? The customized ascend operator sgmv_expand and sgmv_shrink applies only to the scenario where rank is 8,16,32,64. When rank >= 128, the operator is out of range, causing the model to report an error. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Depends on this commit https://github.com/vllm-project/vllm/pull/31408 - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/254f6b986720c92ddf97fbb1a6a6465da8e87e29 --------- Signed-off-by: ZT-AIA <1028681969@qq.com> Signed-off-by: ZT-AIA <63220130+ZT-AIA@users.noreply.github.com> --- .github/workflows/_e2e_test.yaml | 4 +- .../multicard/2-cards/test_ilama_lora_tp2.py | 1 + tests/e2e/singlecard/test_ilama_lora.py | 1 + vllm_ascend/lora/punica_npu.py | 52 +++++++++++-------- 4 files changed, 35 insertions(+), 23 deletions(-) diff --git a/.github/workflows/_e2e_test.yaml b/.github/workflows/_e2e_test.yaml index 0418b6a0..79c20073 100644 --- a/.github/workflows/_e2e_test.yaml +++ b/.github/workflows/_e2e_test.yaml @@ -105,7 +105,7 @@ jobs: # xgrammar has parameter mismatching bug, please follows: https://github.com/vllm-project/vllm-ascend/issues/5524 # pytest -sv --durations=0 tests/e2e/singlecard/test_guided_decoding.py # torch 2.8 doesn't work with lora, fix me - #pytest -sv --durations=0 tests/e2e/singlecard/test_ilama_lora.py + pytest -sv --durations=0 tests/e2e/singlecard/test_ilama_lora.py pytest -sv --durations=0 tests/e2e/singlecard/test_models.py pytest -sv --durations=0 tests/e2e/singlecard/test_multistream_overlap_shared_expert.py pytest -sv --durations=0 tests/e2e/singlecard/test_profile_execute_duration.py @@ -216,7 +216,7 @@ jobs: pytest -sv --durations=0 tests/e2e/multicard/2-cards/test_external_launcher.py pytest -sv --durations=0 tests/e2e/multicard/2-cards/test_full_graph_mode.py # torch 2.8 doesn't work with lora, fix me - #pytest -sv --durations=0 tests/e2e/multicard/2-cards/test_ilama_lora_tp2.py + pytest -sv --durations=0 tests/e2e/multicard/2-cards/test_ilama_lora_tp2.py # To avoid oom, we need to run the test in a single process. diff --git a/tests/e2e/multicard/2-cards/test_ilama_lora_tp2.py b/tests/e2e/multicard/2-cards/test_ilama_lora_tp2.py index f37978b7..e0174e18 100644 --- a/tests/e2e/multicard/2-cards/test_ilama_lora_tp2.py +++ b/tests/e2e/multicard/2-cards/test_ilama_lora_tp2.py @@ -18,6 +18,7 @@ def test_ilama_lora_tp2(distributed_executor_backend, ilama_lora_files): tensor_parallel_size=2, cudagraph_capture_sizes=[1, 2, 4, 8], distributed_executor_backend=distributed_executor_backend, + enforce_eager=True, ) as vllm_model: output = do_sample(vllm_model.model, ilama_lora_files, lora_id=2) diff --git a/tests/e2e/singlecard/test_ilama_lora.py b/tests/e2e/singlecard/test_ilama_lora.py index a1f9fd41..d9c4814e 100644 --- a/tests/e2e/singlecard/test_ilama_lora.py +++ b/tests/e2e/singlecard/test_ilama_lora.py @@ -53,6 +53,7 @@ def test_ilama_lora(ilama_lora_files): max_model_len=1024, cudagraph_capture_sizes=[1, 2, 4, 8], max_num_seqs=16, + enforce_eager=True, ) as vllm_model: output1 = do_sample(vllm_model.model, ilama_lora_files, lora_id=1) diff --git a/vllm_ascend/lora/punica_npu.py b/vllm_ascend/lora/punica_npu.py index ddc98e32..90c2ef5d 100644 --- a/vllm_ascend/lora/punica_npu.py +++ b/vllm_ascend/lora/punica_npu.py @@ -3,21 +3,10 @@ from typing import Callable, Optional, Tuple, Union import torch - -from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type - -if get_ascend_device_type() == AscendDeviceType._310P: - from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice, - bgmv_shrink, sgmv_expand, - sgmv_expand_slice, sgmv_shrink) -else: - from vllm_ascend.lora.lora_ops import (bgmv_expand, bgmv_expand_slice, - bgmv_shrink, sgmv_expand, - sgmv_expand_slice, sgmv_shrink) - from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase from vllm_ascend.lora.utils import refresh_all_lora_classes +from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type # The platforms that are compatible with the PyTorch-native implementation can @@ -34,6 +23,27 @@ class PunicaWrapperNPU(PunicaWrapperBase): PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) refresh_all_lora_classes() + self.lora_config = kwargs.get("lora_config") + if get_ascend_device_type() == AscendDeviceType._310P or ( + self.lora_config is not None + and self.lora_config.max_lora_rank >= 128): + from vllm.lora.ops.torch_ops import (bgmv_expand, + bgmv_expand_slice, + bgmv_shrink, sgmv_expand, + sgmv_expand_slice, + sgmv_shrink) + else: + from vllm_ascend.lora.lora_ops import (bgmv_expand, + bgmv_expand_slice, + bgmv_shrink, sgmv_expand, + sgmv_expand_slice, + sgmv_shrink) + self.bgmv_expand = bgmv_expand + self.bgmv_expand_slice = bgmv_expand_slice + self.bgmv_shrink = bgmv_shrink + self.sgmv_expand = sgmv_expand + self.sgmv_expand_slice = sgmv_expand_slice + self.sgmv_shrink = sgmv_shrink def _shrink_prefill( self, @@ -45,7 +55,7 @@ class PunicaWrapperNPU(PunicaWrapperBase): #No LoRA request, so return directly if self.no_lora: return - sgmv_shrink( + self.sgmv_shrink( x, w_t_all, y, @@ -60,7 +70,7 @@ class PunicaWrapperNPU(PunicaWrapperBase): w_t_all: torch.Tensor, scale: float, ): - bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) + self.bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) def _expand_prefill( self, @@ -72,7 +82,7 @@ class PunicaWrapperNPU(PunicaWrapperBase): #No LoRA request, so return directly if self.no_lora: return - sgmv_expand( + self.sgmv_expand( x, w_t_all, y, @@ -87,7 +97,7 @@ class PunicaWrapperNPU(PunicaWrapperBase): w_t_all: torch.Tensor, add_inputs: bool, ): - bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs) + self.bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs) def _expand_slice_prefill( self, @@ -101,7 +111,7 @@ class PunicaWrapperNPU(PunicaWrapperBase): #No LoRA request, so return directly if self.no_lora: return - sgmv_expand_slice( + self.sgmv_expand_slice( x, w_t_all, y, @@ -120,8 +130,8 @@ class PunicaWrapperNPU(PunicaWrapperBase): y_slice_size: int, add_inputs: bool, ): - bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, - y_slice_size, add_inputs) + self.bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, + y_offset, y_slice_size, add_inputs) def _apply_expand( self, @@ -346,7 +356,7 @@ class PunicaWrapperNPU(PunicaWrapperBase): indices = self.sampler_indices - bgmv_shrink(x, lora_a_stacked, buffer, indices, scale) - bgmv_expand(buffer, lora_b_stacked, y, indices, add_inputs=True) + self.bgmv_shrink(x, lora_a_stacked, buffer, indices, scale) + self.bgmv_expand(buffer, lora_b_stacked, y, indices, add_inputs=True) y = y.view_as(y_org)