[Bugfix][LoRA] Fix forward error and shape mismatch when using LoRA (#3153)

### What this PR does / why we need it?
Relying on #3044, this PR aims to further fix:
1. The forward error occured when `LogitsProcessorWithLoRA` calls
`AscendLogitsProcessor.forward`. Since `LogitsProcessorWithLoRA`
bypasses the MRO to call it, `super().forward(...)` in
`AscendLogitsProcessor.forward` will raise an error. This PR fixes it by
directly invoking `LogitsProcessor.forward(self, ...)`;
2. The shape mismatch in `add_lora_logits` in punica_npu.py. The
`lora_a_stacked` and `lora_b_stacked` are organized as [num_loras, 1,
lora_rank, hidden_size] and [num_loras, 1, vocab_size, lora_rank] shapes
respectively, but they are misunderstood in #1583---the last two
dimensions were assumed in reverse order, which causes errors in
`bgmv_shrink` and `bgmv_expand`. This PR fixes it by reverting it to the
previous version to align with the implementation in punica_cpu.py in
vllm.

### Dependencies
This PR depends on changes introduced by #3044 (LoRA support for
`AscendQKVParallelLinear` and `AscendMergedQKVParallelLinear` layers).

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
The LoRA-related tests, e.g., test_ilama_lora.py and
test_ilama_lora_tp2.py, use ilama-3.2-1B, and this model is regarded as
`TransformersForCausalLM`, where `embedding_modules` attribute lacks
`lm_head`. However, `LlamaForCausalLM` and most other models include
both `embed_tokens` and `lm_head` in `embedding_modules`. This attribute
contributes to `supported_lora_modules` when using LoRA in vllm.
Therefore, without `lm_head` in `embedding_modules`, current tests using
ilama-3.2-1B are unable to find the abve errors since
`LogitsProcessorWithLoRA` replacing `lm_head` is skipped. Simply using
Meta-Llama-3.1-8B-Instruct can reproduce the above errors and check
whether these fixes can work. What's more, it's necessary to add more
comprehensive tests for LoRA.

- vLLM version: v0.10.2
- vLLM main:
f225ea7dd9

Signed-off-by: Zetong Li <slippersss@126.com>
This commit is contained in:
Zetong Li
2025-09-28 17:30:50 +08:00
committed by GitHub
parent 3d21ed9ee8
commit a86ece5e39
2 changed files with 7 additions and 17 deletions

View File

@@ -341,13 +341,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
y_org = y
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
if lora_a_stacked.dim() == 2:
lora_a_stacked = lora_a_stacked.unsqueeze(0)
if lora_b_stacked.dim() == 2:
lora_b_stacked = lora_b_stacked.unsqueeze(0)
r = lora_a_stacked.size(-1)
r = lora_b_stacked.size(-1)
if buffer is None:
buffer = torch.zeros((x.size(0), r),
@@ -355,13 +349,8 @@ class PunicaWrapperNPU(PunicaWrapperBase):
device=x.device)
indices = self.sampler_indices
if indices.max() >= lora_a_stacked.size(0):
indices = torch.clamp(indices, 0, lora_a_stacked.size(0) - 1)
lora_a_reshaped = lora_a_stacked.transpose(1, 2)
lora_b_reshaped = lora_b_stacked.transpose(1, 2)
bgmv_shrink(x, lora_a_reshaped, buffer, indices, scale)
bgmv_expand(buffer, lora_b_reshaped, y, indices, add_inputs=True)
bgmv_shrink(x, lora_a_stacked, buffer, indices, scale)
bgmv_expand(buffer, lora_b_stacked, y, indices, add_inputs=True)
y = y.view_as(y_org)

View File

@@ -262,6 +262,7 @@ class AscendLogitsProcessor(LogitsProcessor):
sampling_metadata=None, # type: ignore
embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
return super().forward(lm_head,
hidden_states,
embedding_bias=embedding_bias)
return LogitsProcessor.forward(self,
lm_head,
hidden_states,
embedding_bias=embedding_bias)