Files
enginex-mlu370-vllm/vllm-v0.6.2/benchmarks/concurrent_executor.py
2026-02-04 17:22:39 +08:00

149 lines
5.4 KiB
Python

import json
import os
import sys
import time
import traceback
from dataclasses import dataclass, field
from typing import Optional, List
from tqdm.asyncio import tqdm
import requests
import concurrent
from backend_request_func import (RequestFuncInput, RequestFuncOutput, remove_prefix)
@dataclass
class MluRequestFuncInput(RequestFuncInput):
include_usage: bool = False
ignore_eos: bool = False
@dataclass
class MluRequestFuncOutput(RequestFuncOutput):
usage: dict = field(
default_factory=dict)
metric: dict = field(
default_factory=dict)
def sync_request_openai_completions(
request_func_input: MluRequestFuncInput,
pbar: Optional[tqdm] = None,
) -> MluRequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith(
("completions", "profile")
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
assert not request_func_input.use_beam_search
payload = {
"model": request_func_input.model,
"prompt": request_func_input.prompt,
"temperature": 0.0,
"best_of": request_func_input.best_of,
"max_tokens": request_func_input.output_len,
"ignore_eos": request_func_input.ignore_eos,
"logprobs": request_func_input.logprobs,
"stream": True,
"stream_options": {"include_usage": request_func_input.include_usage}
}
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
}
output = MluRequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
generated_text = ""
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
with requests.post(url=api_url, json=payload, headers=headers, stream=True) as response:
response.raise_for_status()
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\n"):
if chunk:
chunk = remove_prefix(chunk.decode("utf-8"), "data: ")
if chunk == "[DONE]":
latency = time.perf_counter() - st
else:
data = json.loads(chunk)
# NOTE: Some completion API might have a last
# usage summary response without a token so we
# want to check a token was generated
if "choices" in data and len(data["choices"]) > 0 and data["choices"][0]["text"]:
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
most_recent_timestamp = timestamp
generated_text += data["choices"][0]["text"]
if "usage" in data and data["usage"] is not None:
output.usage = data["usage"]
if "metric" in data and data["metric"] is not None:
output.metric = data["metric"]
output.generated_text = generated_text
output.success = True
output.latency = latency
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
pbar.update(1)
return output
class ConcurrentExecutor:
def __init__(self, concurrency_num, input_requests) -> None:
self.concurrency_num = concurrency_num
self.concurrency_tasks = []
self.input_requests_iter = iter(input_requests)
self.total_requests = len(input_requests)
self.send_requests = 0
self.recv_requests = 0
self.request_input_kwargs = {}
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=self.concurrency_num)
def config_pyload(self, **kwargs):
self.request_input_kwargs.update(**kwargs)
def run(self, pbar):
request_results = []
while self.recv_requests < self.total_requests:
if len(self.concurrency_tasks) < self.concurrency_num and self.send_requests < self.total_requests:
prompt, prompt_len, output_len = next(self.input_requests_iter)
self.request_input_kwargs['prompt'] = prompt
self.request_input_kwargs['prompt_len'] = prompt_len
self.request_input_kwargs['output_len'] = output_len
request_func_input = MluRequestFuncInput(**self.request_input_kwargs)
self.concurrency_tasks.append(
self.executor.submit(sync_request_openai_completions, request_func_input, pbar)
)
self.send_requests += 1
else:
done, pending = concurrent.futures.wait(self.concurrency_tasks, return_when="FIRST_COMPLETED")
self.recv_requests += len(done)
for task in done:
assert task.done()
request_results.append(task.result())
self.concurrency_tasks = list(pending)
return request_results