Files
xc-llm-ascend/vllm_ascend/lora
hongfugui 1dbb888275 [Bugfix] LoRA logits einsum dimension mismatch in add_lora_logits (#1583)
### What this PR does / why we need it?
This PR fixes a tensor shape mismatch in `add_lora_logits`.

Previously, `lora_a_stacked` was passed as shape `[num_loras, in_dim,
rank]`, which does not match the expected einsum pattern `"bi, boi ->
bo"` used in `bgmv_shrink`.

This causes runtime errors like:
RuntimeError: einsum(): subscript i has size 3 for operand 1 which does
not broadcast with previously seen size 4

![image](https://github.com/user-attachments/assets/63029479-49ae-4c3c-b995-f6805d15ad06)

This fix transposes `lora_a_stacked` and `lora_b_stacked` to match the
expected shapes:
- `lora_a`: `[num_loras, rank, in_dim]`
- `lora_b`: `[num_loras, out_dim, rank]`

All unit tests pass after this fix.
### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
```
import torch
import pytest
from unittest.mock import patch, PropertyMock, ANY
from vllm_ascend.lora.punica_wrapper.punica_npu import PunicaWrapperNPU

@pytest.fixture
def wrapper_cpu():
    cfg = {"max_num_batched_tokens": 10, "max_batches": 2, "device": "cpu"}
    w = PunicaWrapperNPU(**cfg)
    w.is_prefill = True
    w.no_lora = False
    return w

def test_add_lora_logits(wrapper_cpu):
    batch_size = 2
    hidden_size = 4
    lora_rank = 3
    vocab_size = 5
    
    y = torch.zeros(batch_size, vocab_size)
    x = torch.randn(batch_size, hidden_size)
    
    num_loras = 1
    lora_a = torch.randn(num_loras, hidden_size, lora_rank)
    lora_b = torch.randn(num_loras, lora_rank, vocab_size)
    
    with patch.object(wrapper_cpu.__class__, "sampler_indices", 
                     new_callable=PropertyMock) as mock_idx:

        mock_idx.return_value = torch.zeros(batch_size, dtype=torch.long)

        wrapper_cpu.add_lora_logits(y, x, lora_a, lora_b, scale=1.0)

        assert y.shape == (batch_size, vocab_size)
        assert not torch.allclose(y, torch.zeros_like(y))

Signed-off-by: hongfugui <hongfugui_yewu@cmss.chinamobile.com>
2025-07-30 09:50:36 +08:00
..
2025-04-22 08:57:25 +08:00