Perf(PP): support PP with async scheduling. (#7136)
### What this PR does / why we need it?
Follow up the PR https://github.com/vllm-project/vllm/pull/32618, this
PR provides async scheduling support for PP in vllm-ascend.
---
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
@@ -1430,6 +1430,11 @@ class NPUModelRunner(GPUModelRunner):
|
||||
|
||||
if self.execute_model_state is None:
|
||||
# Nothing to do (PP non-final rank case), output isn't used.
|
||||
# receive sampled token ids from the last PP rank when using
|
||||
# async scheduling + pipeline parallelism so downstream code
|
||||
# (e.g., PCP input preparation) can access them.
|
||||
if self.use_async_scheduling and get_pp_group().world_size > 1:
|
||||
self._pp_receive_prev_sampled_token_ids_to_input_batch()
|
||||
if not kv_connector_output:
|
||||
return None # noqa
|
||||
# In case of PP with kv transfer, we need to pass through the
|
||||
@@ -1564,6 +1569,14 @@ class NPUModelRunner(GPUModelRunner):
|
||||
global_stream().wait_event(self.sampling_done_event)
|
||||
self._update_states_after_model_execute(sampler_output.sampled_token_ids, scheduler_output)
|
||||
|
||||
# In async scheduling + PP, broadcast sampled token ids from the
|
||||
# last PP rank so other PP ranks can receive them without going
|
||||
# through the scheduler/engine IPC path.
|
||||
if self.use_async_scheduling:
|
||||
pp = get_pp_group()
|
||||
if pp.world_size > 1 and pp.is_last_rank:
|
||||
self._pp_broadcast_prev_sampled_token_ids(sampler_output.sampled_token_ids)
|
||||
|
||||
if not self.use_async_scheduling:
|
||||
return model_runner_output
|
||||
return AsyncGPUModelRunnerOutput(
|
||||
|
||||
Reference in New Issue
Block a user