2025-07-30 08:47:22 +08:00
|
|
|
import torch
|
|
|
|
|
import torch_npu
|
2025-12-11 23:02:51 +08:00
|
|
|
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
|
2025-07-30 08:47:22 +08:00
|
|
|
from vllm.v1.sample.sampler import Sampler
|
|
|
|
|
|
2025-12-20 21:23:21 +08:00
|
|
|
from vllm_ascend.ascend_config import get_ascend_config
|
2025-12-11 23:02:51 +08:00
|
|
|
from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
|
|
|
|
|
global_stream, npu_stream_switch)
|
2025-08-22 07:30:48 +08:00
|
|
|
|
2025-10-09 10:28:38 +08:00
|
|
|
DEFAULT_LOGPROBS_MODE = "raw_logprobs"
|
2025-08-01 08:43:08 +08:00
|
|
|
|
2025-07-30 08:47:22 +08:00
|
|
|
|
2025-12-11 23:02:51 +08:00
|
|
|
def random_sample(
|
|
|
|
|
probs: torch.Tensor,
|
|
|
|
|
generators: dict[int, torch.Generator],
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
"""Randomly sample from the probabilities.
|
|
|
|
|
|
|
|
|
|
We use this function instead of torch.multinomial because torch.multinomial
|
|
|
|
|
causes CPU-NPU synchronization.
|
|
|
|
|
"""
|
|
|
|
|
# NOTE(woosuk): To batch-process the requests without their own seeds,
|
|
|
|
|
# which is the common case, we first assume that every request does
|
|
|
|
|
# not have its own seed. Then, we overwrite the values for the requests
|
|
|
|
|
# that have their own seeds.
|
|
|
|
|
with npu_stream_switch(global_stream()):
|
|
|
|
|
q = torch.empty_like(probs)
|
|
|
|
|
if len(generators) != probs.shape[0]:
|
|
|
|
|
q.exponential_()
|
|
|
|
|
if generators:
|
|
|
|
|
# TODO(woosuk): This can be slow because we handle each request
|
|
|
|
|
# one by one. Optimize this.
|
|
|
|
|
for i, generator in generators.items():
|
|
|
|
|
q[i].exponential_(generator=generator)
|
|
|
|
|
torch.npu.current_stream().wait_stream(global_stream())
|
|
|
|
|
return probs.div_(q).argmax(dim=-1).view(-1)
|
|
|
|
|
|
|
|
|
|
|
2025-07-30 08:47:22 +08:00
|
|
|
class AscendSampler(Sampler):
|
|
|
|
|
|
2025-08-22 07:30:48 +08:00
|
|
|
def __init__(self, logprobs_mode=DEFAULT_LOGPROBS_MODE):
|
2025-07-30 08:47:22 +08:00
|
|
|
# TODO: support logprobs_mode in vllm-ascend
|
|
|
|
|
super().__init__(logprobs_mode=logprobs_mode)
|
|
|
|
|
self.topk_topp_sampler = AscendTopKTopPSampler()
|
2025-12-20 21:23:21 +08:00
|
|
|
self.async_exponential_event = torch.npu.Event()
|
|
|
|
|
|
|
|
|
|
def set_q_event(self, q, event):
|
|
|
|
|
self.topk_topp_sampler.set_q_event(q, event)
|
|
|
|
|
|
|
|
|
|
def do_async_exponential(self, b_s, head_dim, generators):
|
|
|
|
|
# Calculating exponential randoms in a different stream
|
|
|
|
|
# and overlapping with model executing.
|
|
|
|
|
with torch.npu.stream(global_stream()):
|
|
|
|
|
global_stream().wait_stream(torch.npu.current_stream())
|
|
|
|
|
q = torch.empty((b_s, head_dim), device="npu", dtype=torch.float32)
|
|
|
|
|
# Goes to async exponential with AI-CPU exponential or default exponential.
|
|
|
|
|
if len(generators) != q.shape[0]:
|
|
|
|
|
q.exponential_()
|
|
|
|
|
if generators:
|
|
|
|
|
for i, generator in generators.items():
|
|
|
|
|
q[i].exponential_(generator=generator)
|
|
|
|
|
self.async_exponential_event.record()
|
|
|
|
|
self.set_q_event(q, self.async_exponential_event)
|
2025-07-30 08:47:22 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class AscendTopKTopPSampler(TopKTopPSampler):
|
|
|
|
|
|
2025-12-20 21:23:21 +08:00
|
|
|
def set_q_event(self, q, event):
|
|
|
|
|
# Pass in async exponential results.
|
|
|
|
|
# Also pass in event to prevent synchronize errors.
|
|
|
|
|
self.q = q
|
|
|
|
|
self.async_event = event
|
|
|
|
|
|
2025-07-30 08:47:22 +08:00
|
|
|
def _apply_top_k_top_p(
|
|
|
|
|
self,
|
|
|
|
|
logits: torch.Tensor,
|
|
|
|
|
k: torch.Tensor,
|
|
|
|
|
p: torch.Tensor,
|
|
|
|
|
) -> torch.Tensor:
|
2025-08-01 08:43:08 +08:00
|
|
|
# npu_top_k_top_p uses the operator aclnnApplyTopKTopP, but aclnnApplyTopKTopP currently does not support 310P
|
[refact] unified soc_version code (#4359)
### What this PR does / why we need it?
Currently, there are two paths to judge the chip type in code,
`get_ascend_soc_version` use `get_soc_version` api in torch_npu, and
`is_310p` `use _build_info.__soc_version__`, which generate when
install. We need to unify the two paths.
We need to unify these codes based on the following points:
1. We need to ensure consistency in chip type judgment between compiling
and running states;
2. In compiling state, we need chip type to complete op's compilation,
but in running state, we only need device
type(910B/910_93/310P/910_95/etc) to make code branch judgement;
3. In compiling state, torch_npu may not have been installed yet, so we
can't use torch_npu's api.
Based on the above points, we have made the following changes:
1. When user set env `SOC_VERSION`, use it; when not set, query
soc_version by `npu-smi`;
2. generate device_type based on soc_version when compiling, and write
`__device_type__` instead of `__soc_version__` in `_build_info.py`;
3. In running state, use `__device_type__` to judge code branch.
### Does this PR introduce _any_ user-facing change?
When not set env `SOC_VERSION`, it will not be `ASCEND910B1` by default,
we will query soc_version by `npu-smi`. And env `SOC_VERSION` must be in
the list `soc_to_device` in `setup.py`.
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
Signed-off-by: zzzzwwjj <1183291235@qq.com>
2025-11-26 14:28:55 +08:00
|
|
|
if get_ascend_device_type(
|
|
|
|
|
) != AscendDeviceType._310P and p is not None and k is not None and 1 <= int(
|
2025-09-29 14:04:58 +08:00
|
|
|
k.max()) <= 1024:
|
2025-07-30 08:47:22 +08:00
|
|
|
# npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p)
|
|
|
|
|
return torch_npu.npu_top_k_top_p(logits, p, k)
|
|
|
|
|
|
|
|
|
|
if p is None and k is None:
|
|
|
|
|
return logits
|
|
|
|
|
|
|
|
|
|
probs = logits.softmax(dim=-1)
|
|
|
|
|
probs_sort, _ = probs.sort(dim=-1, descending=False)
|
|
|
|
|
|
|
|
|
|
if k is not None:
|
|
|
|
|
top_k_count = probs_sort.size(1) - k.to(
|
|
|
|
|
torch.long) # shape: (batch, )
|
|
|
|
|
top_k_count = top_k_count.unsqueeze(dim=1)
|
|
|
|
|
top_k_cutoff = probs_sort.gather(-1, top_k_count)
|
|
|
|
|
|
|
|
|
|
# Make sure the no top-k rows are no-op.
|
|
|
|
|
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
|
|
|
|
|
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
|
|
|
|
|
|
|
|
|
|
elements_to_discard = probs < top_k_cutoff
|
|
|
|
|
logits.masked_fill_(elements_to_discard, -float("inf"))
|
|
|
|
|
|
|
|
|
|
if p is not None:
|
|
|
|
|
cumprob = torch.cumsum(probs_sort, dim=-1)
|
|
|
|
|
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
|
|
|
|
|
top_p_mask[:, -1] = False # at least one
|
|
|
|
|
|
|
|
|
|
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
|
|
|
|
|
top_p_cutoff = probs_sort.gather(-1, top_p_count)
|
|
|
|
|
elements_to_discard = probs < top_p_cutoff
|
|
|
|
|
logits.masked_fill_(elements_to_discard, -float("inf"))
|
|
|
|
|
|
|
|
|
|
return logits
|
|
|
|
|
|
|
|
|
|
def forward_native(self, logits, generators, k, p):
|
|
|
|
|
"""Override pytorch native implementation to torch_npu"""
|
|
|
|
|
logits = self._apply_top_k_top_p(logits, k, p)
|
2025-09-10 08:43:10 +08:00
|
|
|
logits_to_return = None
|
2025-10-09 10:28:38 +08:00
|
|
|
if self.logprobs_mode == "processed_logits":
|
|
|
|
|
logits_to_return = logits
|
|
|
|
|
elif self.logprobs_mode == "processed_logprobs":
|
|
|
|
|
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
|
2025-08-22 07:30:48 +08:00
|
|
|
|
2025-07-30 08:47:22 +08:00
|
|
|
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
2025-12-20 21:23:21 +08:00
|
|
|
if get_ascend_config().enable_async_exponential == 1:
|
|
|
|
|
# Add synchronize to prevent synchronize error.
|
|
|
|
|
self.async_event.synchronize()
|
|
|
|
|
return probs.div_(self.q).argmax(dim=-1).view(-1), logits_to_return
|
2025-09-10 08:43:10 +08:00
|
|
|
return random_sample(probs, generators), logits_to_return
|