[Performance] Add async exponential while model executing (#4501)

### What this PR does / why we need it?
Add a control to enable the exponential distribution operator
overlapping with model executing (default is OFF due to this feature
might not perform well on MOE models, i.e. For Qwen3-30B).
Enable async exponential overlapping will provides performance
improvement.
Also, overlapping the exponential operator with module execution can
cover the performance drop introduced by AICPU-version's exponential
operator.

**UPDATE**: (12/12)
Now our overlap will use the same stream that introduced in this pr:
#4908 .
We move the `do_async_exponential` from `model_runner_v1.py` to
`sampler.py`.
Now we are using `additional_config` to enable async exponential:
Add `"enable_async_exponential": 1` in `addition_config`.
Now we **ONLY** support default exponential/AI-CPU exponential, the old
`"enable_async_exponential": 2` option has been aborted to keep
consistency.

### Does this PR introduce _any_ user-facing change?
**YES**, added a new `additional_config` : `"enable_async_exponential":
1`.
When `enable_async_exponential` is set to 1, we enable the async
exponential and overlap with model runner.
When `enable_async_exponential` is set to 0 (default is 0), we disable
the async exponential, but exponential will still running on a different
stream using stream introduced in #4908.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: YuhanBai <yuhan.bai0830@gmail.com>
Signed-off-by: YuhanBai yuhan.bai0830@gmail.com
This commit is contained in:
YuhanBai
2025-12-20 21:23:21 +08:00
committed by GitHub
parent 58773af708
commit 5d02eed16f
5 changed files with 60 additions and 0 deletions

View File

@@ -3,6 +3,7 @@ import torch_npu
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
from vllm.v1.sample.sampler import Sampler
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
global_stream, npu_stream_switch)
@@ -41,10 +42,35 @@ class AscendSampler(Sampler):
# TODO: support logprobs_mode in vllm-ascend
super().__init__(logprobs_mode=logprobs_mode)
self.topk_topp_sampler = AscendTopKTopPSampler()
self.async_exponential_event = torch.npu.Event()
def set_q_event(self, q, event):
self.topk_topp_sampler.set_q_event(q, event)
def do_async_exponential(self, b_s, head_dim, generators):
# Calculating exponential randoms in a different stream
# and overlapping with model executing.
with torch.npu.stream(global_stream()):
global_stream().wait_stream(torch.npu.current_stream())
q = torch.empty((b_s, head_dim), device="npu", dtype=torch.float32)
# Goes to async exponential with AI-CPU exponential or default exponential.
if len(generators) != q.shape[0]:
q.exponential_()
if generators:
for i, generator in generators.items():
q[i].exponential_(generator=generator)
self.async_exponential_event.record()
self.set_q_event(q, self.async_exponential_event)
class AscendTopKTopPSampler(TopKTopPSampler):
def set_q_event(self, q, event):
# Pass in async exponential results.
# Also pass in event to prevent synchronize errors.
self.q = q
self.async_event = event
def _apply_top_k_top_p(
self,
logits: torch.Tensor,
@@ -99,4 +125,8 @@ class AscendTopKTopPSampler(TopKTopPSampler):
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
probs = logits.softmax(dim=-1, dtype=torch.float32)
if get_ascend_config().enable_async_exponential == 1:
# Add synchronize to prevent synchronize error.
self.async_event.synchronize()
return probs.div_(self.q).argmax(dim=-1).view(-1), logits_to_return
return random_sample(probs, generators), logits_to_return