feat: frequency, min_new_tokens, presence, and repetition penalties (#973)

This commit is contained in:
Juwan Yoo
2024-08-08 04:21:08 -07:00
committed by GitHub
parent 228cf47547
commit ab7875941b
20 changed files with 1898 additions and 18 deletions

View File

@@ -24,7 +24,7 @@ import warnings
from argparse import ArgumentParser
from dataclasses import dataclass, field
from datetime import datetime
from typing import AsyncGenerator, List, Optional, Tuple, Union
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
import aiohttp
import numpy as np
@@ -47,6 +47,7 @@ class RequestFuncInput:
prompt_len: int
output_len: int
model: str
extra_request_body: Dict[str, Any]
@dataclass
@@ -84,6 +85,7 @@ async def async_request_trt_llm(
"stream": True,
"min_length": request_func_input.output_len,
"end_id": 1048576,
**request_func_input.extra_request_body,
}
if args.disable_ignore_eos:
del payload["min_length"]
@@ -154,6 +156,7 @@ async def async_request_openai_completions(
"max_tokens": request_func_input.output_len,
"stream": not args.disable_stream,
"ignore_eos": not args.disable_ignore_eos,
**request_func_input.extra_request_body,
}
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
@@ -542,6 +545,7 @@ async def benchmark(
request_rate: float,
disable_tqdm: bool,
enable_multi: bool,
extra_request_body: Dict[str, Any],
):
if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[backend]
@@ -556,6 +560,7 @@ async def benchmark(
api_url=api_url,
prompt_len=test_prompt_len,
output_len=test_output_len,
extra_request_body=extra_request_body,
)
test_output = await request_func(request_func_input=test_input)
if not test_output.success:
@@ -578,6 +583,7 @@ async def benchmark(
api_url=api_url,
prompt_len=prompt_len,
output_len=output_len,
extra_request_body=extra_request_body,
)
tasks.append(
asyncio.create_task(
@@ -746,6 +752,10 @@ def fire(args: argparse.Namespace):
random.seed(args.seed)
np.random.seed(args.seed)
extra_request_body = {}
if args.extra_request_body:
extra_request_body = json.loads(args.extra_request_body)
if args.port is None:
args.port = {
"sglang": 30000,
@@ -838,6 +848,7 @@ def fire(args: argparse.Namespace):
request_rate=rate,
disable_tqdm=args.disable_tqdm,
enable_multi=args.multi,
extra_request_body=extra_request_body,
)
)
else:
@@ -851,6 +862,7 @@ def fire(args: argparse.Namespace):
request_rate=args.request_rate,
disable_tqdm=args.disable_tqdm,
enable_multi=args.multi,
extra_request_body=extra_request_body,
)
)
@@ -976,6 +988,13 @@ if __name__ == "__main__":
action="store_true",
help="Disable ignoring EOS.",
)
parser.add_argument(
"--extra-request-body",
metavar='{"key1": "value1", "key2": "value2"}',
type=str,
help="Append given JSON object to the request payload. You can use this to specify"
"additional generate params like sampling params.",
)
set_ulimit()