diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 55bb562..01f01e6 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1472,6 +1472,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): hidden_states: torch.Tensor, num_scheduled_tokens: int, num_scheduled_tokens_np: np.ndarray, + finished_sending: Optional[set[str]], + finished_receiving: Optional[set[str]], ) -> ModelRunnerOutput: assert self.input_batch.num_reqs ==\ len(self.input_batch.pooling_params), \ @@ -1506,7 +1508,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): logprobs=None, prompt_logprobs_dict={}, pooler_output=pooler_output, - ) + finished_sending=finished_sending, + finished_recving=finished_receiving) @torch.inference_mode() def execute_model( @@ -1542,6 +1545,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): if not get_pp_group().is_last_rank: # For mid-pipeline stages, return the hidden states. if not broadcast_pp_output: + if finished_sending or finished_recving: + hidden_states.finished_sending = finished_sending + hidden_states.finished_recving = finished_recving return hidden_states assert isinstance(hidden_states, IntermediateTensors) get_pp_group().send_tensor_dict( @@ -1550,7 +1556,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): else: if self.input_batch.pooling_params: return self._pool(hidden_states, num_scheduled_tokens, - num_scheduled_tokens_np) + num_scheduled_tokens_np, + finished_sending, finished_recving) sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) if broadcast_pp_output: diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 15dba8b..4988ef4 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -17,6 +17,7 @@ # Adapted from vllm-project/vllm/vllm/worker/gpu_worker.py # +import copy from typing import Optional import torch @@ -27,7 +28,8 @@ from vllm import envs from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) -from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized +from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, + has_kv_transfer_group) from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import logger from vllm.lora.request import LoRARequest @@ -35,7 +37,7 @@ from vllm.sequence import IntermediateTensors from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, GiB_bytes from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec -from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput from vllm.v1.worker.worker_base import WorkerBase from vllm_ascend.ascend_config import init_ascend_config @@ -204,9 +206,18 @@ class NPUWorker(WorkerBase): assert isinstance(output, IntermediateTensors) get_pp_group().send_tensor_dict(output.tensors, all_gather_group=get_tp_group()) - return None + if not has_kv_transfer_group(): + return None + + new_output = EMPTY_MODEL_RUNNER_OUTPUT + if output.finished_sending or output.finished_recving: + new_output = copy.copy(new_output) + new_output.finished_sending = output.finished_sending + new_output.finished_recving = output.finished_recving + output = new_output + assert isinstance(output, ModelRunnerOutput) - return output if self.is_driver_worker else None + return output def load_model(self) -> None: if self.vllm_config.model_config.enable_sleep_mode: