feat: frequency, min_new_tokens, presence, and repetition penalties (#973)
This commit is contained in:
@@ -24,7 +24,7 @@ import warnings
|
||||
from argparse import ArgumentParser
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator, List, Optional, Tuple, Union
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
@@ -47,6 +47,7 @@ class RequestFuncInput:
|
||||
prompt_len: int
|
||||
output_len: int
|
||||
model: str
|
||||
extra_request_body: Dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -84,6 +85,7 @@ async def async_request_trt_llm(
|
||||
"stream": True,
|
||||
"min_length": request_func_input.output_len,
|
||||
"end_id": 1048576,
|
||||
**request_func_input.extra_request_body,
|
||||
}
|
||||
if args.disable_ignore_eos:
|
||||
del payload["min_length"]
|
||||
@@ -154,6 +156,7 @@ async def async_request_openai_completions(
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"stream": not args.disable_stream,
|
||||
"ignore_eos": not args.disable_ignore_eos,
|
||||
**request_func_input.extra_request_body,
|
||||
}
|
||||
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
||||
|
||||
@@ -542,6 +545,7 @@ async def benchmark(
|
||||
request_rate: float,
|
||||
disable_tqdm: bool,
|
||||
enable_multi: bool,
|
||||
extra_request_body: Dict[str, Any],
|
||||
):
|
||||
if backend in ASYNC_REQUEST_FUNCS:
|
||||
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||
@@ -556,6 +560,7 @@ async def benchmark(
|
||||
api_url=api_url,
|
||||
prompt_len=test_prompt_len,
|
||||
output_len=test_output_len,
|
||||
extra_request_body=extra_request_body,
|
||||
)
|
||||
test_output = await request_func(request_func_input=test_input)
|
||||
if not test_output.success:
|
||||
@@ -578,6 +583,7 @@ async def benchmark(
|
||||
api_url=api_url,
|
||||
prompt_len=prompt_len,
|
||||
output_len=output_len,
|
||||
extra_request_body=extra_request_body,
|
||||
)
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
@@ -746,6 +752,10 @@ def fire(args: argparse.Namespace):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
extra_request_body = {}
|
||||
if args.extra_request_body:
|
||||
extra_request_body = json.loads(args.extra_request_body)
|
||||
|
||||
if args.port is None:
|
||||
args.port = {
|
||||
"sglang": 30000,
|
||||
@@ -838,6 +848,7 @@ def fire(args: argparse.Namespace):
|
||||
request_rate=rate,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
enable_multi=args.multi,
|
||||
extra_request_body=extra_request_body,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -851,6 +862,7 @@ def fire(args: argparse.Namespace):
|
||||
request_rate=args.request_rate,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
enable_multi=args.multi,
|
||||
extra_request_body=extra_request_body,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -976,6 +988,13 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
help="Disable ignoring EOS.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--extra-request-body",
|
||||
metavar='{"key1": "value1", "key2": "value2"}',
|
||||
type=str,
|
||||
help="Append given JSON object to the request payload. You can use this to specify"
|
||||
"additional generate params like sampling params.",
|
||||
)
|
||||
|
||||
set_ulimit()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user