From 1b5513aa914aa37cb539d410faf6e021e6dc5481 Mon Sep 17 00:00:00 2001 From: FuNanyang <43992549+coder-fny@users.noreply.github.com> Date: Tue, 2 Dec 2025 20:35:51 +0800 Subject: [PATCH] [performance] Enhance performance after enabling min_p (#4529) ### What this PR does / why we need it? When min_p post-processing parameters are enabled, the original vllm implementation introduces the aclnInIndexPutImpl operator, which performs poorly on NPU ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? After enabling min_p to collect profiling The performance has been greatly improved - vLLM version: v0.11.2 --------- Signed-off-by: funanyang <985619145@qq.com> --- vllm_ascend/sample/logits_processor/builtin.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/vllm_ascend/sample/logits_processor/builtin.py b/vllm_ascend/sample/logits_processor/builtin.py index f38d9402..9910df1b 100644 --- a/vllm_ascend/sample/logits_processor/builtin.py +++ b/vllm_ascend/sample/logits_processor/builtin.py @@ -33,3 +33,20 @@ class AscendMinPLogitsProcessor(MinPLogitsProcessor): self.min_p_device = self.min_p_cpu_tensor # Current slice of the device tensor self.min_p: torch.Tensor = self.min_p_device[:0] + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + if not self.min_p_count: + return logits + # Convert logits to probability distribution + probability_values = torch.nn.functional.softmax(logits, dim=-1) + # Calculate maximum probabilities per sequence + max_probabilities = torch.amax(probability_values, + dim=-1, + keepdim=True) + # Adjust min_p + adjusted_min_p = max_probabilities.mul_(self.min_p) + # Identify valid tokens using threshold comparison + invalid_token_mask = probability_values < adjusted_min_p + # Apply mask using boolean indexing + logits.masked_fill_(invalid_token_mask, -float('inf')) + return logits