[Bugfix] LoRA logits einsum dimension mismatch in add_lora_logits (#1583)
### What this PR does / why we need it? This PR fixes a tensor shape mismatch in `add_lora_logits`. Previously, `lora_a_stacked` was passed as shape `[num_loras, in_dim, rank]`, which does not match the expected einsum pattern `"bi, boi -> bo"` used in `bgmv_shrink`. This causes runtime errors like: RuntimeError: einsum(): subscript i has size 3 for operand 1 which does not broadcast with previously seen size 4  This fix transposes `lora_a_stacked` and `lora_b_stacked` to match the expected shapes: - `lora_a`: `[num_loras, rank, in_dim]` - `lora_b`: `[num_loras, out_dim, rank]` All unit tests pass after this fix. ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? ``` import torch import pytest from unittest.mock import patch, PropertyMock, ANY from vllm_ascend.lora.punica_wrapper.punica_npu import PunicaWrapperNPU @pytest.fixture def wrapper_cpu(): cfg = {"max_num_batched_tokens": 10, "max_batches": 2, "device": "cpu"} w = PunicaWrapperNPU(**cfg) w.is_prefill = True w.no_lora = False return w def test_add_lora_logits(wrapper_cpu): batch_size = 2 hidden_size = 4 lora_rank = 3 vocab_size = 5 y = torch.zeros(batch_size, vocab_size) x = torch.randn(batch_size, hidden_size) num_loras = 1 lora_a = torch.randn(num_loras, hidden_size, lora_rank) lora_b = torch.randn(num_loras, lora_rank, vocab_size) with patch.object(wrapper_cpu.__class__, "sampler_indices", new_callable=PropertyMock) as mock_idx: mock_idx.return_value = torch.zeros(batch_size, dtype=torch.long) wrapper_cpu.add_lora_logits(y, x, lora_a, lora_b, scale=1.0) assert y.shape == (batch_size, vocab_size) assert not torch.allclose(y, torch.zeros_like(y)) Signed-off-by: hongfugui <hongfugui_yewu@cmss.chinamobile.com>
This commit is contained in:
@@ -322,7 +322,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
**kwargs) -> None:
|
||||
"""
|
||||
Applies lora specifically for LogitsProcessorWithLoRA.
|
||||
|
||||
|
||||
Semantics:
|
||||
buffer = (x @ lora_a_stacked) * scale
|
||||
y += buffer @ lora_b_stacked
|
||||
@@ -338,18 +338,27 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
x = x.view(-1, x.shape[-1])
|
||||
r = lora_b_stacked.size(-1)
|
||||
|
||||
if lora_a_stacked.dim() == 2:
|
||||
lora_a_stacked = lora_a_stacked.unsqueeze(0)
|
||||
if lora_b_stacked.dim() == 2:
|
||||
lora_b_stacked = lora_b_stacked.unsqueeze(0)
|
||||
|
||||
r = lora_a_stacked.size(-1)
|
||||
|
||||
if buffer is None:
|
||||
# We set the buffer to be float32 by default, consistent with the
|
||||
# triton op
|
||||
buffer = torch.zeros((x.size(0), r),
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
# LogitsProcessorWithLoRA always using bgmv.
|
||||
bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale)
|
||||
bgmv_expand(buffer,
|
||||
lora_b_stacked,
|
||||
y,
|
||||
self.sampler_indices,
|
||||
add_inputs=True)
|
||||
|
||||
indices = self.sampler_indices
|
||||
if indices.max() >= lora_a_stacked.size(0):
|
||||
indices = torch.clamp(indices, 0, lora_a_stacked.size(0) - 1)
|
||||
|
||||
lora_a_reshaped = lora_a_stacked.transpose(1, 2)
|
||||
lora_b_reshaped = lora_b_stacked.transpose(1, 2)
|
||||
|
||||
bgmv_shrink(x, lora_a_reshaped, buffer, indices, scale)
|
||||
bgmv_expand(buffer, lora_b_reshaped, y, indices, add_inputs=True)
|
||||
|
||||
y = y.view_as(y_org)
|
||||
|
||||
Reference in New Issue
Block a user