[Bugfix] LoRA logits einsum dimension mismatch in add_lora_logits (#1583)

### What this PR does / why we need it?
This PR fixes a tensor shape mismatch in `add_lora_logits`.

Previously, `lora_a_stacked` was passed as shape `[num_loras, in_dim,
rank]`, which does not match the expected einsum pattern `"bi, boi ->
bo"` used in `bgmv_shrink`.

This causes runtime errors like:
RuntimeError: einsum(): subscript i has size 3 for operand 1 which does
not broadcast with previously seen size 4

![image](https://github.com/user-attachments/assets/63029479-49ae-4c3c-b995-f6805d15ad06)

This fix transposes `lora_a_stacked` and `lora_b_stacked` to match the
expected shapes:
- `lora_a`: `[num_loras, rank, in_dim]`
- `lora_b`: `[num_loras, out_dim, rank]`

All unit tests pass after this fix.
### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
```
import torch
import pytest
from unittest.mock import patch, PropertyMock, ANY
from vllm_ascend.lora.punica_wrapper.punica_npu import PunicaWrapperNPU

@pytest.fixture
def wrapper_cpu():
    cfg = {"max_num_batched_tokens": 10, "max_batches": 2, "device": "cpu"}
    w = PunicaWrapperNPU(**cfg)
    w.is_prefill = True
    w.no_lora = False
    return w

def test_add_lora_logits(wrapper_cpu):
    batch_size = 2
    hidden_size = 4
    lora_rank = 3
    vocab_size = 5
    
    y = torch.zeros(batch_size, vocab_size)
    x = torch.randn(batch_size, hidden_size)
    
    num_loras = 1
    lora_a = torch.randn(num_loras, hidden_size, lora_rank)
    lora_b = torch.randn(num_loras, lora_rank, vocab_size)
    
    with patch.object(wrapper_cpu.__class__, "sampler_indices", 
                     new_callable=PropertyMock) as mock_idx:

        mock_idx.return_value = torch.zeros(batch_size, dtype=torch.long)

        wrapper_cpu.add_lora_logits(y, x, lora_a, lora_b, scale=1.0)

        assert y.shape == (batch_size, vocab_size)
        assert not torch.allclose(y, torch.zeros_like(y))

Signed-off-by: hongfugui <hongfugui_yewu@cmss.chinamobile.com>
This commit is contained in:
hongfugui
2025-07-30 09:50:36 +08:00
committed by GitHub
parent d80b0cca5d
commit 1dbb888275

View File

@@ -322,7 +322,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
**kwargs) -> None:
"""
Applies lora specifically for LogitsProcessorWithLoRA.
Semantics:
buffer = (x @ lora_a_stacked) * scale
y += buffer @ lora_b_stacked
@@ -338,18 +338,27 @@ class PunicaWrapperNPU(PunicaWrapperBase):
y_org = y
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
r = lora_b_stacked.size(-1)
if lora_a_stacked.dim() == 2:
lora_a_stacked = lora_a_stacked.unsqueeze(0)
if lora_b_stacked.dim() == 2:
lora_b_stacked = lora_b_stacked.unsqueeze(0)
r = lora_a_stacked.size(-1)
if buffer is None:
# We set the buffer to be float32 by default, consistent with the
# triton op
buffer = torch.zeros((x.size(0), r),
dtype=torch.float32,
device=x.device)
# LogitsProcessorWithLoRA always using bgmv.
bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale)
bgmv_expand(buffer,
lora_b_stacked,
y,
self.sampler_indices,
add_inputs=True)
indices = self.sampler_indices
if indices.max() >= lora_a_stacked.size(0):
indices = torch.clamp(indices, 0, lora_a_stacked.size(0) - 1)
lora_a_reshaped = lora_a_stacked.transpose(1, 2)
lora_b_reshaped = lora_b_stacked.transpose(1, 2)
bgmv_shrink(x, lora_a_reshaped, buffer, indices, scale)
bgmv_expand(buffer, lora_b_reshaped, y, indices, add_inputs=True)
y = y.view_as(y_org)