import json import os import sys import time import traceback from dataclasses import dataclass, field from typing import Optional, List from tqdm.asyncio import tqdm import requests import concurrent from backend_request_func import (RequestFuncInput, RequestFuncOutput, remove_prefix) @dataclass class MluRequestFuncInput(RequestFuncInput): include_usage: bool = False ignore_eos: bool = False @dataclass class MluRequestFuncOutput(RequestFuncOutput): usage: dict = field( default_factory=dict) metric: dict = field( default_factory=dict) def sync_request_openai_completions( request_func_input: MluRequestFuncInput, pbar: Optional[tqdm] = None, ) -> MluRequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith( ("completions", "profile") ), "OpenAI Completions API URL must end with 'completions' or 'profile'." assert not request_func_input.use_beam_search payload = { "model": request_func_input.model, "prompt": request_func_input.prompt, "temperature": 0.0, "best_of": request_func_input.best_of, "max_tokens": request_func_input.output_len, "ignore_eos": request_func_input.ignore_eos, "logprobs": request_func_input.logprobs, "stream": True, "stream_options": {"include_usage": request_func_input.include_usage} } headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" } output = MluRequestFuncOutput() output.prompt_len = request_func_input.prompt_len generated_text = "" ttft = 0.0 st = time.perf_counter() most_recent_timestamp = st try: with requests.post(url=api_url, json=payload, headers=headers, stream=True) as response: response.raise_for_status() for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\n"): if chunk: chunk = remove_prefix(chunk.decode("utf-8"), "data: ") if chunk == "[DONE]": latency = time.perf_counter() - st else: data = json.loads(chunk) # NOTE: Some completion API might have a last # usage summary response without a token so we # want to check a token was generated if "choices" in data and len(data["choices"]) > 0 and data["choices"][0]["text"]: timestamp = time.perf_counter() # First token if ttft == 0.0: ttft = time.perf_counter() - st output.ttft = ttft # Decoding phase else: output.itl.append(timestamp - most_recent_timestamp) most_recent_timestamp = timestamp generated_text += data["choices"][0]["text"] if "usage" in data and data["usage"] is not None: output.usage = data["usage"] if "metric" in data and data["metric"] is not None: output.metric = data["metric"] output.generated_text = generated_text output.success = True output.latency = latency except Exception: output.success = False exc_info = sys.exc_info() output.error = "".join(traceback.format_exception(*exc_info)) if pbar: pbar.update(1) return output class ConcurrentExecutor: def __init__(self, concurrency_num, input_requests) -> None: self.concurrency_num = concurrency_num self.concurrency_tasks = [] self.input_requests_iter = iter(input_requests) self.total_requests = len(input_requests) self.send_requests = 0 self.recv_requests = 0 self.request_input_kwargs = {} self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=self.concurrency_num) def config_pyload(self, **kwargs): self.request_input_kwargs.update(**kwargs) def run(self, pbar): request_results = [] while self.recv_requests < self.total_requests: if len(self.concurrency_tasks) < self.concurrency_num and self.send_requests < self.total_requests: prompt, prompt_len, output_len = next(self.input_requests_iter) self.request_input_kwargs['prompt'] = prompt self.request_input_kwargs['prompt_len'] = prompt_len self.request_input_kwargs['output_len'] = output_len request_func_input = MluRequestFuncInput(**self.request_input_kwargs) self.concurrency_tasks.append( self.executor.submit(sync_request_openai_completions, request_func_input, pbar) ) self.send_requests += 1 else: done, pending = concurrent.futures.wait(self.concurrency_tasks, return_when="FIRST_COMPLETED") self.recv_requests += len(done) for task in done: assert task.done() request_results.append(task.result()) self.concurrency_tasks = list(pending) return request_results