"""Example Python client for vllm.entrypoints.api_server""" import argparse import json from typing import Iterable, List import requests from xtrt_llm.vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams def clear_line(n: int = 1) -> None: LINE_UP = '\033[1A' LINE_CLEAR = '\x1b[2K' for _ in range(n): print(LINE_UP, end=LINE_CLEAR, flush=True) def post_http_request(prompt: str, api_url: str, n: int = 1, stream: bool = False) -> requests.Response: headers = {"User-Agent": "Test Client"} pload = { "prompt": prompt, "n": n, "use_beam_search": True, "temperature": 0.0, "max_tokens": 16, "stream": stream, } response = requests.post(api_url, headers=headers, json=pload, stream=True) return response def get_streaming_response(response: requests.Response) -> Iterable[List[str]]: for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): if chunk: data = json.loads(chunk.decode("utf-8")) output = data["text"] yield output def get_response(response: requests.Response) -> List[str]: data = json.loads(response.content) output = data["text"] return output def create_test_prompts() -> List[str]: """Create a list of test prompts.""" test_prompts = list() unit_promts = ["To be or not to be,", "A robot may not injure a human being", "A robot may not injure a human being", "It is only with the heart that one can see rightly", "A robot may not injure a human being", "To be or not to be,", "It is only with the heart that one can see rightly", "To be or not to be,", "It is only with the heart that one can see rightly"] for i in range (0,100): test_prompts += unit_promts return test_prompts if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=8000) parser.add_argument("--n", type=int, default=4) parser.add_argument("--prompt", type=str, default="") parser.add_argument("--stream", action="store_true") args = parser.parse_args() prompt = args.prompt api_url = f"http://{args.host}:{args.port}/generate" n = args.n stream = args.stream if prompt == '': prompt_list = create_test_prompts() else: prompt_list = [prompt] for i in range(len(prompt_list)): print(f"Prompt: {prompt_list[i]!r}\n", flush=True) response = post_http_request(prompt_list[i], api_url, n, stream) if stream: num_printed_lines = 0 for h in get_streaming_response(response): clear_line(num_printed_lines) num_printed_lines = 0 for i, line in enumerate(h): num_printed_lines += 1 print(f"Beam candidate {i}: {line!r}", flush=True) else: output = get_response(response) for i, line in enumerate(output): print(f"Beam candidate {i}: {line!r}", flush=True)