diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 830270ea..4ab4f06d 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -21,7 +21,7 @@ import math import time from collections import defaultdict from contextlib import contextmanager, nullcontext -from copy import deepcopy +from copy import copy, deepcopy from dataclasses import dataclass from multiprocessing import Manager from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union @@ -189,7 +189,6 @@ class ExecuteModelState(NamedTuple): hidden_states: torch.Tensor sample_hidden_states: torch.Tensor aux_hidden_states: list[torch.Tensor] | None - kv_connector_output: KVConnectorOutput | None attn_metadata: dict[str, Any] positions: torch.Tensor @@ -1450,6 +1449,7 @@ class NPUModelRunner(GPUModelRunner): # For mid-pipeline stages, return the hidden states. if not broadcast_pp_output: hidden_states.kv_connector_output = kv_connector_output + self.kv_connector_output = kv_connector_output if need_dump: assert self.debugger is not None self.debugger.stop() @@ -1496,19 +1496,32 @@ class NPUModelRunner(GPUModelRunner): hidden_states, sample_hidden_states, aux_hidden_states, - kv_connector_output, attn_metadata, positions, ) + self.kv_connector_output = kv_connector_output return None @torch.inference_mode def sample_tokens( self, grammar_output: "GrammarOutput | None" ) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors: + kv_connector_output = self.kv_connector_output + self.kv_connector_output = None + if self.execute_model_state is None: # Nothing to do (PP non-final rank case), output isn't used. - return None # noqa + if not kv_connector_output: + return None # noqa + # In case of PP with kv transfer, we need to pass through the + # kv_connector_output + if kv_connector_output.is_empty(): + return EMPTY_MODEL_RUNNER_OUTPUT + + output = copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.kv_connector_output = kv_connector_output + return output + need_dump = self.dump_enable and self.debugger is not None # Unpack ephemeral state. ( @@ -1517,8 +1530,7 @@ class NPUModelRunner(GPUModelRunner): spec_decode_metadata, hidden_states, sample_hidden_states, - aux_hidden_states, # noqa - kv_connector_output, + aux_hidden_states, attn_metadata, positions, ) = self.execute_model_state