[Bugfix] Adopt the new changes on disaggregated pd from vllm main branch (#2122)

### What this PR does / why we need it?
We notice that vllm's main branch merged the PR
https://github.com/vllm-project/vllm/pull/21072 and
https://github.com/vllm-project/vllm/pull/21473 to support ray backend
and fix some rebase bug from previous change. Those changes makes the
disaggregate pd in vllm ascend breaks in some scenario.

In this PR, we adopt those changes to make sure the
`llmdatddist_c_mgr_connector` works fine on the newest vllm main branch.

### Does this PR introduce _any_ user-facing change?

No user face change.

### How was this patch tested?
relevant ut will be added to make sure the functionality of those
changes.

- vLLM version: v0.10.0
- vLLM main:
ad57f23f6a

---------

Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
This commit is contained in:
Pleaplusone
2025-08-04 10:08:58 +08:00
committed by GitHub
parent ddaded1537
commit f939381c6f
2 changed files with 24 additions and 6 deletions

View File

@@ -1472,6 +1472,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
hidden_states: torch.Tensor,
num_scheduled_tokens: int,
num_scheduled_tokens_np: np.ndarray,
finished_sending: Optional[set[str]],
finished_receiving: Optional[set[str]],
) -> ModelRunnerOutput:
assert self.input_batch.num_reqs ==\
len(self.input_batch.pooling_params), \
@@ -1506,7 +1508,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
logprobs=None,
prompt_logprobs_dict={},
pooler_output=pooler_output,
)
finished_sending=finished_sending,
finished_recving=finished_receiving)
@torch.inference_mode()
def execute_model(
@@ -1542,6 +1545,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if not get_pp_group().is_last_rank:
# For mid-pipeline stages, return the hidden states.
if not broadcast_pp_output:
if finished_sending or finished_recving:
hidden_states.finished_sending = finished_sending
hidden_states.finished_recving = finished_recving
return hidden_states
assert isinstance(hidden_states, IntermediateTensors)
get_pp_group().send_tensor_dict(
@@ -1550,7 +1556,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
else:
if self.input_batch.pooling_params:
return self._pool(hidden_states, num_scheduled_tokens,
num_scheduled_tokens_np)
num_scheduled_tokens_np,
finished_sending, finished_recving)
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
if broadcast_pp_output: