Add model accuracy test - step 1 (#866)

This commit is contained in:
Ying Sheng
2024-08-03 18:20:50 -07:00
committed by GitHub
parent 7dd8a7e6d9
commit 70cc0749ce
4 changed files with 330 additions and 3 deletions

View File

@@ -28,7 +28,7 @@ import sys
import threading
import time
from http import HTTPStatus
from typing import Dict, Optional
from typing import Dict, List, Optional, Union
# Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
@@ -481,10 +481,10 @@ class Runtime:
trust_remote_code=self.server_args.trust_remote_code,
)
async def add_request(
async def async_generate(
self,
prompt: str,
sampling_params: Dict,
sampling_params: Optional[Dict] = None,
):
json_data = {
"text": prompt,
@@ -507,5 +507,26 @@ class Runtime:
yield cur
pos += len(cur)
add_request = async_generate
def generate(
self,
prompt: str,
sampling_params: Optional[Dict] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
top_logprobs_num: Optional[Union[List[int], int]] = None,
):
json_data = {
"text": prompt,
"sampling_params": sampling_params,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
}
response = requests.post(
self.url + "/generate",
json=json_data,
)
return json.dumps(response.json())
def __del__(self):
self.shutdown()