first commit
This commit is contained in:
108
vllm/v1/executor/ray_distributed_executor.py
Normal file
108
vllm/v1/executor/ray_distributed_executor.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from concurrent.futures import Future
|
||||
from typing import Optional, Union
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||
from vllm.executor.ray_distributed_executor import ( # noqa
|
||||
RayDistributedExecutor as RayDistributedExecutorV0)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FutureWrapper(Future):
|
||||
"""A wrapper around Ray output reference to meet the interface
|
||||
of .execute_model(): The top level (core busy loop) expects .result() api
|
||||
to block and return a single output.
|
||||
|
||||
If aggregator is provided, the outputs from all workers are aggregated upon
|
||||
the result() call. If not only the first worker's output is returned.
|
||||
"""
|
||||
|
||||
def __init__(self, refs, aggregator: Optional[KVOutputAggregator] = None):
|
||||
super().__init__()
|
||||
self.refs = refs
|
||||
self.aggregator = aggregator
|
||||
|
||||
def result(self, timeout=None):
|
||||
if timeout is not None:
|
||||
raise NotImplementedError("timeout is not supported")
|
||||
|
||||
if self.aggregator is None:
|
||||
return self.refs[0].get()
|
||||
|
||||
outputs = [ref.get() for ref in self.refs]
|
||||
return self.aggregator.aggregate(outputs, output_rank=0)
|
||||
|
||||
|
||||
class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
|
||||
"""Ray distributed executor using Ray Compiled Graphs."""
|
||||
|
||||
supports_pp: bool = True
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
super()._init_executor()
|
||||
|
||||
# KV connector setup
|
||||
self.has_connector = self.vllm_config.kv_transfer_config is not None
|
||||
|
||||
@property
|
||||
def max_concurrent_batches(self) -> int:
|
||||
"""Ray distributed executor supports pipeline parallelism,
|
||||
meaning that it allows PP size batches to be executed concurrently.
|
||||
"""
|
||||
if self.scheduler_config.async_scheduling:
|
||||
return 2
|
||||
return self.parallel_config.pipeline_parallel_size
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
non_block: bool = False,
|
||||
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
|
||||
"""Execute the model on the Ray workers.
|
||||
|
||||
Args:
|
||||
scheduler_output: The scheduler output to execute.
|
||||
non_block: If True, the method will return a Future.
|
||||
|
||||
Returns:
|
||||
The model runner output.
|
||||
"""
|
||||
# Build the compiled DAG for the first time.
|
||||
if self.forward_dag is None: # type: ignore
|
||||
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
|
||||
|
||||
refs = self.forward_dag.execute(scheduler_output) # type: ignore
|
||||
|
||||
if not self.has_connector:
|
||||
# Get output only from a single worker (output_rank)
|
||||
# When PP is not used, we block here until the result is available.
|
||||
if not non_block:
|
||||
return refs[0].get()
|
||||
|
||||
# When PP is used, we return a FutureWrapper immediately so that
|
||||
# the scheduler can yield to the next batch.
|
||||
return FutureWrapper(refs)
|
||||
|
||||
# Get output from all workers when connector is present
|
||||
if not non_block:
|
||||
# Block and get results from all workers
|
||||
outputs = [ref.get() for ref in refs]
|
||||
return self.kv_output_aggregator.aggregate(outputs)
|
||||
|
||||
# Return a future that will aggregate outputs from all workers
|
||||
return FutureWrapper(refs, self.kv_output_aggregator)
|
||||
|
||||
def reinitialize_distributed(
|
||||
self, reconfig_request: ReconfigureDistributedRequest) -> None:
|
||||
self._run_workers("reinitialize_distributed", reconfig_request)
|
||||
if reconfig_request.new_data_parallel_rank == \
|
||||
ReconfigureRankType.SHUTDOWN_CURRENT_RANK:
|
||||
self.shutdown()
|
||||
Reference in New Issue
Block a user