Files
enginex-mlu370-vllm/vllm-v0.6.2/benchmarks/concurrent_executor.py

149 lines
5.4 KiB
Python
Raw Normal View History

2026-02-04 17:22:39 +08:00
import json
import os
import sys
import time
import traceback
from dataclasses import dataclass, field
from typing import Optional, List
from tqdm.asyncio import tqdm
import requests
import concurrent
from backend_request_func import (RequestFuncInput, RequestFuncOutput, remove_prefix)
@dataclass
class MluRequestFuncInput(RequestFuncInput):
include_usage: bool = False
ignore_eos: bool = False
@dataclass
class MluRequestFuncOutput(RequestFuncOutput):
usage: dict = field(
default_factory=dict)
metric: dict = field(
default_factory=dict)
def sync_request_openai_completions(
request_func_input: MluRequestFuncInput,
pbar: Optional[tqdm] = None,
) -> MluRequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith(
("completions", "profile")
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
assert not request_func_input.use_beam_search
payload = {
"model": request_func_input.model,
"prompt": request_func_input.prompt,
"temperature": 0.0,
"best_of": request_func_input.best_of,
"max_tokens": request_func_input.output_len,
"ignore_eos": request_func_input.ignore_eos,
"logprobs": request_func_input.logprobs,
"stream": True,
"stream_options": {"include_usage": request_func_input.include_usage}
}
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
}
output = MluRequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
generated_text = ""
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
with requests.post(url=api_url, json=payload, headers=headers, stream=True) as response:
response.raise_for_status()
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\n"):
if chunk:
chunk = remove_prefix(chunk.decode("utf-8"), "data: ")
if chunk == "[DONE]":
latency = time.perf_counter() - st
else:
data = json.loads(chunk)
# NOTE: Some completion API might have a last
# usage summary response without a token so we
# want to check a token was generated
if "choices" in data and len(data["choices"]) > 0 and data["choices"][0]["text"]:
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
most_recent_timestamp = timestamp
generated_text += data["choices"][0]["text"]
if "usage" in data and data["usage"] is not None:
output.usage = data["usage"]
if "metric" in data and data["metric"] is not None:
output.metric = data["metric"]
output.generated_text = generated_text
output.success = True
output.latency = latency
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
pbar.update(1)
return output
class ConcurrentExecutor:
def __init__(self, concurrency_num, input_requests) -> None:
self.concurrency_num = concurrency_num
self.concurrency_tasks = []
self.input_requests_iter = iter(input_requests)
self.total_requests = len(input_requests)
self.send_requests = 0
self.recv_requests = 0
self.request_input_kwargs = {}
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=self.concurrency_num)
def config_pyload(self, **kwargs):
self.request_input_kwargs.update(**kwargs)
def run(self, pbar):
request_results = []
while self.recv_requests < self.total_requests:
if len(self.concurrency_tasks) < self.concurrency_num and self.send_requests < self.total_requests:
prompt, prompt_len, output_len = next(self.input_requests_iter)
self.request_input_kwargs['prompt'] = prompt
self.request_input_kwargs['prompt_len'] = prompt_len
self.request_input_kwargs['output_len'] = output_len
request_func_input = MluRequestFuncInput(**self.request_input_kwargs)
self.concurrency_tasks.append(
self.executor.submit(sync_request_openai_completions, request_func_input, pbar)
)
self.send_requests += 1
else:
done, pending = concurrent.futures.wait(self.concurrency_tasks, return_when="FIRST_COMPLETED")
self.recv_requests += len(done)
for task in done:
assert task.done()
request_results.append(task.result())
self.concurrency_tasks = list(pending)
return request_results