[Bugfix] fix pipeline parallelism bug introduced by async-scheduling refactor work (#4973)
### What this PR does / why we need it?
Currently, when using pipeline parallel and pd disaggregate,
model_runner will return None on non-last-pp-rank stages in
`sample_tokens`, which will cause assert error in vllm
KVOutputAggregator on [this
line](https://github.com/vllm-project/vllm/blob/main/vllm/distributed/kv_transfer/kv_connector/utils.py#L84).
In fact, all pp workers should return a model_runner_output which
contains kv_connector_output to do aggregate in Enginecore scheduler
process to ensure all kv transfer is finished for kv cache releasing
later.
To fix this issue, this PR follows gpu_model_runner in vllm, passing
kv_connector_output in `sample_tokens` to make sure all ranks will
return a ModelRunnerOutput, in non-last-pp-rank workers, it will return
EMPTY_MODEL_RUNNER_OUTPUT with kv_connector_output.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: lidenghui <lidenghui1110@gmail.com>
This commit is contained in:
@@ -21,7 +21,7 @@ import math
|
|||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
from copy import deepcopy
|
from copy import copy, deepcopy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from multiprocessing import Manager
|
from multiprocessing import Manager
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union
|
||||||
@@ -189,7 +189,6 @@ class ExecuteModelState(NamedTuple):
|
|||||||
hidden_states: torch.Tensor
|
hidden_states: torch.Tensor
|
||||||
sample_hidden_states: torch.Tensor
|
sample_hidden_states: torch.Tensor
|
||||||
aux_hidden_states: list[torch.Tensor] | None
|
aux_hidden_states: list[torch.Tensor] | None
|
||||||
kv_connector_output: KVConnectorOutput | None
|
|
||||||
attn_metadata: dict[str, Any]
|
attn_metadata: dict[str, Any]
|
||||||
positions: torch.Tensor
|
positions: torch.Tensor
|
||||||
|
|
||||||
@@ -1450,6 +1449,7 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
# For mid-pipeline stages, return the hidden states.
|
# For mid-pipeline stages, return the hidden states.
|
||||||
if not broadcast_pp_output:
|
if not broadcast_pp_output:
|
||||||
hidden_states.kv_connector_output = kv_connector_output
|
hidden_states.kv_connector_output = kv_connector_output
|
||||||
|
self.kv_connector_output = kv_connector_output
|
||||||
if need_dump:
|
if need_dump:
|
||||||
assert self.debugger is not None
|
assert self.debugger is not None
|
||||||
self.debugger.stop()
|
self.debugger.stop()
|
||||||
@@ -1496,19 +1496,32 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
sample_hidden_states,
|
sample_hidden_states,
|
||||||
aux_hidden_states,
|
aux_hidden_states,
|
||||||
kv_connector_output,
|
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
positions,
|
positions,
|
||||||
)
|
)
|
||||||
|
self.kv_connector_output = kv_connector_output
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@torch.inference_mode
|
@torch.inference_mode
|
||||||
def sample_tokens(
|
def sample_tokens(
|
||||||
self, grammar_output: "GrammarOutput | None"
|
self, grammar_output: "GrammarOutput | None"
|
||||||
) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors:
|
) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors:
|
||||||
|
kv_connector_output = self.kv_connector_output
|
||||||
|
self.kv_connector_output = None
|
||||||
|
|
||||||
if self.execute_model_state is None:
|
if self.execute_model_state is None:
|
||||||
# Nothing to do (PP non-final rank case), output isn't used.
|
# Nothing to do (PP non-final rank case), output isn't used.
|
||||||
return None # noqa
|
if not kv_connector_output:
|
||||||
|
return None # noqa
|
||||||
|
# In case of PP with kv transfer, we need to pass through the
|
||||||
|
# kv_connector_output
|
||||||
|
if kv_connector_output.is_empty():
|
||||||
|
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||||
|
|
||||||
|
output = copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||||
|
output.kv_connector_output = kv_connector_output
|
||||||
|
return output
|
||||||
|
|
||||||
need_dump = self.dump_enable and self.debugger is not None
|
need_dump = self.dump_enable and self.debugger is not None
|
||||||
# Unpack ephemeral state.
|
# Unpack ephemeral state.
|
||||||
(
|
(
|
||||||
@@ -1517,8 +1530,7 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
spec_decode_metadata,
|
spec_decode_metadata,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
sample_hidden_states,
|
sample_hidden_states,
|
||||||
aux_hidden_states, # noqa
|
aux_hidden_states,
|
||||||
kv_connector_output,
|
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
positions,
|
positions,
|
||||||
) = self.execute_model_state
|
) = self.execute_model_state
|
||||||
|
|||||||
Reference in New Issue
Block a user