149 lines
5.4 KiB
Python
149 lines
5.4 KiB
Python
import json
|
|
import os
|
|
import sys
|
|
import time
|
|
import traceback
|
|
from dataclasses import dataclass, field
|
|
from typing import Optional, List
|
|
|
|
from tqdm.asyncio import tqdm
|
|
|
|
import requests
|
|
import concurrent
|
|
|
|
from backend_request_func import (RequestFuncInput, RequestFuncOutput, remove_prefix)
|
|
|
|
@dataclass
|
|
class MluRequestFuncInput(RequestFuncInput):
|
|
include_usage: bool = False
|
|
ignore_eos: bool = False
|
|
|
|
|
|
@dataclass
|
|
class MluRequestFuncOutput(RequestFuncOutput):
|
|
usage: dict = field(
|
|
default_factory=dict)
|
|
metric: dict = field(
|
|
default_factory=dict)
|
|
|
|
|
|
def sync_request_openai_completions(
|
|
request_func_input: MluRequestFuncInput,
|
|
pbar: Optional[tqdm] = None,
|
|
) -> MluRequestFuncOutput:
|
|
api_url = request_func_input.api_url
|
|
assert api_url.endswith(
|
|
("completions", "profile")
|
|
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
|
|
|
|
assert not request_func_input.use_beam_search
|
|
payload = {
|
|
"model": request_func_input.model,
|
|
"prompt": request_func_input.prompt,
|
|
"temperature": 0.0,
|
|
"best_of": request_func_input.best_of,
|
|
"max_tokens": request_func_input.output_len,
|
|
"ignore_eos": request_func_input.ignore_eos,
|
|
"logprobs": request_func_input.logprobs,
|
|
"stream": True,
|
|
"stream_options": {"include_usage": request_func_input.include_usage}
|
|
}
|
|
headers = {
|
|
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
|
|
}
|
|
|
|
output = MluRequestFuncOutput()
|
|
output.prompt_len = request_func_input.prompt_len
|
|
|
|
generated_text = ""
|
|
ttft = 0.0
|
|
st = time.perf_counter()
|
|
most_recent_timestamp = st
|
|
try:
|
|
with requests.post(url=api_url, json=payload, headers=headers, stream=True) as response:
|
|
response.raise_for_status()
|
|
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\n"):
|
|
if chunk:
|
|
chunk = remove_prefix(chunk.decode("utf-8"), "data: ")
|
|
if chunk == "[DONE]":
|
|
latency = time.perf_counter() - st
|
|
else:
|
|
data = json.loads(chunk)
|
|
|
|
# NOTE: Some completion API might have a last
|
|
# usage summary response without a token so we
|
|
# want to check a token was generated
|
|
if "choices" in data and len(data["choices"]) > 0 and data["choices"][0]["text"]:
|
|
timestamp = time.perf_counter()
|
|
# First token
|
|
if ttft == 0.0:
|
|
ttft = time.perf_counter() - st
|
|
output.ttft = ttft
|
|
|
|
# Decoding phase
|
|
else:
|
|
output.itl.append(timestamp -
|
|
most_recent_timestamp)
|
|
|
|
most_recent_timestamp = timestamp
|
|
generated_text += data["choices"][0]["text"]
|
|
|
|
if "usage" in data and data["usage"] is not None:
|
|
output.usage = data["usage"]
|
|
|
|
if "metric" in data and data["metric"] is not None:
|
|
output.metric = data["metric"]
|
|
|
|
output.generated_text = generated_text
|
|
output.success = True
|
|
output.latency = latency
|
|
except Exception:
|
|
output.success = False
|
|
exc_info = sys.exc_info()
|
|
output.error = "".join(traceback.format_exception(*exc_info))
|
|
|
|
if pbar:
|
|
pbar.update(1)
|
|
return output
|
|
|
|
|
|
class ConcurrentExecutor:
|
|
|
|
def __init__(self, concurrency_num, input_requests) -> None:
|
|
self.concurrency_num = concurrency_num
|
|
self.concurrency_tasks = []
|
|
self.input_requests_iter = iter(input_requests)
|
|
self.total_requests = len(input_requests)
|
|
self.send_requests = 0
|
|
self.recv_requests = 0
|
|
self.request_input_kwargs = {}
|
|
|
|
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=self.concurrency_num)
|
|
|
|
def config_pyload(self, **kwargs):
|
|
self.request_input_kwargs.update(**kwargs)
|
|
|
|
def run(self, pbar):
|
|
request_results = []
|
|
|
|
while self.recv_requests < self.total_requests:
|
|
if len(self.concurrency_tasks) < self.concurrency_num and self.send_requests < self.total_requests:
|
|
prompt, prompt_len, output_len = next(self.input_requests_iter)
|
|
self.request_input_kwargs['prompt'] = prompt
|
|
self.request_input_kwargs['prompt_len'] = prompt_len
|
|
self.request_input_kwargs['output_len'] = output_len
|
|
request_func_input = MluRequestFuncInput(**self.request_input_kwargs)
|
|
|
|
self.concurrency_tasks.append(
|
|
self.executor.submit(sync_request_openai_completions, request_func_input, pbar)
|
|
)
|
|
self.send_requests += 1
|
|
else:
|
|
done, pending = concurrent.futures.wait(self.concurrency_tasks, return_when="FIRST_COMPLETED")
|
|
self.recv_requests += len(done)
|
|
for task in done:
|
|
assert task.done()
|
|
request_results.append(task.result())
|
|
self.concurrency_tasks = list(pending)
|
|
|
|
return request_results |