[Disagg][Perf] Use NPU event sync instead of blocking tolist to avoid unintentional copy ops blocking across different NPU streams, improving disagg TTIT/TTFT (#2788)
### What this PR does / why we need it?
When we copy the sampled valid token ids from device to host, avoid
using tolist which would trigger a CUDA wise stream sync if the source
is on device. We change it to use non-blocking copy followed by an
explicit CUDA event sync.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
Bring up vLLM server
```bash
VLLM_USE_V1=1 vllm serve Qwen/Qwen2.5-14B-Instruct --disable-l
og-requests -tp 8 --max-num-seqs 64 --no-enable-prefix-caching --max_num_batched_tokens=8000
```
## Before:

## After

As shown in the figure, the TTFT decreased
- vLLM version: v0.10.2
- vLLM main:
9607d5eb44
---------
Signed-off-by: jesse <szxfml@gmail.com>
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.utils import AscendSocVersion
|
||||
@@ -105,3 +106,48 @@ def test_select_moe_comm_method_unsupported_soc():
|
||||
pytest.raises(ValueError, match=f"Unsupported soc_version: {unsupported_soc}"):
|
||||
|
||||
NPUModelRunner._select_moe_comm_method(mock_runner, 100, False)
|
||||
|
||||
|
||||
@patch('vllm_ascend.worker.model_runner_v1.torch_npu')
|
||||
@patch('vllm_ascend.worker.model_runner_v1.torch')
|
||||
def test_init_creates_transfer_event_and_pinned_memory(mock_torch,
|
||||
mock_torch_npu):
|
||||
"""Test that initialization creates transfer event and pinned CPU memory."""
|
||||
# This is a simplified test focusing only on the new attributes
|
||||
# We mock the entire __init__ process and only test the specific lines we added
|
||||
|
||||
# Mock torch.empty to return a mock tensor
|
||||
mock_pinned_tensor = MagicMock()
|
||||
mock_torch.empty.return_value = mock_pinned_tensor
|
||||
|
||||
# Mock torch_npu.npu.Event - 需要设置嵌套的 mock 结构
|
||||
mock_event = MagicMock()
|
||||
mock_torch_npu.npu.Event.return_value = mock_event
|
||||
|
||||
# Create a runner instance using __new__ to bypass __init__
|
||||
runner = NPUModelRunner.__new__(NPUModelRunner)
|
||||
|
||||
# Manually set the attributes we need for our test
|
||||
runner.max_model_len = 2048
|
||||
|
||||
# Test the specific lines from the commit
|
||||
runner.transfer_event = mock_torch_npu.npu.Event()
|
||||
runner.sampled_token_ids_pinned_cpu = mock_torch.empty(
|
||||
(runner.max_model_len, 1),
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
|
||||
# Verify max_model_len is set
|
||||
assert runner.max_model_len == 2048
|
||||
|
||||
# Verify transfer_event is created
|
||||
assert runner.transfer_event == mock_event
|
||||
mock_torch_npu.npu.Event.assert_called_once()
|
||||
|
||||
# Verify pinned CPU memory is created with correct parameters
|
||||
assert runner.sampled_token_ids_pinned_cpu == mock_pinned_tensor
|
||||
mock_torch.empty.assert_called_with((2048, 1),
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
|
||||
@@ -248,6 +248,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
|
||||
self.block_size)
|
||||
self.max_model_len = self.model_config.max_model_len
|
||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
decode_max_num_seqs = getattr(self.scheduler_config,
|
||||
'decode_max_num_seqs', 0)
|
||||
@@ -427,6 +428,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Cached outputs.
|
||||
self._draft_token_ids: Optional[Union[list[list[int]],
|
||||
torch.Tensor]] = None
|
||||
self.transfer_event = torch_npu.npu.Event()
|
||||
self.sampled_token_ids_pinned_cpu = torch.empty(
|
||||
(self.max_model_len, 1),
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
|
||||
# NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True
|
||||
self.in_profile_run = False
|
||||
@@ -2081,7 +2088,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
max_gen_len = sampled_token_ids.shape[-1]
|
||||
if max_gen_len == 1:
|
||||
# No spec decode tokens.
|
||||
valid_sampled_token_ids = sampled_token_ids.tolist()
|
||||
valid_sampled_token_ids = self._to_list(sampled_token_ids)
|
||||
else:
|
||||
# Includes spec decode tokens.
|
||||
valid_sampled_token_ids = self.rejection_sampler.parse_output(
|
||||
@@ -3531,3 +3538,18 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
def _build_drafter_prepare_inputs_torchair_param(self):
|
||||
return False
|
||||
|
||||
def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]:
|
||||
# This is a short term mitigation for issue mentioned in
|
||||
# https://github.com/vllm-project/vllm/issues/22754.
|
||||
# `tolist` would trigger a npu wise stream sync, which
|
||||
# would block other copy ops from other npu streams.
|
||||
# A npu event sync would avoid such a situation. Since
|
||||
# this is in the critical path of every single model
|
||||
# forward loop, this has caused perf issue for a disagg
|
||||
# setup.
|
||||
pinned = self.sampled_token_ids_pinned_cpu[:sampled_token_ids.shape[0]]
|
||||
pinned.copy_(sampled_token_ids, non_blocking=True)
|
||||
self.transfer_event.record()
|
||||
self.transfer_event.synchronize()
|
||||
return pinned.tolist()
|
||||
Reference in New Issue
Block a user