Improve benchmark (#1140)

This commit is contained in:
Lianmin Zheng
2024-08-17 17:43:23 -07:00
committed by GitHub
parent cdc8d60752
commit 57d0bd91ec
8 changed files with 111 additions and 678 deletions

View File

@@ -149,10 +149,12 @@ async def async_request_openai_completions(
"completions"
), "OpenAI Completions API URL must end with 'completions'."
prompt = request_func_input.prompt
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
payload = {
"model": request_func_input.model,
"prompt": request_func_input.prompt,
"prompt": prompt,
"temperature": 0.0,
"best_of": 1,
"max_tokens": request_func_input.output_len,
@@ -220,6 +222,13 @@ async def async_request_openai_completions(
return output
async def async_request_ginfer(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
raise NotImplementedError()
def get_model(pretrained_model_name_or_path: str) -> str:
if os.getenv("SGLANG_USE_MODELSCOPE", "False").lower() == "true":
import huggingface_hub.constants
@@ -238,6 +247,13 @@ def get_model(pretrained_model_name_or_path: str) -> str:
def get_tokenizer(
pretrained_model_name_or_path: str,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
if pretrained_model_name_or_path.endswith(
".json"
) or pretrained_model_name_or_path.endswith(".model"):
from sglang.srt.hf_transformers_utils import get_tokenizer
return get_tokenizer(pretrained_model_name_or_path)
if pretrained_model_name_or_path is not None and not os.path.exists(
pretrained_model_name_or_path
):
@@ -252,6 +268,7 @@ ASYNC_REQUEST_FUNCS = {
"vllm": async_request_openai_completions,
"lmdeploy": async_request_openai_completions,
"trt": async_request_trt_llm,
"ginfer": async_request_ginfer,
}
@@ -351,9 +368,9 @@ def sample_sharegpt_requests(
# Tokenize the prompts and completions.
prompt = dataset[i][0]
prompt_token_ids = tokenizer(prompt).input_ids
prompt_token_ids = tokenizer.encode(prompt)
completion = dataset[i][1]
completion_token_ids = tokenizer(completion).input_ids
completion_token_ids = tokenizer.encode(completion)
prompt_len = len(prompt_token_ids)
output_len = (
len(completion_token_ids) if fixed_output_len is None else fixed_output_len
@@ -361,7 +378,9 @@ def sample_sharegpt_requests(
if prompt_len < 4 or output_len < 4:
# Prune too short sequences.
continue
if prompt_len > 1024 or prompt_len + output_len > 2048:
if prompt_len > 1024 or (
prompt_len + output_len > 2048 and fixed_output_len is None
):
# Prune too long sequences.
continue
filtered_dataset.append((prompt, prompt_len, output_len))
@@ -422,7 +441,7 @@ def sample_random_requests(
for i in range(num_prompts):
# Tokenize the prompts and completions.
prompt = dataset[i][0]
prompt_token_ids = tokenizer(prompt).input_ids
prompt_token_ids = tokenizer.encode(prompt)
prompt_len = len(prompt_token_ids)
if prompt_len > input_lens[i]:
@@ -488,7 +507,7 @@ def calculate_metrics(
output_len = outputs[i].output_len
output_lens.append(output_len)
retokenized_output_len = len(
tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids
tokenizer.encode(outputs[i].generated_text, add_special_tokens=False)
)
retokenized_output_lens.append(retokenized_output_len)
total_input += input_requests[i][1]
@@ -547,7 +566,6 @@ async def benchmark(
input_requests: List[Tuple[str, int, int]],
request_rate: float,
disable_tqdm: bool,
enable_multi: bool,
extra_request_body: Dict[str, Any],
):
if backend in ASYNC_REQUEST_FUNCS:
@@ -756,6 +774,7 @@ def run_benchmark(args_: argparse.Namespace):
global args
args = args_
# Set global environments
set_ulimit()
random.seed(args.seed)
np.random.seed(args.seed)
@@ -764,12 +783,14 @@ def run_benchmark(args_: argparse.Namespace):
if args.extra_request_body:
extra_request_body = json.loads(args.extra_request_body)
# Set url
if args.port is None:
args.port = {
"sglang": 30000,
"lmdeploy": 23333,
"vllm": 8000,
"trt": 8000,
"ginfer": 9988,
}.get(args.backend, 30000)
api_url = (
@@ -792,7 +813,11 @@ def run_benchmark(args_: argparse.Namespace):
if args.model is None:
print("Please provide a model using `--model` when using `trt` backend.")
sys.exit(1)
elif args.backend == "ginfer":
api_url = args.base_url if args.base_url else f"{args.host}:{args.port}"
args.model = args.model or "default"
# Get model name
if args.model is None:
try:
response = requests.get(model_url)
@@ -817,6 +842,7 @@ def run_benchmark(args_: argparse.Namespace):
print(f"{args}\n")
# Read dataset
backend = args.backend
model_id = args.model
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
@@ -842,7 +868,21 @@ def run_benchmark(args_: argparse.Namespace):
else:
raise ValueError(f"Unknown dataset: {args.dataset_name}")
if args.multi:
if not args.multi:
return asyncio.run(
benchmark(
backend=backend,
api_url=api_url,
model_id=model_id,
tokenizer=tokenizer,
input_requests=input_requests,
request_rate=args.request_rate,
disable_tqdm=args.disable_tqdm,
extra_request_body=extra_request_body,
)
)
else:
# Benchmark multiple rps. TODO: use a fixed duration to compute num_prompts
request_rates = parse_request_rate_range(args.request_rate_range)
for rate in request_rates:
@@ -855,27 +895,11 @@ def run_benchmark(args_: argparse.Namespace):
input_requests=input_requests,
request_rate=rate,
disable_tqdm=args.disable_tqdm,
enable_multi=args.multi,
extra_request_body=extra_request_body,
)
)
else:
return asyncio.run(
benchmark(
backend=backend,
api_url=api_url,
model_id=model_id,
tokenizer=tokenizer,
input_requests=input_requests,
request_rate=args.request_rate,
disable_tqdm=args.disable_tqdm,
enable_multi=args.multi,
extra_request_body=extra_request_body,
)
)
# to avoid relying on SGLang's components
def set_ulimit(target_soft_limit=65535):
resource_type = resource.RLIMIT_NOFILE
current_soft, current_hard = resource.getrlimit(resource_type)
@@ -968,7 +992,7 @@ if __name__ == "__main__":
help="Number of requests per second. If this is inf, then all the requests are sent at time 0. "
"Otherwise, we use Poisson process to synthesize the request arrival times. Default is 128.0.",
)
parser.add_argument("--seed", type=int, default=0, help="Default is 0.")
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
parser.add_argument(
"--multi",
action="store_true",

View File

@@ -30,7 +30,17 @@ from transformers import (
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
try:
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
ChatGLMConfig.model_type: ChatGLMConfig,
DbrxConfig.model_type: DbrxConfig,
}
except ImportError:
# We want this file to run without vllm dependency
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {}
from sglang.srt.utils import is_multimodal_model

View File

@@ -113,30 +113,7 @@ def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None):
def call_generate_ginfer(prompt, temperature, max_tokens, stop=None, url=None):
import grpc
from ginfer import sampler_pb2, sampler_pb2_grpc
sampler_channel = grpc.insecure_channel(url.replace("http://", ""))
sampler = sampler_pb2_grpc.SamplerStub(sampler_channel)
if stop is None:
stop_strings = None
else:
stop_strings = [stop]
sample_request = sampler_pb2.SampleTextRequest(
prompt=prompt,
settings=sampler_pb2.SampleSettings(
max_len=max_tokens,
rng_seed=0,
temperature=max(temperature, 1e-7),
nucleus_p=1,
stop_strings=stop_strings,
),
)
stream = sampler.SampleText(sample_request)
response = "".join([x.text for x in stream])
return response
raise NotImplementedError()
def call_generate_guidance(