20 lines
773 B
Python
20 lines
773 B
Python
|
|
import asyncio
|
||
|
|
from typing import List
|
||
|
|
|
||
|
|
from vllm.v1.outputs import PoolerOutput, SamplerOutput
|
||
|
|
from vllm.sequence import ExecuteModelRequest
|
||
|
|
|
||
|
|
# class DistributedExecutorBase():
|
||
|
|
# """Abstract superclass of distributed executor implementations."""
|
||
|
|
|
||
|
|
async def execute_model_async(
|
||
|
|
self,
|
||
|
|
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||
|
|
if self.parallel_worker_tasks is None:
|
||
|
|
# Start model execution loop running in the parallel workers
|
||
|
|
self.parallel_worker_tasks = asyncio.create_task(
|
||
|
|
self._start_worker_execution_loop())
|
||
|
|
await asyncio.sleep(0)
|
||
|
|
# Only the driver worker returns the sampling results.
|
||
|
|
await asyncio.sleep(0)
|
||
|
|
return await self._driver_execute_model_async(execute_model_req)
|