diff --git a/vllm_ascend/lora/punica_wrapper/punica_npu.py b/vllm_ascend/lora/punica_wrapper/punica_npu.py index 9ca747b..8f1eaf9 100644 --- a/vllm_ascend/lora/punica_wrapper/punica_npu.py +++ b/vllm_ascend/lora/punica_wrapper/punica_npu.py @@ -322,7 +322,7 @@ class PunicaWrapperNPU(PunicaWrapperBase): **kwargs) -> None: """ Applies lora specifically for LogitsProcessorWithLoRA. - + Semantics: buffer = (x @ lora_a_stacked) * scale y += buffer @ lora_b_stacked @@ -338,18 +338,27 @@ class PunicaWrapperNPU(PunicaWrapperBase): y_org = y y = y.view(-1, y.shape[-1]) x = x.view(-1, x.shape[-1]) - r = lora_b_stacked.size(-1) + + if lora_a_stacked.dim() == 2: + lora_a_stacked = lora_a_stacked.unsqueeze(0) + if lora_b_stacked.dim() == 2: + lora_b_stacked = lora_b_stacked.unsqueeze(0) + + r = lora_a_stacked.size(-1) + if buffer is None: - # We set the buffer to be float32 by default, consistent with the - # triton op buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) - # LogitsProcessorWithLoRA always using bgmv. - bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale) - bgmv_expand(buffer, - lora_b_stacked, - y, - self.sampler_indices, - add_inputs=True) + + indices = self.sampler_indices + if indices.max() >= lora_a_stacked.size(0): + indices = torch.clamp(indices, 0, lora_a_stacked.size(0) - 1) + + lora_a_reshaped = lora_a_stacked.transpose(1, 2) + lora_b_reshaped = lora_b_stacked.transpose(1, 2) + + bgmv_shrink(x, lora_a_reshaped, buffer, indices, scale) + bgmv_expand(buffer, lora_b_reshaped, y, indices, add_inputs=True) + y = y.view_as(y_org)