diff --git a/vllm_ascend/lora/punica_npu.py b/vllm_ascend/lora/punica_npu.py index b86ee33..db4adc4 100644 --- a/vllm_ascend/lora/punica_npu.py +++ b/vllm_ascend/lora/punica_npu.py @@ -341,13 +341,7 @@ class PunicaWrapperNPU(PunicaWrapperBase): y_org = y y = y.view(-1, y.shape[-1]) x = x.view(-1, x.shape[-1]) - - if lora_a_stacked.dim() == 2: - lora_a_stacked = lora_a_stacked.unsqueeze(0) - if lora_b_stacked.dim() == 2: - lora_b_stacked = lora_b_stacked.unsqueeze(0) - - r = lora_a_stacked.size(-1) + r = lora_b_stacked.size(-1) if buffer is None: buffer = torch.zeros((x.size(0), r), @@ -355,13 +349,8 @@ class PunicaWrapperNPU(PunicaWrapperBase): device=x.device) indices = self.sampler_indices - if indices.max() >= lora_a_stacked.size(0): - indices = torch.clamp(indices, 0, lora_a_stacked.size(0) - 1) - lora_a_reshaped = lora_a_stacked.transpose(1, 2) - lora_b_reshaped = lora_b_stacked.transpose(1, 2) - - bgmv_shrink(x, lora_a_reshaped, buffer, indices, scale) - bgmv_expand(buffer, lora_b_reshaped, y, indices, add_inputs=True) + bgmv_shrink(x, lora_a_stacked, buffer, indices, scale) + bgmv_expand(buffer, lora_b_stacked, y, indices, add_inputs=True) y = y.view_as(y_org) diff --git a/vllm_ascend/ops/vocab_parallel_embedding.py b/vllm_ascend/ops/vocab_parallel_embedding.py index 6e317e9..fe7ee51 100644 --- a/vllm_ascend/ops/vocab_parallel_embedding.py +++ b/vllm_ascend/ops/vocab_parallel_embedding.py @@ -262,6 +262,7 @@ class AscendLogitsProcessor(LogitsProcessor): sampling_metadata=None, # type: ignore embedding_bias: Optional[torch.Tensor] = None, ) -> Optional[torch.Tensor]: - return super().forward(lm_head, - hidden_states, - embedding_bias=embedding_bias) + return LogitsProcessor.forward(self, + lm_head, + hidden_states, + embedding_bias=embedding_bias)