[Bugfix] Adopt the new changes on disaggregated pd from vllm main branch (#2122)
### What this PR does / why we need it?
We notice that vllm's main branch merged the PR
https://github.com/vllm-project/vllm/pull/21072 and
https://github.com/vllm-project/vllm/pull/21473 to support ray backend
and fix some rebase bug from previous change. Those changes makes the
disaggregate pd in vllm ascend breaks in some scenario.
In this PR, we adopt those changes to make sure the
`llmdatddist_c_mgr_connector` works fine on the newest vllm main branch.
### Does this PR introduce _any_ user-facing change?
No user face change.
### How was this patch tested?
relevant ut will be added to make sure the functionality of those
changes.
- vLLM version: v0.10.0
- vLLM main:
ad57f23f6a
---------
Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
This commit is contained in:
@@ -1472,6 +1472,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
hidden_states: torch.Tensor,
|
||||
num_scheduled_tokens: int,
|
||||
num_scheduled_tokens_np: np.ndarray,
|
||||
finished_sending: Optional[set[str]],
|
||||
finished_receiving: Optional[set[str]],
|
||||
) -> ModelRunnerOutput:
|
||||
assert self.input_batch.num_reqs ==\
|
||||
len(self.input_batch.pooling_params), \
|
||||
@@ -1506,7 +1508,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=pooler_output,
|
||||
)
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_receiving)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
@@ -1542,6 +1545,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
if not get_pp_group().is_last_rank:
|
||||
# For mid-pipeline stages, return the hidden states.
|
||||
if not broadcast_pp_output:
|
||||
if finished_sending or finished_recving:
|
||||
hidden_states.finished_sending = finished_sending
|
||||
hidden_states.finished_recving = finished_recving
|
||||
return hidden_states
|
||||
assert isinstance(hidden_states, IntermediateTensors)
|
||||
get_pp_group().send_tensor_dict(
|
||||
@@ -1550,7 +1556,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
else:
|
||||
if self.input_batch.pooling_params:
|
||||
return self._pool(hidden_states, num_scheduled_tokens,
|
||||
num_scheduled_tokens_np)
|
||||
num_scheduled_tokens_np,
|
||||
finished_sending, finished_recving)
|
||||
sample_hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
if broadcast_pp_output:
|
||||
|
||||
Reference in New Issue
Block a user