[Disagg][Perf] Use NPU event sync instead of blocking tolist to avoid unintentional copy ops blocking across different NPU streams, improving disagg TTIT/TTFT (#2788)

### What this PR does / why we need it?
When we copy the sampled valid token ids from device to host, avoid
using tolist which would trigger a CUDA wise stream sync if the source
is on device. We change it to use non-blocking copy followed by an
explicit CUDA event sync.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
Bring up vLLM server
```bash
VLLM_USE_V1=1 vllm serve Qwen/Qwen2.5-14B-Instruct --disable-l
og-requests -tp 8 --max-num-seqs 64 --no-enable-prefix-caching --max_num_batched_tokens=8000
```
## Before:

![76218085a0cde9b2a73214e35fb7fc08](https://github.com/user-attachments/assets/38cbd02d-d380-47f8-a111-4bd859102eb1)
## After

![6c2111136673332244d3ce11060f4048](https://github.com/user-attachments/assets/957f9bf1-ec50-4f49-9318-f4876b3e3691)

As shown in the figure, the TTFT decreased


- vLLM version: v0.10.2
- vLLM main:
9607d5eb44

---------

Signed-off-by: jesse <szxfml@gmail.com>
This commit is contained in:
Song Zhixin
2025-09-24 11:21:58 +08:00
committed by GitHub
parent c4b976af1a
commit 6995a7bc5b
2 changed files with 69 additions and 1 deletions

View File

@@ -248,6 +248,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.block_size = vllm_config.cache_config.block_size
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
self.block_size)
self.max_model_len = self.model_config.max_model_len
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
decode_max_num_seqs = getattr(self.scheduler_config,
'decode_max_num_seqs', 0)
@@ -427,6 +428,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Cached outputs.
self._draft_token_ids: Optional[Union[list[list[int]],
torch.Tensor]] = None
self.transfer_event = torch_npu.npu.Event()
self.sampled_token_ids_pinned_cpu = torch.empty(
(self.max_model_len, 1),
dtype=torch.int64,
device="cpu",
pin_memory=True)
# NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True
self.in_profile_run = False
@@ -2081,7 +2088,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
# No spec decode tokens.
valid_sampled_token_ids = sampled_token_ids.tolist()
valid_sampled_token_ids = self._to_list(sampled_token_ids)
else:
# Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output(
@@ -3531,3 +3538,18 @@ class NPUModelRunner(LoRAModelRunnerMixin):
def _build_drafter_prepare_inputs_torchair_param(self):
return False
def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]:
# This is a short term mitigation for issue mentioned in
# https://github.com/vllm-project/vllm/issues/22754.
# `tolist` would trigger a npu wise stream sync, which
# would block other copy ops from other npu streams.
# A npu event sync would avoid such a situation. Since
# this is in the critical path of every single model
# forward loop, this has caused perf issue for a disagg
# setup.
pinned = self.sampled_token_ids_pinned_cpu[:sampled_token_ids.shape[0]]
pinned.copy_(sampled_token_ids, non_blocking=True)
self.transfer_event.record()
self.transfer_event.synchronize()
return pinned.tolist()