[Bugfix] Adopt the new changes on disaggregated pd from vllm main branch (#2122)
### What this PR does / why we need it?
We notice that vllm's main branch merged the PR
https://github.com/vllm-project/vllm/pull/21072 and
https://github.com/vllm-project/vllm/pull/21473 to support ray backend
and fix some rebase bug from previous change. Those changes makes the
disaggregate pd in vllm ascend breaks in some scenario.
In this PR, we adopt those changes to make sure the
`llmdatddist_c_mgr_connector` works fine on the newest vllm main branch.
### Does this PR introduce _any_ user-facing change?
No user face change.
### How was this patch tested?
relevant ut will be added to make sure the functionality of those
changes.
- vLLM version: v0.10.0
- vLLM main:
ad57f23f6a
---------
Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
This commit is contained in:
@@ -1472,6 +1472,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
hidden_states: torch.Tensor,
|
||||
num_scheduled_tokens: int,
|
||||
num_scheduled_tokens_np: np.ndarray,
|
||||
finished_sending: Optional[set[str]],
|
||||
finished_receiving: Optional[set[str]],
|
||||
) -> ModelRunnerOutput:
|
||||
assert self.input_batch.num_reqs ==\
|
||||
len(self.input_batch.pooling_params), \
|
||||
@@ -1506,7 +1508,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=pooler_output,
|
||||
)
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_receiving)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
@@ -1542,6 +1545,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
if not get_pp_group().is_last_rank:
|
||||
# For mid-pipeline stages, return the hidden states.
|
||||
if not broadcast_pp_output:
|
||||
if finished_sending or finished_recving:
|
||||
hidden_states.finished_sending = finished_sending
|
||||
hidden_states.finished_recving = finished_recving
|
||||
return hidden_states
|
||||
assert isinstance(hidden_states, IntermediateTensors)
|
||||
get_pp_group().send_tensor_dict(
|
||||
@@ -1550,7 +1556,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
else:
|
||||
if self.input_batch.pooling_params:
|
||||
return self._pool(hidden_states, num_scheduled_tokens,
|
||||
num_scheduled_tokens_np)
|
||||
num_scheduled_tokens_np,
|
||||
finished_sending, finished_recving)
|
||||
sample_hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
if broadcast_pp_output:
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
# Adapted from vllm-project/vllm/vllm/worker/gpu_worker.py
|
||||
#
|
||||
|
||||
import copy
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@@ -27,7 +28,8 @@ from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
||||
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
|
||||
has_kv_transfer_group)
|
||||
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
||||
from vllm.logger import logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@@ -35,7 +37,7 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, GiB_bytes
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
|
||||
from vllm_ascend.ascend_config import init_ascend_config
|
||||
@@ -204,9 +206,18 @@ class NPUWorker(WorkerBase):
|
||||
assert isinstance(output, IntermediateTensors)
|
||||
get_pp_group().send_tensor_dict(output.tensors,
|
||||
all_gather_group=get_tp_group())
|
||||
return None
|
||||
if not has_kv_transfer_group():
|
||||
return None
|
||||
|
||||
new_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||
if output.finished_sending or output.finished_recving:
|
||||
new_output = copy.copy(new_output)
|
||||
new_output.finished_sending = output.finished_sending
|
||||
new_output.finished_recving = output.finished_recving
|
||||
output = new_output
|
||||
|
||||
assert isinstance(output, ModelRunnerOutput)
|
||||
return output if self.is_driver_worker else None
|
||||
return output
|
||||
|
||||
def load_model(self) -> None:
|
||||
if self.vllm_config.model_config.enable_sleep_mode:
|
||||
|
||||
Reference in New Issue
Block a user