Files
r200_8f_xtrt_llm/examples/vllm_test/test_api_client.py
2025-08-06 15:49:14 +08:00

102 lines
3.4 KiB
Python

"""Example Python client for vllm.entrypoints.api_server"""
import argparse
import json
from typing import Iterable, List
import requests
from xtrt_llm.vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
def clear_line(n: int = 1) -> None:
LINE_UP = '\033[1A'
LINE_CLEAR = '\x1b[2K'
for _ in range(n):
print(LINE_UP, end=LINE_CLEAR, flush=True)
def post_http_request(prompt: str,
api_url: str,
n: int = 1,
stream: bool = False) -> requests.Response:
headers = {"User-Agent": "Test Client"}
pload = {
"prompt": prompt,
"n": n,
"use_beam_search": True,
"temperature": 0.0,
"max_tokens": 16,
"stream": stream,
}
response = requests.post(api_url, headers=headers, json=pload, stream=True)
return response
def get_streaming_response(response: requests.Response) -> Iterable[List[str]]:
for chunk in response.iter_lines(chunk_size=8192,
decode_unicode=False,
delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["text"]
yield output
def get_response(response: requests.Response) -> List[str]:
data = json.loads(response.content)
output = data["text"]
return output
def create_test_prompts() -> List[str]:
"""Create a list of test prompts."""
test_prompts = list()
unit_promts = ["To be or not to be,",
"A robot may not injure a human being",
"A robot may not injure a human being",
"It is only with the heart that one can see rightly",
"A robot may not injure a human being",
"To be or not to be,",
"It is only with the heart that one can see rightly",
"To be or not to be,",
"It is only with the heart that one can see rightly"]
for i in range (0,100):
test_prompts += unit_promts
return test_prompts
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--n", type=int, default=4)
parser.add_argument("--prompt", type=str, default="")
parser.add_argument("--stream", action="store_true")
args = parser.parse_args()
prompt = args.prompt
api_url = f"http://{args.host}:{args.port}/generate"
n = args.n
stream = args.stream
if prompt == '':
prompt_list = create_test_prompts()
else:
prompt_list = [prompt]
for i in range(len(prompt_list)):
print(f"Prompt: {prompt_list[i]!r}\n", flush=True)
response = post_http_request(prompt_list[i], api_url, n, stream)
if stream:
num_printed_lines = 0
for h in get_streaming_response(response):
clear_line(num_printed_lines)
num_printed_lines = 0
for i, line in enumerate(h):
num_printed_lines += 1
print(f"Beam candidate {i}: {line!r}", flush=True)
else:
output = get_response(response)
for i, line in enumerate(output):
print(f"Beam candidate {i}: {line!r}", flush=True)