forked from EngineX-Cambricon/enginex-mlu370-vllm
add qwen3
This commit is contained in:
78
vllm-v0.6.2/tests/mq_llm_engine/utils.py
Normal file
78
vllm-v0.6.2/tests/mq_llm_engine/utils.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import asyncio
|
||||
import multiprocessing
|
||||
from typing import Callable, Tuple, Union
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||
from vllm.engine.multiprocessing.engine import MQLLMEngine
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
|
||||
|
||||
async def generate(
|
||||
client: MQLLMEngineClient,
|
||||
request_id: str,
|
||||
num_tokens: int,
|
||||
return_output: bool = False) -> Union[RequestOutput, Tuple[int, str]]:
|
||||
|
||||
final_output = None
|
||||
count = 0
|
||||
async for out in client.generate(
|
||||
request_id=request_id,
|
||||
prompt="Hello my name is Robert and",
|
||||
sampling_params=SamplingParams(max_tokens=num_tokens,
|
||||
temperature=0)):
|
||||
|
||||
count += 1
|
||||
final_output = out
|
||||
await asyncio.sleep(0.)
|
||||
|
||||
if return_output:
|
||||
return final_output
|
||||
|
||||
# Confirm we generated all the tokens we expected.
|
||||
return count, request_id
|
||||
|
||||
|
||||
def run_normal(engine_args: AsyncEngineArgs, ipc_path: str):
|
||||
# Make engine.
|
||||
engine = MQLLMEngine.from_engine_args(
|
||||
engine_args=engine_args,
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT,
|
||||
ipc_path=ipc_path)
|
||||
|
||||
# Run engine.
|
||||
engine.start()
|
||||
|
||||
|
||||
class RemoteMQLLMEngine:
|
||||
|
||||
def __init__(self,
|
||||
engine_args: AsyncEngineArgs,
|
||||
ipc_path: str,
|
||||
run_fn: Callable = run_normal) -> None:
|
||||
|
||||
self.engine_args = engine_args
|
||||
self.ipc_path = ipc_path
|
||||
context = multiprocessing.get_context("spawn")
|
||||
self.proc = context.Process(target=run_fn,
|
||||
args=(engine_args, ipc_path))
|
||||
self.proc.start()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.proc.kill()
|
||||
|
||||
async def make_client(self) -> MQLLMEngineClient:
|
||||
engine_config = self.engine_args.create_engine_config()
|
||||
client = MQLLMEngineClient(self.ipc_path, engine_config, self.proc.pid)
|
||||
while True:
|
||||
try:
|
||||
await client.setup()
|
||||
break
|
||||
except TimeoutError:
|
||||
assert self.proc.is_alive()
|
||||
return client
|
||||
Reference in New Issue
Block a user