Add model accuracy test - step 1 (#866)
This commit is contained in:
@@ -28,7 +28,7 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
# Fix a bug of Python threading
|
||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||
@@ -481,10 +481,10 @@ class Runtime:
|
||||
trust_remote_code=self.server_args.trust_remote_code,
|
||||
)
|
||||
|
||||
async def add_request(
|
||||
async def async_generate(
|
||||
self,
|
||||
prompt: str,
|
||||
sampling_params: Dict,
|
||||
sampling_params: Optional[Dict] = None,
|
||||
):
|
||||
json_data = {
|
||||
"text": prompt,
|
||||
@@ -507,5 +507,26 @@ class Runtime:
|
||||
yield cur
|
||||
pos += len(cur)
|
||||
|
||||
add_request = async_generate
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
sampling_params: Optional[Dict] = None,
|
||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||
):
|
||||
json_data = {
|
||||
"text": prompt,
|
||||
"sampling_params": sampling_params,
|
||||
"return_logprob": return_logprob,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
}
|
||||
response = requests.post(
|
||||
self.url + "/generate",
|
||||
json=json_data,
|
||||
)
|
||||
return json.dumps(response.json())
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
Reference in New Issue
Block a user