From a86ece5e399db9aa9d7186ab7e51bc0e0dad4134 Mon Sep 17 00:00:00 2001 From: Zetong Li <48438720+slippersss@users.noreply.github.com> Date: Sun, 28 Sep 2025 17:30:50 +0800 Subject: [PATCH] [Bugfix][LoRA] Fix forward error and shape mismatch when using LoRA (#3153) ### What this PR does / why we need it? Relying on #3044, this PR aims to further fix: 1. The forward error occured when `LogitsProcessorWithLoRA` calls `AscendLogitsProcessor.forward`. Since `LogitsProcessorWithLoRA` bypasses the MRO to call it, `super().forward(...)` in `AscendLogitsProcessor.forward` will raise an error. This PR fixes it by directly invoking `LogitsProcessor.forward(self, ...)`; 2. The shape mismatch in `add_lora_logits` in punica_npu.py. The `lora_a_stacked` and `lora_b_stacked` are organized as [num_loras, 1, lora_rank, hidden_size] and [num_loras, 1, vocab_size, lora_rank] shapes respectively, but they are misunderstood in #1583---the last two dimensions were assumed in reverse order, which causes errors in `bgmv_shrink` and `bgmv_expand`. This PR fixes it by reverting it to the previous version to align with the implementation in punica_cpu.py in vllm. ### Dependencies This PR depends on changes introduced by #3044 (LoRA support for `AscendQKVParallelLinear` and `AscendMergedQKVParallelLinear` layers). ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? The LoRA-related tests, e.g., test_ilama_lora.py and test_ilama_lora_tp2.py, use ilama-3.2-1B, and this model is regarded as `TransformersForCausalLM`, where `embedding_modules` attribute lacks `lm_head`. However, `LlamaForCausalLM` and most other models include both `embed_tokens` and `lm_head` in `embedding_modules`. This attribute contributes to `supported_lora_modules` when using LoRA in vllm. Therefore, without `lm_head` in `embedding_modules`, current tests using ilama-3.2-1B are unable to find the abve errors since `LogitsProcessorWithLoRA` replacing `lm_head` is skipped. Simply using Meta-Llama-3.1-8B-Instruct can reproduce the above errors and check whether these fixes can work. What's more, it's necessary to add more comprehensive tests for LoRA. - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/f225ea7dd98e9f29752e5c032cd4a8ee1d712f16 Signed-off-by: Zetong Li --- vllm_ascend/lora/punica_npu.py | 17 +++-------------- vllm_ascend/ops/vocab_parallel_embedding.py | 7 ++++--- 2 files changed, 7 insertions(+), 17 deletions(-) diff --git a/vllm_ascend/lora/punica_npu.py b/vllm_ascend/lora/punica_npu.py index b86ee33..db4adc4 100644 --- a/vllm_ascend/lora/punica_npu.py +++ b/vllm_ascend/lora/punica_npu.py @@ -341,13 +341,7 @@ class PunicaWrapperNPU(PunicaWrapperBase): y_org = y y = y.view(-1, y.shape[-1]) x = x.view(-1, x.shape[-1]) - - if lora_a_stacked.dim() == 2: - lora_a_stacked = lora_a_stacked.unsqueeze(0) - if lora_b_stacked.dim() == 2: - lora_b_stacked = lora_b_stacked.unsqueeze(0) - - r = lora_a_stacked.size(-1) + r = lora_b_stacked.size(-1) if buffer is None: buffer = torch.zeros((x.size(0), r), @@ -355,13 +349,8 @@ class PunicaWrapperNPU(PunicaWrapperBase): device=x.device) indices = self.sampler_indices - if indices.max() >= lora_a_stacked.size(0): - indices = torch.clamp(indices, 0, lora_a_stacked.size(0) - 1) - lora_a_reshaped = lora_a_stacked.transpose(1, 2) - lora_b_reshaped = lora_b_stacked.transpose(1, 2) - - bgmv_shrink(x, lora_a_reshaped, buffer, indices, scale) - bgmv_expand(buffer, lora_b_reshaped, y, indices, add_inputs=True) + bgmv_shrink(x, lora_a_stacked, buffer, indices, scale) + bgmv_expand(buffer, lora_b_stacked, y, indices, add_inputs=True) y = y.view_as(y_org) diff --git a/vllm_ascend/ops/vocab_parallel_embedding.py b/vllm_ascend/ops/vocab_parallel_embedding.py index 6e317e9..fe7ee51 100644 --- a/vllm_ascend/ops/vocab_parallel_embedding.py +++ b/vllm_ascend/ops/vocab_parallel_embedding.py @@ -262,6 +262,7 @@ class AscendLogitsProcessor(LogitsProcessor): sampling_metadata=None, # type: ignore embedding_bias: Optional[torch.Tensor] = None, ) -> Optional[torch.Tensor]: - return super().forward(lm_head, - hidden_states, - embedding_bias=embedding_bias) + return LogitsProcessor.forward(self, + lm_head, + hidden_states, + embedding_bias=embedding_bias)