diff --git a/tests/ut/worker/test_worker_v1.py b/tests/ut/worker/test_worker_v1.py index ccbd3ae..7ae9aa3 100644 --- a/tests/ut/worker/test_worker_v1.py +++ b/tests/ut/worker/test_worker_v1.py @@ -828,6 +828,7 @@ class TestNPUWorker(TestBase): # Mock scheduler_output and return result mock_scheduler_output = MagicMock() + mock_scheduler_output.total_num_scheduled_tokens = 1 # Create a real ModelRunnerOutput instance or mock mock_model_output = MagicMock(spec=ModelRunnerOutput) worker.model_runner.execute_model.return_value = mock_model_output @@ -842,9 +843,8 @@ class TestNPUWorker(TestBase): @patch("vllm_ascend.worker.worker_v1.get_pp_group") @patch("vllm_ascend.worker.worker_v1.get_tp_group") - @patch("vllm_ascend.worker.worker_v1.has_kv_transfer_group") - def test_execute_model_middle_rank(self, mock_has_kv_transfer_group, - mock_get_tp_group, mock_get_pp_group): + def test_execute_model_middle_rank(self, mock_get_tp_group, + mock_get_pp_group): """Test execute_model method - middle rank case""" from vllm.sequence import IntermediateTensors @@ -875,10 +875,8 @@ class TestNPUWorker(TestBase): ) worker.model_runner.execute_model.return_value = mock_intermediate_output - # Set has_kv_transfer_group returns False - mock_has_kv_transfer_group.return_value = False - mock_scheduler_output = MagicMock() + mock_scheduler_output.total_num_scheduled_tokens = 1 # Test execute_model result = worker.execute_model(mock_scheduler_output) @@ -926,6 +924,7 @@ class TestNPUWorker(TestBase): # Mock return result mock_scheduler_output = MagicMock() + mock_scheduler_output.total_num_scheduled_tokens = 1 mock_model_output = MagicMock(spec=ModelRunnerOutput) worker.model_runner.execute_model.return_value = mock_model_output @@ -1150,3 +1149,55 @@ class TestNPUWorker(TestBase): # Verify calls worker.model_runner.initialize_kv_cache.assert_called_once_with( mock_kv_cache_config) + + @patch("vllm_ascend.worker.worker_v1.get_pp_group") + @patch("vllm_ascend.worker.worker_v1.get_tp_group") + @patch("vllm_ascend.worker.worker_v1.EMPTY_MODEL_RUNNER_OUTPUT") + def test_execute_model_kv_connector_not_finished(self, mock_empty_output, + mock_get_tp_group, + mock_get_pp_group): + """Test execute_model method - kv_connector_output not finished sending/recving case""" + from vllm.sequence import IntermediateTensors + + from vllm_ascend.worker.worker_v1 import NPUWorker + + # Create worker mock + with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None): + worker = NPUWorker() + worker.model_runner = MagicMock() + worker.vllm_config = MagicMock() + worker.vllm_config.parallel_config = MagicMock() + worker.vllm_config.parallel_config.distributed_executor_backend = "ray" + + # Set as middle rank (not first, not last) + mock_pp_group = MagicMock() + mock_pp_group.is_first_rank = False + mock_pp_group.is_last_rank = False + mock_get_pp_group.return_value = mock_pp_group + + # Setup tensor reception data + mock_pp_group.recv_tensor_dict.return_value = {"tensor": "data"} + + # Create mock kv_connector_output - both finished_sending and finished_recving are False + mock_kv_connector_output = MagicMock() + mock_kv_connector_output.finished_sending = False + mock_kv_connector_output.finished_recving = False + + # Mock return IntermediateTensors with kv_connector_output + mock_intermediate_output = MagicMock(spec=IntermediateTensors) + mock_intermediate_output.tensors = {"output_tensor": "data"} + mock_intermediate_output.kv_connector_output = mock_kv_connector_output + worker.model_runner.execute_model.return_value = mock_intermediate_output + + mock_scheduler_output = MagicMock() + mock_scheduler_output.total_num_scheduled_tokens = 1 + + # Test execute_model + result = worker.execute_model(mock_scheduler_output) + + # Verify tensor reception and sending + mock_pp_group.recv_tensor_dict.assert_called_once() + mock_pp_group.send_tensor_dict.assert_called_once() + + # When both finished_sending and finished_recving are False, should return EMPTY_MODEL_RUNNER_OUTPUT directly + self.assertEqual(result, mock_empty_output) diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 6298d34..820ec63 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -28,8 +28,7 @@ from torch_npu.op_plugin.atb._atb_ops import _register_atb_extensions from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) -from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, - has_kv_transfer_group) +from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import logger from vllm.lora.request import LoRARequest @@ -223,34 +222,36 @@ class NPUWorker(WorkerBase): scheduler_output: "SchedulerOutput", ) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]: intermediate_tensors = None - if not get_pp_group().is_first_rank: + forward_pass = scheduler_output.total_num_scheduled_tokens > 0 + if forward_pass and not get_pp_group().is_first_rank: intermediate_tensors = IntermediateTensors( get_pp_group().recv_tensor_dict( all_gather_group=get_tp_group())) output = self.model_runner.execute_model(scheduler_output, intermediate_tensors) + if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)): + return output + + assert isinstance(output, IntermediateTensors) parallel_config = self.vllm_config.parallel_config - if parallel_config.distributed_executor_backend != "external_launcher" \ - and not get_pp_group().is_last_rank: - assert isinstance(output, IntermediateTensors) - get_pp_group().send_tensor_dict(output.tensors, - all_gather_group=get_tp_group()) - if not has_kv_transfer_group(): - return None + assert parallel_config.distributed_executor_backend != ( + "external_launcher") and not get_pp_group().is_last_rank - kv_connector_output = output.kv_connector_output - finished_sending = kv_connector_output.finished_sending - finished_recving = kv_connector_output.finished_recving + get_pp_group().send_tensor_dict(output.tensors, + all_gather_group=get_tp_group()) - if not finished_sending and not finished_recving: - return EMPTY_MODEL_RUNNER_OUTPUT + kv_connector_output = output.kv_connector_output + if not kv_connector_output: + return None - new_output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) - new_output.kv_connector_output = kv_connector_output - return new_output - - assert isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)) + # In case of PP with kv transfer, we need to pass through the + # kv_connector_output + if (not kv_connector_output.finished_sending + and not kv_connector_output.finished_recving): + return EMPTY_MODEL_RUNNER_OUTPUT + output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.kv_connector_output = kv_connector_output return output def load_model(self) -> None: