From 1dbb8882759e4326f5706f6e610674423376c2f3 Mon Sep 17 00:00:00 2001 From: hongfugui Date: Wed, 30 Jul 2025 09:50:36 +0800 Subject: [PATCH] [Bugfix] LoRA logits einsum dimension mismatch in add_lora_logits (#1583) ### What this PR does / why we need it? This PR fixes a tensor shape mismatch in `add_lora_logits`. Previously, `lora_a_stacked` was passed as shape `[num_loras, in_dim, rank]`, which does not match the expected einsum pattern `"bi, boi -> bo"` used in `bgmv_shrink`. This causes runtime errors like: RuntimeError: einsum(): subscript i has size 3 for operand 1 which does not broadcast with previously seen size 4 ![image](https://github.com/user-attachments/assets/63029479-49ae-4c3c-b995-f6805d15ad06) This fix transposes `lora_a_stacked` and `lora_b_stacked` to match the expected shapes: - `lora_a`: `[num_loras, rank, in_dim]` - `lora_b`: `[num_loras, out_dim, rank]` All unit tests pass after this fix. ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? ``` import torch import pytest from unittest.mock import patch, PropertyMock, ANY from vllm_ascend.lora.punica_wrapper.punica_npu import PunicaWrapperNPU @pytest.fixture def wrapper_cpu(): cfg = {"max_num_batched_tokens": 10, "max_batches": 2, "device": "cpu"} w = PunicaWrapperNPU(**cfg) w.is_prefill = True w.no_lora = False return w def test_add_lora_logits(wrapper_cpu): batch_size = 2 hidden_size = 4 lora_rank = 3 vocab_size = 5 y = torch.zeros(batch_size, vocab_size) x = torch.randn(batch_size, hidden_size) num_loras = 1 lora_a = torch.randn(num_loras, hidden_size, lora_rank) lora_b = torch.randn(num_loras, lora_rank, vocab_size) with patch.object(wrapper_cpu.__class__, "sampler_indices", new_callable=PropertyMock) as mock_idx: mock_idx.return_value = torch.zeros(batch_size, dtype=torch.long) wrapper_cpu.add_lora_logits(y, x, lora_a, lora_b, scale=1.0) assert y.shape == (batch_size, vocab_size) assert not torch.allclose(y, torch.zeros_like(y)) Signed-off-by: hongfugui --- vllm_ascend/lora/punica_wrapper/punica_npu.py | 31 ++++++++++++------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/vllm_ascend/lora/punica_wrapper/punica_npu.py b/vllm_ascend/lora/punica_wrapper/punica_npu.py index 9ca747b..8f1eaf9 100644 --- a/vllm_ascend/lora/punica_wrapper/punica_npu.py +++ b/vllm_ascend/lora/punica_wrapper/punica_npu.py @@ -322,7 +322,7 @@ class PunicaWrapperNPU(PunicaWrapperBase): **kwargs) -> None: """ Applies lora specifically for LogitsProcessorWithLoRA. - + Semantics: buffer = (x @ lora_a_stacked) * scale y += buffer @ lora_b_stacked @@ -338,18 +338,27 @@ class PunicaWrapperNPU(PunicaWrapperBase): y_org = y y = y.view(-1, y.shape[-1]) x = x.view(-1, x.shape[-1]) - r = lora_b_stacked.size(-1) + + if lora_a_stacked.dim() == 2: + lora_a_stacked = lora_a_stacked.unsqueeze(0) + if lora_b_stacked.dim() == 2: + lora_b_stacked = lora_b_stacked.unsqueeze(0) + + r = lora_a_stacked.size(-1) + if buffer is None: - # We set the buffer to be float32 by default, consistent with the - # triton op buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) - # LogitsProcessorWithLoRA always using bgmv. - bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale) - bgmv_expand(buffer, - lora_b_stacked, - y, - self.sampler_indices, - add_inputs=True) + + indices = self.sampler_indices + if indices.max() >= lora_a_stacked.size(0): + indices = torch.clamp(indices, 0, lora_a_stacked.size(0) - 1) + + lora_a_reshaped = lora_a_stacked.transpose(1, 2) + lora_b_reshaped = lora_b_stacked.transpose(1, 2) + + bgmv_shrink(x, lora_a_reshaped, buffer, indices, scale) + bgmv_expand(buffer, lora_b_reshaped, y, indices, add_inputs=True) + y = y.view_as(y_org)