### What this PR does / why we need it? When min_p post-processing parameters are enabled, the original vllm implementation introduces the aclnInIndexPutImpl operator, which performs poorly on NPU ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? After enabling min_p to collect profiling The performance has been greatly improved - vLLM version: v0.11.2 --------- Signed-off-by: funanyang <985619145@qq.com>
53 lines
2.2 KiB
Python
53 lines
2.2 KiB
Python
import torch
|
|
from vllm.config import VllmConfig
|
|
from vllm.v1.sample.logits_processor import MinPLogitsProcessor
|
|
|
|
|
|
class AscendMinPLogitsProcessor(MinPLogitsProcessor):
|
|
|
|
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
|
|
is_pin_memory: bool):
|
|
super().__init__(vllm_config, device, is_pin_memory)
|
|
|
|
decode_max_num_seqs = getattr(vllm_config.scheduler_config,
|
|
'decode_max_num_seqs', 0)
|
|
if decode_max_num_seqs != 0:
|
|
max_num_reqs = max(vllm_config.scheduler_config.max_num_seqs,
|
|
decode_max_num_seqs)
|
|
|
|
self.min_p_count: int = 0
|
|
|
|
self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ),
|
|
dtype=torch.float32,
|
|
device="cpu",
|
|
pin_memory=is_pin_memory)
|
|
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
|
|
|
|
self.use_double_tensor = torch.device(device).type != "cpu"
|
|
|
|
if self.use_double_tensor:
|
|
# Pre-allocated device tensor
|
|
self.min_p_device: torch.Tensor = torch.empty(
|
|
(max_num_reqs, ), dtype=torch.float32, device=device)
|
|
else:
|
|
self.min_p_device = self.min_p_cpu_tensor
|
|
# Current slice of the device tensor
|
|
self.min_p: torch.Tensor = self.min_p_device[:0]
|
|
|
|
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
|
if not self.min_p_count:
|
|
return logits
|
|
# Convert logits to probability distribution
|
|
probability_values = torch.nn.functional.softmax(logits, dim=-1)
|
|
# Calculate maximum probabilities per sequence
|
|
max_probabilities = torch.amax(probability_values,
|
|
dim=-1,
|
|
keepdim=True)
|
|
# Adjust min_p
|
|
adjusted_min_p = max_probabilities.mul_(self.min_p)
|
|
# Identify valid tokens using threshold comparison
|
|
invalid_token_mask = probability_values < adjusted_min_p
|
|
# Apply mask using boolean indexing
|
|
logits.masked_fill_(invalid_token_mask, -float('inf'))
|
|
return logits
|