[BugFix] Async scheduling and PP compatibility with DP (#2796)

### What this PR does / why we need it?
based on the https://github.com/vllm-project/vllm/pull/23770,
fix Async scheduling and PP compatibility with DP, also fixes issue with
finished requests not being processed in async scheduling and PP cases,
and possible worker race conditions.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.10.2
- vLLM main:
544fe76b95

---------

Signed-off-by: jesse <szxfml@gmail.com>
This commit is contained in:
Song Zhixin
2025-09-19 11:29:50 +08:00
committed by GitHub
parent 0a526768f5
commit 833cd1b698
2 changed files with 78 additions and 26 deletions

View File

@@ -828,6 +828,7 @@ class TestNPUWorker(TestBase):
# Mock scheduler_output and return result
mock_scheduler_output = MagicMock()
mock_scheduler_output.total_num_scheduled_tokens = 1
# Create a real ModelRunnerOutput instance or mock
mock_model_output = MagicMock(spec=ModelRunnerOutput)
worker.model_runner.execute_model.return_value = mock_model_output
@@ -842,9 +843,8 @@ class TestNPUWorker(TestBase):
@patch("vllm_ascend.worker.worker_v1.get_pp_group")
@patch("vllm_ascend.worker.worker_v1.get_tp_group")
@patch("vllm_ascend.worker.worker_v1.has_kv_transfer_group")
def test_execute_model_middle_rank(self, mock_has_kv_transfer_group,
mock_get_tp_group, mock_get_pp_group):
def test_execute_model_middle_rank(self, mock_get_tp_group,
mock_get_pp_group):
"""Test execute_model method - middle rank case"""
from vllm.sequence import IntermediateTensors
@@ -875,10 +875,8 @@ class TestNPUWorker(TestBase):
)
worker.model_runner.execute_model.return_value = mock_intermediate_output
# Set has_kv_transfer_group returns False
mock_has_kv_transfer_group.return_value = False
mock_scheduler_output = MagicMock()
mock_scheduler_output.total_num_scheduled_tokens = 1
# Test execute_model
result = worker.execute_model(mock_scheduler_output)
@@ -926,6 +924,7 @@ class TestNPUWorker(TestBase):
# Mock return result
mock_scheduler_output = MagicMock()
mock_scheduler_output.total_num_scheduled_tokens = 1
mock_model_output = MagicMock(spec=ModelRunnerOutput)
worker.model_runner.execute_model.return_value = mock_model_output
@@ -1150,3 +1149,55 @@ class TestNPUWorker(TestBase):
# Verify calls
worker.model_runner.initialize_kv_cache.assert_called_once_with(
mock_kv_cache_config)
@patch("vllm_ascend.worker.worker_v1.get_pp_group")
@patch("vllm_ascend.worker.worker_v1.get_tp_group")
@patch("vllm_ascend.worker.worker_v1.EMPTY_MODEL_RUNNER_OUTPUT")
def test_execute_model_kv_connector_not_finished(self, mock_empty_output,
mock_get_tp_group,
mock_get_pp_group):
"""Test execute_model method - kv_connector_output not finished sending/recving case"""
from vllm.sequence import IntermediateTensors
from vllm_ascend.worker.worker_v1 import NPUWorker
# Create worker mock
with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None):
worker = NPUWorker()
worker.model_runner = MagicMock()
worker.vllm_config = MagicMock()
worker.vllm_config.parallel_config = MagicMock()
worker.vllm_config.parallel_config.distributed_executor_backend = "ray"
# Set as middle rank (not first, not last)
mock_pp_group = MagicMock()
mock_pp_group.is_first_rank = False
mock_pp_group.is_last_rank = False
mock_get_pp_group.return_value = mock_pp_group
# Setup tensor reception data
mock_pp_group.recv_tensor_dict.return_value = {"tensor": "data"}
# Create mock kv_connector_output - both finished_sending and finished_recving are False
mock_kv_connector_output = MagicMock()
mock_kv_connector_output.finished_sending = False
mock_kv_connector_output.finished_recving = False
# Mock return IntermediateTensors with kv_connector_output
mock_intermediate_output = MagicMock(spec=IntermediateTensors)
mock_intermediate_output.tensors = {"output_tensor": "data"}
mock_intermediate_output.kv_connector_output = mock_kv_connector_output
worker.model_runner.execute_model.return_value = mock_intermediate_output
mock_scheduler_output = MagicMock()
mock_scheduler_output.total_num_scheduled_tokens = 1
# Test execute_model
result = worker.execute_model(mock_scheduler_output)
# Verify tensor reception and sending
mock_pp_group.recv_tensor_dict.assert_called_once()
mock_pp_group.send_tensor_dict.assert_called_once()
# When both finished_sending and finished_recving are False, should return EMPTY_MODEL_RUNNER_OUTPUT directly
self.assertEqual(result, mock_empty_output)

View File

@@ -28,8 +28,7 @@ from torch_npu.op_plugin.atb._atb_ops import _register_atb_extensions
from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
has_kv_transfer_group)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import logger
from vllm.lora.request import LoRARequest
@@ -223,34 +222,36 @@ class NPUWorker(WorkerBase):
scheduler_output: "SchedulerOutput",
) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]:
intermediate_tensors = None
if not get_pp_group().is_first_rank:
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
if forward_pass and not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group()))
output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors)
if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
return output
assert isinstance(output, IntermediateTensors)
parallel_config = self.vllm_config.parallel_config
if parallel_config.distributed_executor_backend != "external_launcher" \
and not get_pp_group().is_last_rank:
assert isinstance(output, IntermediateTensors)
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
if not has_kv_transfer_group():
return None
assert parallel_config.distributed_executor_backend != (
"external_launcher") and not get_pp_group().is_last_rank
kv_connector_output = output.kv_connector_output
finished_sending = kv_connector_output.finished_sending
finished_recving = kv_connector_output.finished_recving
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
if not finished_sending and not finished_recving:
return EMPTY_MODEL_RUNNER_OUTPUT
kv_connector_output = output.kv_connector_output
if not kv_connector_output:
return None
new_output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
new_output.kv_connector_output = kv_connector_output
return new_output
assert isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput))
# In case of PP with kv transfer, we need to pass through the
# kv_connector_output
if (not kv_connector_output.finished_sending
and not kv_connector_output.finished_recving):
return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.kv_connector_output = kv_connector_output
return output
def load_model(self) -> None: