[Bugfix] Adopt the new changes on disaggregated pd from vllm main branch (#2122)

### What this PR does / why we need it?
We notice that vllm's main branch merged the PR
https://github.com/vllm-project/vllm/pull/21072 and
https://github.com/vllm-project/vllm/pull/21473 to support ray backend
and fix some rebase bug from previous change. Those changes makes the
disaggregate pd in vllm ascend breaks in some scenario.

In this PR, we adopt those changes to make sure the
`llmdatddist_c_mgr_connector` works fine on the newest vllm main branch.

### Does this PR introduce _any_ user-facing change?

No user face change.

### How was this patch tested?
relevant ut will be added to make sure the functionality of those
changes.

- vLLM version: v0.10.0
- vLLM main:
ad57f23f6a

---------

Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
This commit is contained in:
Pleaplusone
2025-08-04 10:08:58 +08:00
committed by GitHub
parent ddaded1537
commit f939381c6f
2 changed files with 24 additions and 6 deletions

View File

@@ -1472,6 +1472,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
hidden_states: torch.Tensor,
num_scheduled_tokens: int,
num_scheduled_tokens_np: np.ndarray,
finished_sending: Optional[set[str]],
finished_receiving: Optional[set[str]],
) -> ModelRunnerOutput:
assert self.input_batch.num_reqs ==\
len(self.input_batch.pooling_params), \
@@ -1506,7 +1508,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
logprobs=None,
prompt_logprobs_dict={},
pooler_output=pooler_output,
)
finished_sending=finished_sending,
finished_recving=finished_receiving)
@torch.inference_mode()
def execute_model(
@@ -1542,6 +1545,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if not get_pp_group().is_last_rank:
# For mid-pipeline stages, return the hidden states.
if not broadcast_pp_output:
if finished_sending or finished_recving:
hidden_states.finished_sending = finished_sending
hidden_states.finished_recving = finished_recving
return hidden_states
assert isinstance(hidden_states, IntermediateTensors)
get_pp_group().send_tensor_dict(
@@ -1550,7 +1556,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
else:
if self.input_batch.pooling_params:
return self._pool(hidden_states, num_scheduled_tokens,
num_scheduled_tokens_np)
num_scheduled_tokens_np,
finished_sending, finished_recving)
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
if broadcast_pp_output:

View File

@@ -17,6 +17,7 @@
# Adapted from vllm-project/vllm/vllm/worker/gpu_worker.py
#
import copy
from typing import Optional
import torch
@@ -27,7 +28,8 @@ from vllm import envs
from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
has_kv_transfer_group)
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import logger
from vllm.lora.request import LoRARequest
@@ -35,7 +37,7 @@ from vllm.sequence import IntermediateTensors
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, GiB_bytes
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
from vllm.v1.worker.worker_base import WorkerBase
from vllm_ascend.ascend_config import init_ascend_config
@@ -204,9 +206,18 @@ class NPUWorker(WorkerBase):
assert isinstance(output, IntermediateTensors)
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
return None
if not has_kv_transfer_group():
return None
new_output = EMPTY_MODEL_RUNNER_OUTPUT
if output.finished_sending or output.finished_recving:
new_output = copy.copy(new_output)
new_output.finished_sending = output.finished_sending
new_output.finished_recving = output.finished_recving
output = new_output
assert isinstance(output, ModelRunnerOutput)
return output if self.is_driver_worker else None
return output
def load_model(self) -> None:
if self.vllm_config.model_config.enable_sleep_mode: